read chain pipeline from JSON, remove new endpoint
This commit is contained in:
parent
1a732d54b6
commit
1fb965633e
|
@ -1,5 +1,6 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageParams
|
from .base import ChainPipeline, PipelineStage, StageParams
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
|
from .blend_grid import BlendGridStage
|
||||||
from .blend_linear import BlendLinearStage
|
from .blend_linear import BlendLinearStage
|
||||||
from .blend_mask import BlendMaskStage
|
from .blend_mask import BlendMaskStage
|
||||||
from .correct_codeformer import CorrectCodeformerStage
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
|
@ -23,6 +24,7 @@ from .upscale_swinir import UpscaleSwinIRStage
|
||||||
CHAIN_STAGES = {
|
CHAIN_STAGES = {
|
||||||
"blend-img2img": BlendImg2ImgStage,
|
"blend-img2img": BlendImg2ImgStage,
|
||||||
"blend-inpaint": UpscaleOutpaintStage,
|
"blend-inpaint": UpscaleOutpaintStage,
|
||||||
|
"blend-grid": BlendGridStage,
|
||||||
"blend-linear": BlendLinearStage,
|
"blend-linear": BlendLinearStage,
|
||||||
"blend-mask": BlendMaskStage,
|
"blend-mask": BlendMaskStage,
|
||||||
"correct-codeformer": CorrectCodeformerStage,
|
"correct-codeformer": CorrectCodeformerStage,
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import ImageParams, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .stage import BaseStage
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BlendGridStage(BaseStage):
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
_worker: WorkerContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
sources: List[Image.Image],
|
||||||
|
*,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
rows: Optional[List[str]] = None,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
order: Optional[int] = None,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
_callback: Optional[ProgressCallback] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
logger.info("combining source images using grid layout")
|
||||||
|
|
||||||
|
size = sources[0].size
|
||||||
|
|
||||||
|
output = Image.new("RGB", (size[0] * width, size[1] * height))
|
||||||
|
|
||||||
|
# TODO: labels
|
||||||
|
for i in order or range(len(sources)):
|
||||||
|
x = i % width
|
||||||
|
y = i / width
|
||||||
|
|
||||||
|
output.paste(sources[i], (x * size[0], y * size[1]))
|
||||||
|
|
||||||
|
return [output]
|
||||||
|
|
|
@ -28,11 +28,13 @@ class SourceNoiseStage(BaseStage):
|
||||||
logger.info("generating image from noise source")
|
logger.info("generating image from noise source")
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
logger.warning(
|
logger.info(
|
||||||
"source images were passed to a noise stage and will be discarded"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = list(sources)
|
||||||
|
|
||||||
|
# TODO: looping over sources and ignoring params does not make much sense for a source stage
|
||||||
for source in sources:
|
for source in sources:
|
||||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class SourceS3Stage(BaseStage):
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
_sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
*,
|
*,
|
||||||
source_keys: List[str],
|
source_keys: List[str],
|
||||||
bucket: str,
|
bucket: str,
|
||||||
|
@ -31,7 +31,12 @@ class SourceS3Stage(BaseStage):
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
outputs = []
|
if len(sources) > 0:
|
||||||
|
logger.info(
|
||||||
|
"source images were passed to a source stage, new images will be appended"
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = list(sources)
|
||||||
for key in source_keys:
|
for key in source_keys:
|
||||||
try:
|
try:
|
||||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -30,7 +30,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
_source: Image.Image,
|
sources: List[Image.Image],
|
||||||
*,
|
*,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int],
|
||||||
size: Size,
|
size: Size,
|
||||||
|
@ -50,9 +50,9 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stage_source" in kwargs:
|
if len(sources):
|
||||||
logger.warning(
|
logger.info(
|
||||||
"a source image was passed to a txt2img stage, and will be discarded"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||||
|
@ -123,4 +123,6 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result.images
|
output = list(sources)
|
||||||
|
output.extend(result.images)
|
||||||
|
return output
|
||||||
|
|
|
@ -29,11 +29,11 @@ class SourceURLStage(BaseStage):
|
||||||
logger.info("loading image from URL source")
|
logger.info("loading image from URL source")
|
||||||
|
|
||||||
if len(sources) > 0:
|
if len(sources) > 0:
|
||||||
logger.warning(
|
logger.info(
|
||||||
"a source image was passed to a source stage, and will be discarded"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = []
|
outputs = list(sources)
|
||||||
for url in source_urls:
|
for url in source_urls:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
output = Image.open(BytesIO(response.content))
|
output = Image.open(BytesIO(response.content))
|
||||||
|
|
|
@ -368,16 +368,21 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
|
|
||||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
logger.debug(
|
if request.is_json():
|
||||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
logger.debug("chain pipeline request with JSON body")
|
||||||
)
|
data = request.get_json()
|
||||||
body = request.form.get("chain") or request.files.get("chain")
|
else:
|
||||||
if body is None:
|
logger.debug(
|
||||||
return error_reply("chain pipeline must have a body")
|
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
body = request.form.get("chain") or request.files.get("chain")
|
||||||
|
if body is None:
|
||||||
|
return error_reply("chain pipeline must have a body")
|
||||||
|
|
||||||
|
data = load_config_str(body)
|
||||||
|
|
||||||
data = load_config_str(body)
|
|
||||||
schema = load_config("./schemas/chain.yaml")
|
schema = load_config("./schemas/chain.yaml")
|
||||||
|
|
||||||
logger.debug("validating chain request: %s against %s", data, schema)
|
logger.debug("validating chain request: %s against %s", data, schema)
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
|
@ -515,61 +520,6 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size))
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
def generate(server: ServerContext, pool: DevicePoolExecutor):
|
|
||||||
if not request.is_json():
|
|
||||||
return error_reply("generate endpoint requires JSON parameters")
|
|
||||||
|
|
||||||
# TODO: should this accept YAML as well?
|
|
||||||
data = request.get_json()
|
|
||||||
schema = load_config("./schemas/generate.yaml")
|
|
||||||
|
|
||||||
logger.debug("validating generate request: %s against %s", data, schema)
|
|
||||||
validate(data, schema)
|
|
||||||
|
|
||||||
jobs = []
|
|
||||||
|
|
||||||
if "txt2img" in data:
|
|
||||||
for job in data.get("txt2img"):
|
|
||||||
device, params, size = pipeline_from_json(server, job, "txt2img")
|
|
||||||
jobs.append((
|
|
||||||
f"generate-txt2img-{len(jobs)}",
|
|
||||||
run_txt2img_pipeline,
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
make_output_name(server, "txt2img", params, size, offset=len(jobs)),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
device,
|
|
||||||
))
|
|
||||||
|
|
||||||
if "img2img" in data:
|
|
||||||
for job in data.get("img2img"):
|
|
||||||
device, params, size = pipeline_from_json(server, job, "img2img")
|
|
||||||
jobs.append((
|
|
||||||
f"generate-img2img-{len(jobs)}",
|
|
||||||
run_img2img_pipeline,
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
make_output_name(server, "img2img", params, size, offset=len(jobs))
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
device,
|
|
||||||
))
|
|
||||||
|
|
||||||
for job in jobs:
|
|
||||||
pool.submit(*job)
|
|
||||||
|
|
||||||
# TODO: collect results
|
|
||||||
# this is the hard part. once all of the jobs are done, the last job or some dedicated job
|
|
||||||
# needs to collect the previous outputs and put them on a grid. jobs write their own
|
|
||||||
# output to disk and do not return it, so that may need to read the images based on the
|
|
||||||
# output names assigned to each job. knowing when the jobs are done is the first problem.
|
|
||||||
|
|
||||||
# TODO: assemble grid
|
|
||||||
|
|
||||||
|
|
||||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
output_file = request.args.get("output", None)
|
output_file = request.args.get("output", None)
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
$id: TODO
|
||||||
|
$schema: https://json-schema.org/draft/2020-12/schema
|
||||||
|
|
||||||
|
$defs:
|
||||||
|
grid:
|
||||||
|
type: object
|
||||||
|
additionalProperties: False
|
||||||
|
required: [width, height]
|
||||||
|
width:
|
||||||
|
type: number
|
||||||
|
height:
|
||||||
|
type: number
|
||||||
|
labels:
|
||||||
|
type: object
|
||||||
|
additionalProperties: False
|
||||||
|
properties:
|
||||||
|
title:
|
||||||
|
type: string
|
||||||
|
rows:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
columns:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
order:
|
||||||
|
type: array
|
||||||
|
items: number
|
||||||
|
|
||||||
|
job_base:
|
||||||
|
type: object
|
||||||
|
additionalProperties: true
|
||||||
|
required: [
|
||||||
|
device,
|
||||||
|
model,
|
||||||
|
pipeline,
|
||||||
|
scheduler,
|
||||||
|
prompt,
|
||||||
|
cfg,
|
||||||
|
steps,
|
||||||
|
seed,
|
||||||
|
]
|
||||||
|
properties:
|
||||||
|
batch:
|
||||||
|
type: number
|
||||||
|
device:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
control:
|
||||||
|
type: string
|
||||||
|
pipeline:
|
||||||
|
type: string
|
||||||
|
scheduler:
|
||||||
|
type: string
|
||||||
|
prompt:
|
||||||
|
type: string
|
||||||
|
negative_prompt:
|
||||||
|
type: string
|
||||||
|
cfg:
|
||||||
|
type: number
|
||||||
|
eta:
|
||||||
|
type: number
|
||||||
|
steps:
|
||||||
|
type: number
|
||||||
|
tiled_vae:
|
||||||
|
type: boolean
|
||||||
|
tiles:
|
||||||
|
type: number
|
||||||
|
overlap:
|
||||||
|
type: number
|
||||||
|
seed:
|
||||||
|
type: number
|
||||||
|
stride:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
job_txt2img:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/$defs/job_base"
|
||||||
|
- type: object
|
||||||
|
additionalProperties: False
|
||||||
|
required: [
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
]
|
||||||
|
properties:
|
||||||
|
width:
|
||||||
|
type: number
|
||||||
|
height:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
job_img2img:
|
||||||
|
allOf:
|
||||||
|
- $ref: "#/$defs/job_base"
|
||||||
|
- type: object
|
||||||
|
additionalProperties: False
|
||||||
|
required: []
|
||||||
|
properties:
|
||||||
|
loopback:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
type: object
|
||||||
|
additionalProperties: False
|
||||||
|
properties:
|
||||||
|
txt2img:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/$defs/job_txt2img"
|
||||||
|
img2img:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/$defs/job_img2img"
|
||||||
|
grid:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/$defs/grid"
|
Loading…
Reference in New Issue