1
0
Fork 0

read chain pipeline from JSON, remove new endpoint

This commit is contained in:
Sean Sube 2023-09-10 20:59:13 -05:00
parent 1a732d54b6
commit 1fb965633e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 202 additions and 77 deletions

View File

@ -1,5 +1,6 @@
from .base import ChainPipeline, PipelineStage, StageParams from .base import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage from .blend_img2img import BlendImg2ImgStage
from .blend_grid import BlendGridStage
from .blend_linear import BlendLinearStage from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage from .correct_codeformer import CorrectCodeformerStage
@ -23,6 +24,7 @@ from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = { CHAIN_STAGES = {
"blend-img2img": BlendImg2ImgStage, "blend-img2img": BlendImg2ImgStage,
"blend-inpaint": UpscaleOutpaintStage, "blend-inpaint": UpscaleOutpaintStage,
"blend-grid": BlendGridStage,
"blend-linear": BlendLinearStage, "blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage, "blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage, "correct-codeformer": CorrectCodeformerStage,

View File

@ -0,0 +1,47 @@
from logging import getLogger
from typing import List, Optional
from PIL import Image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
logger = getLogger(__name__)
class BlendGridStage(BaseStage):
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
*,
height: int,
width: int,
rows: Optional[List[str]] = None,
columns: Optional[List[str]] = None,
title: Optional[str] = None,
order: Optional[int] = None,
stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
logger.info("combining source images using grid layout")
size = sources[0].size
output = Image.new("RGB", (size[0] * width, size[1] * height))
# TODO: labels
for i in order or range(len(sources)):
x = i % width
y = i / width
output.paste(sources[i], (x * size[0], y * size[1]))
return [output]

View File

@ -28,11 +28,13 @@ class SourceNoiseStage(BaseStage):
logger.info("generating image from noise source") logger.info("generating image from noise source")
if len(sources) > 0: if len(sources) > 0:
logger.warning( logger.info(
"source images were passed to a noise stage and will be discarded" "source images were passed to a source stage, new images will be appended"
) )
outputs = [] outputs = list(sources)
# TODO: looping over sources and ignoring params does not make much sense for a source stage
for source in sources: for source in sources:
output = noise_source(source, (size.width, size.height), (0, 0)) output = noise_source(source, (size.width, size.height), (0, 0))

View File

@ -20,7 +20,7 @@ class SourceS3Stage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
_sources: List[Image.Image], sources: List[Image.Image],
*, *,
source_keys: List[str], source_keys: List[str],
bucket: str, bucket: str,
@ -31,7 +31,12 @@ class SourceS3Stage(BaseStage):
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
outputs = [] if len(sources) > 0:
logger.info(
"source images were passed to a source stage, new images will be appended"
)
outputs = list(sources)
for key in source_keys: for key in source_keys:
try: try:
logger.info("loading image from s3://%s/%s", bucket, key) logger.info("loading image from s3://%s/%s", bucket, key)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -30,7 +30,7 @@ class SourceTxt2ImgStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
_source: Image.Image, sources: List[Image.Image],
*, *,
dims: Tuple[int, int, int], dims: Tuple[int, int, int],
size: Size, size: Size,
@ -50,9 +50,9 @@ class SourceTxt2ImgStage(BaseStage):
"generating image using txt2img, %s steps: %s", params.steps, params.prompt "generating image using txt2img, %s steps: %s", params.steps, params.prompt
) )
if "stage_source" in kwargs: if len(sources):
logger.warning( logger.info(
"a source image was passed to a txt2img stage, and will be discarded" "source images were passed to a source stage, new images will be appended"
) )
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
@ -123,4 +123,6 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback, callback=callback,
) )
return result.images output = list(sources)
output.extend(result.images)
return output

View File

@ -29,11 +29,11 @@ class SourceURLStage(BaseStage):
logger.info("loading image from URL source") logger.info("loading image from URL source")
if len(sources) > 0: if len(sources) > 0:
logger.warning( logger.info(
"a source image was passed to a source stage, and will be discarded" "source images were passed to a source stage, new images will be appended"
) )
outputs = [] outputs = list(sources)
for url in source_urls: for url in source_urls:
response = requests.get(url) response = requests.get(url)
output = Image.open(BytesIO(response.content)) output = Image.open(BytesIO(response.content))

View File

@ -368,16 +368,21 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
def chain(server: ServerContext, pool: DevicePoolExecutor): def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug( if request.is_json():
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys() logger.debug("chain pipeline request with JSON body")
) data = request.get_json()
body = request.form.get("chain") or request.files.get("chain") else:
if body is None: logger.debug(
return error_reply("chain pipeline must have a body") "chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
)
body = request.form.get("chain") or request.files.get("chain")
if body is None:
return error_reply("chain pipeline must have a body")
data = load_config_str(body)
data = load_config_str(body)
schema = load_config("./schemas/chain.yaml") schema = load_config("./schemas/chain.yaml")
logger.debug("validating chain request: %s against %s", data, schema) logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema) validate(data, schema)
@ -515,61 +520,6 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size)) return jsonify(json_params(output, params, size))
def generate(server: ServerContext, pool: DevicePoolExecutor):
if not request.is_json():
return error_reply("generate endpoint requires JSON parameters")
# TODO: should this accept YAML as well?
data = request.get_json()
schema = load_config("./schemas/generate.yaml")
logger.debug("validating generate request: %s against %s", data, schema)
validate(data, schema)
jobs = []
if "txt2img" in data:
for job in data.get("txt2img"):
device, params, size = pipeline_from_json(server, job, "txt2img")
jobs.append((
f"generate-txt2img-{len(jobs)}",
run_txt2img_pipeline,
server,
params,
size,
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
None,
None,
device,
))
if "img2img" in data:
for job in data.get("img2img"):
device, params, size = pipeline_from_json(server, job, "img2img")
jobs.append((
f"generate-img2img-{len(jobs)}",
run_img2img_pipeline,
server,
params,
size,
make_output_name(server, "img2img", params, size, offset=len(jobs))
None,
None,
device,
))
for job in jobs:
pool.submit(*job)
# TODO: collect results
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
# needs to collect the previous outputs and put them on a grid. jobs write their own
# output to disk and do not return it, so that may need to read the images based on the
# output names assigned to each job. knowing when the jobs are done is the first problem.
# TODO: assemble grid
def cancel(server: ServerContext, pool: DevicePoolExecutor): def cancel(server: ServerContext, pool: DevicePoolExecutor):
output_file = request.args.get("output", None) output_file = request.args.get("output", None)
if output_file is None: if output_file is None:

117
api/schemas/generate.yaml Normal file
View File

@ -0,0 +1,117 @@
$id: TODO
$schema: https://json-schema.org/draft/2020-12/schema
$defs:
grid:
type: object
additionalProperties: False
required: [width, height]
width:
type: number
height:
type: number
labels:
type: object
additionalProperties: False
properties:
title:
type: string
rows:
type: array
items:
type: string
columns:
type: array
items:
type: string
order:
type: array
items: number
job_base:
type: object
additionalProperties: true
required: [
device,
model,
pipeline,
scheduler,
prompt,
cfg,
steps,
seed,
]
properties:
batch:
type: number
device:
type: string
model:
type: string
control:
type: string
pipeline:
type: string
scheduler:
type: string
prompt:
type: string
negative_prompt:
type: string
cfg:
type: number
eta:
type: number
steps:
type: number
tiled_vae:
type: boolean
tiles:
type: number
overlap:
type: number
seed:
type: number
stride:
type: number
job_txt2img:
allOf:
- $ref: "#/$defs/job_base"
- type: object
additionalProperties: False
required: [
height,
width,
]
properties:
width:
type: number
height:
type: number
job_img2img:
allOf:
- $ref: "#/$defs/job_base"
- type: object
additionalProperties: False
required: []
properties:
loopback:
type: number
type: object
additionalProperties: False
properties:
txt2img:
type: array
items:
$ref: "#/$defs/job_txt2img"
img2img:
type: array
items:
$ref: "#/$defs/job_img2img"
grid:
type: array
items:
$ref: "#/$defs/grid"