diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 2766bdf1..9078e63c 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -30,12 +30,12 @@ def blend_img2img( prompt: str = None, **kwargs, ) -> Image.Image: - logger.info('generating image using img2img, %s steps: %s', params.steps, params.prompt) + prompt = prompt or params.prompt + logger.info('generating image using img2img, %s steps: %s', params.steps, prompt) pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, params.model, params.provider, params.scheduler) - prompt = prompt or params.prompt rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index a189663d..ddbd156e 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -67,9 +67,11 @@ def upscale_stable_diffusion( source: Image.Image, *, upscale: UpscaleParams, + prompt: str = None, **kwargs, ) -> Image.Image: - logger.info('upscaling with Stable Diffusion') + prompt = prompt or params.prompt + logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) pipeline = load_stable_diffusion(ctx, upscale) generator = torch.manual_seed(params.seed) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index ac9ab487..de8b82b6 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -72,6 +72,7 @@ from .utils import ( get_from_list, get_from_map, get_not_empty, + get_size, make_output_name, base_join, ServerContext, @@ -577,8 +578,8 @@ def chain(): stage = StageParams( stage_data.get('name', callback.__name__), - tile_size=int(kwargs.get('tile_size', SizeChart.auto)), - outscale=int(kwargs.get('outscale', 1)), + tile_size=get_size(kwargs.get('tile_size')), + outscale=get_and_clamp_int(kwargs,'outscale', 1, 4), ) # TODO: create Border from border diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index c35bf735..64c0be9d 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -3,12 +3,13 @@ from logging import getLogger from os import environ, path from struct import pack from time import time -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Union, Tuple from .params import ( ImageParams, Param, Size, + SizeChart, ) logger = getLogger(__name__) @@ -92,6 +93,22 @@ def get_not_empty(args: Any, key: str, default: Any) -> Any: return val +def get_size(val: Union[int, str, None]) -> SizeChart: + if val is None: + return SizeChart.auto + + if type(val) is str: + if val in SizeChart: + return SizeChart[val] + else: + return int(val) + + if type(val) is int: + return val + + raise Exception('invalid size') + + def hash_value(sha, param: Param): if param is None: return diff --git a/common/pipelines/example.json b/common/pipelines/example.json index 85dc6475..f059c599 100644 --- a/common/pipelines/example.json +++ b/common/pipelines/example.json @@ -23,13 +23,15 @@ "prompt": "a magical wizard in a robe fighting a dragon", "scale": 4, "outscale": 4, - "tile_size": 128 + "tile_size": "mini" } }, { "name": "save-local", "type": "persist-disk", - "params": {} + "params": { + "tile_size": "8k" + } }, { "name": "save-ceph", @@ -37,7 +39,8 @@ "params": { "bucket": "storage-stable-diffusion", "endpoint_url": "http://scylla.home.holdmyran.ch:8000", - "profile_name": "ceph" + "profile_name": "ceph", + "tile_size": "8k" } } ]