read chain pipeline from JSON, remove new endpoint
This commit is contained in:
parent
1a732d54b6
commit
1fb965633e
|
@ -1,5 +1,6 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageParams
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_grid import BlendGridStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
from .blend_mask import BlendMaskStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
|
@ -23,6 +24,7 @@ from .upscale_swinir import UpscaleSwinIRStage
|
|||
CHAIN_STAGES = {
|
||||
"blend-img2img": BlendImg2ImgStage,
|
||||
"blend-inpaint": UpscaleOutpaintStage,
|
||||
"blend-grid": BlendGridStage,
|
||||
"blend-linear": BlendLinearStage,
|
||||
"blend-mask": BlendMaskStage,
|
||||
"correct-codeformer": CorrectCodeformerStage,
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class BlendGridStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
_worker: WorkerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
height: int,
|
||||
width: int,
|
||||
rows: Optional[List[str]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
title: Optional[str] = None,
|
||||
order: Optional[int] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
logger.info("combining source images using grid layout")
|
||||
|
||||
size = sources[0].size
|
||||
|
||||
output = Image.new("RGB", (size[0] * width, size[1] * height))
|
||||
|
||||
# TODO: labels
|
||||
for i in order or range(len(sources)):
|
||||
x = i % width
|
||||
y = i / width
|
||||
|
||||
output.paste(sources[i], (x * size[0], y * size[1]))
|
||||
|
||||
return [output]
|
||||
|
|
@ -28,11 +28,13 @@ class SourceNoiseStage(BaseStage):
|
|||
logger.info("generating image from noise source")
|
||||
|
||||
if len(sources) > 0:
|
||||
logger.warning(
|
||||
"source images were passed to a noise stage and will be discarded"
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = []
|
||||
outputs = list(sources)
|
||||
|
||||
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||
for source in sources:
|
||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class SourceS3Stage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
_sources: List[Image.Image],
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
source_keys: List[str],
|
||||
bucket: str,
|
||||
|
@ -31,7 +31,12 @@ class SourceS3Stage(BaseStage):
|
|||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
outputs = []
|
||||
if len(sources) > 0:
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = list(sources)
|
||||
for key in source_keys:
|
||||
try:
|
||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -30,7 +30,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
_source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
dims: Tuple[int, int, int],
|
||||
size: Size,
|
||||
|
@ -50,9 +50,9 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
||||
if "stage_source" in kwargs:
|
||||
logger.warning(
|
||||
"a source image was passed to a txt2img stage, and will be discarded"
|
||||
if len(sources):
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||
|
@ -123,4 +123,6 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
callback=callback,
|
||||
)
|
||||
|
||||
return result.images
|
||||
output = list(sources)
|
||||
output.extend(result.images)
|
||||
return output
|
||||
|
|
|
@ -29,11 +29,11 @@ class SourceURLStage(BaseStage):
|
|||
logger.info("loading image from URL source")
|
||||
|
||||
if len(sources) > 0:
|
||||
logger.warning(
|
||||
"a source image was passed to a source stage, and will be discarded"
|
||||
logger.info(
|
||||
"source images were passed to a source stage, new images will be appended"
|
||||
)
|
||||
|
||||
outputs = []
|
||||
outputs = list(sources)
|
||||
for url in source_urls:
|
||||
response = requests.get(url)
|
||||
output = Image.open(BytesIO(response.content))
|
||||
|
|
|
@ -368,16 +368,21 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
|
||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||
logger.debug(
|
||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||
)
|
||||
body = request.form.get("chain") or request.files.get("chain")
|
||||
if body is None:
|
||||
return error_reply("chain pipeline must have a body")
|
||||
if request.is_json():
|
||||
logger.debug("chain pipeline request with JSON body")
|
||||
data = request.get_json()
|
||||
else:
|
||||
logger.debug(
|
||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||
)
|
||||
|
||||
body = request.form.get("chain") or request.files.get("chain")
|
||||
if body is None:
|
||||
return error_reply("chain pipeline must have a body")
|
||||
|
||||
data = load_config_str(body)
|
||||
|
||||
data = load_config_str(body)
|
||||
schema = load_config("./schemas/chain.yaml")
|
||||
|
||||
logger.debug("validating chain request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
|
@ -515,61 +520,6 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
def generate(server: ServerContext, pool: DevicePoolExecutor):
|
||||
if not request.is_json():
|
||||
return error_reply("generate endpoint requires JSON parameters")
|
||||
|
||||
# TODO: should this accept YAML as well?
|
||||
data = request.get_json()
|
||||
schema = load_config("./schemas/generate.yaml")
|
||||
|
||||
logger.debug("validating generate request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
jobs = []
|
||||
|
||||
if "txt2img" in data:
|
||||
for job in data.get("txt2img"):
|
||||
device, params, size = pipeline_from_json(server, job, "txt2img")
|
||||
jobs.append((
|
||||
f"generate-txt2img-{len(jobs)}",
|
||||
run_txt2img_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
|
||||
None,
|
||||
None,
|
||||
device,
|
||||
))
|
||||
|
||||
if "img2img" in data:
|
||||
for job in data.get("img2img"):
|
||||
device, params, size = pipeline_from_json(server, job, "img2img")
|
||||
jobs.append((
|
||||
f"generate-img2img-{len(jobs)}",
|
||||
run_img2img_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
make_output_name(server, "img2img", params, size, offset=len(jobs))
|
||||
None,
|
||||
None,
|
||||
device,
|
||||
))
|
||||
|
||||
for job in jobs:
|
||||
pool.submit(*job)
|
||||
|
||||
# TODO: collect results
|
||||
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
|
||||
# needs to collect the previous outputs and put them on a grid. jobs write their own
|
||||
# output to disk and do not return it, so that may need to read the images based on the
|
||||
# output names assigned to each job. knowing when the jobs are done is the first problem.
|
||||
|
||||
# TODO: assemble grid
|
||||
|
||||
|
||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
output_file = request.args.get("output", None)
|
||||
if output_file is None:
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
$id: TODO
|
||||
$schema: https://json-schema.org/draft/2020-12/schema
|
||||
|
||||
$defs:
|
||||
grid:
|
||||
type: object
|
||||
additionalProperties: False
|
||||
required: [width, height]
|
||||
width:
|
||||
type: number
|
||||
height:
|
||||
type: number
|
||||
labels:
|
||||
type: object
|
||||
additionalProperties: False
|
||||
properties:
|
||||
title:
|
||||
type: string
|
||||
rows:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
order:
|
||||
type: array
|
||||
items: number
|
||||
|
||||
job_base:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
required: [
|
||||
device,
|
||||
model,
|
||||
pipeline,
|
||||
scheduler,
|
||||
prompt,
|
||||
cfg,
|
||||
steps,
|
||||
seed,
|
||||
]
|
||||
properties:
|
||||
batch:
|
||||
type: number
|
||||
device:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
control:
|
||||
type: string
|
||||
pipeline:
|
||||
type: string
|
||||
scheduler:
|
||||
type: string
|
||||
prompt:
|
||||
type: string
|
||||
negative_prompt:
|
||||
type: string
|
||||
cfg:
|
||||
type: number
|
||||
eta:
|
||||
type: number
|
||||
steps:
|
||||
type: number
|
||||
tiled_vae:
|
||||
type: boolean
|
||||
tiles:
|
||||
type: number
|
||||
overlap:
|
||||
type: number
|
||||
seed:
|
||||
type: number
|
||||
stride:
|
||||
type: number
|
||||
|
||||
job_txt2img:
|
||||
allOf:
|
||||
- $ref: "#/$defs/job_base"
|
||||
- type: object
|
||||
additionalProperties: False
|
||||
required: [
|
||||
height,
|
||||
width,
|
||||
]
|
||||
properties:
|
||||
width:
|
||||
type: number
|
||||
height:
|
||||
type: number
|
||||
|
||||
job_img2img:
|
||||
allOf:
|
||||
- $ref: "#/$defs/job_base"
|
||||
- type: object
|
||||
additionalProperties: False
|
||||
required: []
|
||||
properties:
|
||||
loopback:
|
||||
type: number
|
||||
|
||||
type: object
|
||||
additionalProperties: False
|
||||
properties:
|
||||
txt2img:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/job_txt2img"
|
||||
img2img:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/job_img2img"
|
||||
grid:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/grid"
|
Loading…
Reference in New Issue