1
0
Fork 0

skip controlnet in chains for now, remove empty source image

This commit is contained in:
Sean Sube 2023-09-10 21:17:09 -05:00
parent 9d4272eb09
commit 93fe54577c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 12 additions and 6 deletions

View File

@ -69,9 +69,9 @@ class SourceTxt2ImgStage(BaseStage):
# generate new latents or slice existing # generate new latents or slice existing
if latents is None: if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch) latents = get_latents_from_seed(int(params.seed), latent_size, params.batch)
else: else:
latents = get_tile_latents(latents, params.seed, latent_size, dims) latents = get_tile_latents(latents, int(params.seed), latent_size, dims)
pipe_type = params.get_valid_pipeline("txt2img") pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline( pipe = load_pipeline(

View File

@ -1,6 +1,7 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, Dict
from flask import Flask, jsonify, make_response, request, url_for from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate from jsonschema import validate
@ -368,7 +369,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
def chain(server: ServerContext, pool: DevicePoolExecutor): def chain(server: ServerContext, pool: DevicePoolExecutor):
if request.is_json(): if request.is_json:
logger.debug("chain pipeline request with JSON body") logger.debug("chain pipeline request with JSON body")
data = request.get_json() data = request.get_json()
else: else:
@ -396,9 +397,13 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
pipeline = ChainPipeline() pipeline = ChainPipeline()
for stage_data in data.get("stages", []): for stage_data in data.get("stages", []):
stage_class = CHAIN_STAGES[stage_data.get("type")] stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {}) kwargs: Dict[str, Any] = stage_data.get("params", {})
logger.info("request stage: %s, %s", stage_class.__name__, kwargs) logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
if "control" in kwargs:
logger.warning("TODO: resolve controlnet model")
kwargs.pop("control")
stage = StageParams( stage = StageParams(
stage_data.get("name", stage_class.__name__), stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")), tile_size=get_size(kwargs.get("tile_size")),
@ -443,13 +448,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))
# build and run chain pipeline # build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
pool.submit( pool.submit(
job_name, job_name,
pipeline, pipeline,
server, server,
params, params,
empty_source, [],
output=output[0], output=output[0],
size=size, size=size,
needs_device=device, needs_device=device,

View File

@ -46,8 +46,10 @@ $defs:
patternProperties: patternProperties:
"^[-_A-Za-z]+$": "^[-_A-Za-z]+$":
oneOf: oneOf:
- type: boolean
- type: number - type: number
- type: string - type: string
- type: "null"
request_chain: request_chain:
type: array type: array