diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 1405456e..f69421e5 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -285,7 +285,11 @@ def run_img2img_pipeline( **pipe_params, ) - for image, output in zip(result.images, outputs): + images = result.images + if source_filter is not None: + images.append(source) + + for image, output in zip(images, outputs): image = run_upscale_correction( job, server, diff --git a/api/onnx_web/image/source_filter.py b/api/onnx_web/image/source_filter.py index 56f02df4..3885bfdf 100644 --- a/api/onnx_web/image/source_filter.py +++ b/api/onnx_web/image/source_filter.py @@ -183,6 +183,7 @@ def source_filter_canny( image = cv2.Canny(pil_to_cv2(source), low_threshold, high_threshold) image = Image.fromarray(image) + image = image.convert("RGB") return image @@ -192,7 +193,7 @@ def source_filter_openpose(server: ServerContext, source: Image.Image) -> Image. model = OpenposeDetector.from_pretrained( "lllyasviel/ControlNet", - cache_dir=server.cache_dir, + cache_dir=server.cache_path, ) image = model(source) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 5a9a7ed8..892ffd6a 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -69,7 +69,9 @@ def make_output_name( params: ImageParams, size: Size, extras: Optional[List[Optional[Param]]] = None, + count: Optional[int] = None, ) -> List[str]: + count = count or params.batch now = int(time()) sha = sha256() @@ -93,7 +95,7 @@ def make_output_name( return [ f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" - for i in range(params.batch) + for i in range(count) ] diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 29585582..3d5247af 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -176,8 +176,12 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): get_config_value("strength", "min"), ) - # TODO: add filtered source to outputs - output = make_output_name(server, "img2img", params, size, extras=[strength]) + output_count = params.batch + if source_filter is not None: + output_count += 1 + + output = make_output_name(server, "img2img", params, size, extras=[strength], count=output_count) + job_name = output[0] logger.info("img2img job queued for: %s", job_name)