From 93fe54577ca22456f7bec6aaa9e53691e7fa6d1a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Sep 2023 21:17:09 -0500 Subject: [PATCH] skip controlnet in chains for now, remove empty source image --- api/onnx_web/chain/source_txt2img.py | 4 ++-- api/onnx_web/server/api.py | 12 ++++++++---- api/schemas/chain.yaml | 2 ++ 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 82d9aebe..c69443a1 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -69,9 +69,9 @@ class SourceTxt2ImgStage(BaseStage): # generate new latents or slice existing if latents is None: - latents = get_latents_from_seed(params.seed, latent_size, params.batch) + latents = get_latents_from_seed(int(params.seed), latent_size, params.batch) else: - latents = get_tile_latents(latents, params.seed, latent_size, dims) + latents = get_tile_latents(latents, int(params.seed), latent_size, dims) pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index b97ab36a..ae33901a 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -1,6 +1,7 @@ from io import BytesIO from logging import getLogger from os import path +from typing import Any, Dict from flask import Flask, jsonify, make_response, request, url_for from jsonschema import validate @@ -368,7 +369,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): def chain(server: ServerContext, pool: DevicePoolExecutor): - if request.is_json(): + if request.is_json: logger.debug("chain pipeline request with JSON body") data = request.get_json() else: @@ -396,9 +397,13 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): pipeline = ChainPipeline() for stage_data in data.get("stages", []): stage_class = CHAIN_STAGES[stage_data.get("type")] - kwargs = stage_data.get("params", {}) + kwargs: Dict[str, Any] = stage_data.get("params", {}) logger.info("request stage: %s, %s", stage_class.__name__, kwargs) + if "control" in kwargs: + logger.warning("TODO: resolve controlnet model") + kwargs.pop("control") + stage = StageParams( stage_data.get("name", stage_class.__name__), tile_size=get_size(kwargs.get("tile_size")), @@ -443,13 +448,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("running chain pipeline with %s stages", len(pipeline.stages)) # build and run chain pipeline - empty_source = Image.new("RGB", (size.width, size.height)) pool.submit( job_name, pipeline, server, params, - empty_source, + [], output=output[0], size=size, needs_device=device, diff --git a/api/schemas/chain.yaml b/api/schemas/chain.yaml index 1b41c503..96d0c73b 100644 --- a/api/schemas/chain.yaml +++ b/api/schemas/chain.yaml @@ -46,8 +46,10 @@ $defs: patternProperties: "^[-_A-Za-z]+$": oneOf: + - type: boolean - type: number - type: string + - type: "null" request_chain: type: array