diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 3bd24933..8a00dfe7 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -27,16 +27,19 @@ def blend_img2img( source_image: Image.Image, *, strength: float, + prompt: str = None, + **kwargs, ) -> Image.Image: logger.info('generating image using img2img', params.prompt) pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, params.model, params.provider, params.scheduler) + prompt = prompt or params.prompt rng = np.random.RandomState(params.seed) result = pipe( - params.prompt, + prompt, generator=rng, guidance_scale=params.cfg, image=source_image, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index bdf42780..5ed461db 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -46,6 +46,7 @@ def blend_inpaint( fill_color: str = 'white', mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, + **kwargs, ) -> Image.Image: logger.info('upscaling image by expanding borders', expand) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 3190747c..b24970cb 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -60,6 +60,7 @@ def correct_gfpgan( *, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None, + **kwargs, ) -> Image.Image: if upscale.correction_model is None: logger.warn('no face model given, skipping') diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index ca877157..6e74429b 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -21,6 +21,7 @@ def persist_disk( source_image: Image.Image, *, output: str, + **kwargs, ) -> Image.Image: dest = base_join(ctx.output_path, output) source_image.save(dest) diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 6698fe11..2653a7f1 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -26,6 +26,7 @@ def persist_s3( bucket: str, endpoint_url: str = None, profile_name: str = None, + **kwargs, ) -> Image.Image: session = Session(profile_name=profile_name) s3 = session.client('s3', endpoint_url=endpoint_url) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 88d49590..4ae06580 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -29,8 +29,11 @@ def source_txt2img( source_image: Image.Image, *, size: Size, + prompt: str = None, + **kwargs, ) -> Image.Image: - logger.info('generating image using txt2img, %s steps', params.steps) + prompt = prompt or params.prompt + logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt) if source_image is not None: logger.warn('a source image was passed to a txt2img stage, but will be discarded') @@ -42,7 +45,7 @@ def source_txt2img( rng = np.random.RandomState(params.seed) result = pipe( - params.prompt, + prompt, height=size.height, width=size.width, generator=rng, diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index bbb4ae30..f6957aa8 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -46,6 +46,7 @@ def upscale_outpaint( fill_color: str = 'white', mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, + **kwargs, ) -> Image.Image: logger.info('upscaling image by expanding borders: %s', expand) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 3d83083e..bc9ec56d 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -77,6 +77,7 @@ def upscale_resrgan( source_image: Image.Image, *, upscale: UpscaleParams, + **kwargs, ) -> Image.Image: logger.info('upscaling image with Real ESRGAN', upscale.scale) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 4543cccd..a189663d 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -67,6 +67,7 @@ def upscale_stable_diffusion( source: Image.Image, *, upscale: UpscaleParams, + **kwargs, ) -> Image.Image: logger.info('upscaling with Stable Diffusion') diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 34833ce4..9fcf414a 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -18,6 +18,7 @@ from flask_cors import CORS from flask_executor import Executor from glob import glob from io import BytesIO +from jsonschema import validate from logging import getLogger from PIL import Image from onnxruntime import get_available_providers @@ -26,10 +27,12 @@ from typing import Tuple from .chain import ( + blend_img2img, + blend_inpaint, correct_gfpgan, - source_txt2img, persist_disk, persist_s3, + source_txt2img, upscale_outpaint, upscale_resrgan, upscale_stable_diffusion, @@ -120,10 +123,15 @@ mask_filters = { 'gaussian-screen': mask_filter_gaussian_screen, } chain_stages = { - 'correction-gfpgan': correct_gfpgan, - 'upscaling-outpaint': upscale_outpaint, - 'upscaling-resrgan': upscale_resrgan, - 'upscaling-stable-diffusion': upscale_stable_diffusion, + 'blend-img2img': blend_img2img, + 'blend-inpaint': blend_inpaint, + 'correct-gfpgan': correct_gfpgan, + 'persist-disk': persist_disk, + 'persist-s3': persist_s3, + 'source-txt2img': source_txt2img, + 'upscale-outpaint': upscale_outpaint, + 'upscale-resrgan': upscale_resrgan, + 'upscale-stable-diffusion': upscale_stable_diffusion, } # Available ORT providers @@ -547,37 +555,38 @@ def upscale(): @app.route('/api/chain', methods=['POST']) def chain(): + data = request.json + + with open('./schema.yaml', 'r') as f: + schema = yaml.safe_load(f.read()) + + logger.info('validating chain request: %s against %s', data, schema) + validate(data, schema) + + # get defaults from the regular parameters params, size = pipeline_from_request() output = make_output_name('chain', params, size) - # parse body as json, list of stages - example = ChainPipeline(stages=[ - (source_txt2img, StageParams(), { - 'size': size, - }), - (upscale_outpaint, StageParams(), { - 'expand': Border.even(SizeChart.half), - }), - (persist_disk, StageParams(tile_size=SizeChart.hd8k), { - 'output': output, - }), - (upscale_stable_diffusion, StageParams(tile_size=SizeChart.mini,outscale=4), { - 'upscale': UpscaleParams('stable-diffusion-x4-upscaler', params.provider, scale=4, outscale=4) - }), - (persist_disk, StageParams(tile_size=SizeChart.hd8k), { - 'output': output, - }), - (persist_s3, StageParams(tile_size=SizeChart.hd8k), { - 'bucket': 'storage-stable-diffusion', - 'endpoint_url': 'http://scylla.home.holdmyran.ch:8000', - 'output': output, - 'profile_name': 'ceph', - }), - ]) + pipeline = ChainPipeline() + for stage_data in data.get('stages', []): + logger.info('request stage: %s', stage_data) + + callback = chain_stages[stage_data.get('type')] + kwargs = stage_data.get('params', {}) + + stage = StageParams( + stage_data.get('name', callback.__name__), + tile_size=int(kwargs.get('tile_size', SizeChart.auto)), + outscale=int(kwargs.get('outscale', 1)), + ) + # TODO: create Border from border + # TODO: create Upscale from upscale + pipeline.append((callback, stage, kwargs)) # build and run chain pipeline - executor.submit_stored(output, example, context, - params, Image.new('RGB', (1, 1))) + fake_source = Image.new('RGB', (1, 1)) + executor.submit_stored(output, pipeline, context, + params, fake_source, output=output, size=size) return jsonify({ 'output': output, diff --git a/api/requirements.txt b/api/requirements.txt index e368009b..9b07a6eb 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -18,4 +18,5 @@ boto3 flask flask-cors flask_executor +jsonschema pyyaml \ No newline at end of file diff --git a/api/schema.yaml b/api/schema.yaml new file mode 100644 index 00000000..4b031b09 --- /dev/null +++ b/api/schema.yaml @@ -0,0 +1,69 @@ +$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml +$schema: https://json-schema.org/draft/2020-12/schema + +$defs: + border_params: + type: object + properties: + bottom: + type: number + left: + type: number + right: + type: number + top: + type: number + + image_params: + type: object + required: [prompt] + properties: + prompt: + type: string + + upscale_params: + type: object + required: [outscale, scale] + properties: + outscale: + type: number + scale: + type: number + + request_stage: + type: object + required: [name, type, params] + properties: + name: + type: string + type: + type: string + params: + type: object + properties: + args: + type: object + additionalProperties: False + patternProperties: + "^[-_A-Za-z]+$": + oneOf: + - type: number + - type: string + border: + $ref: "#/$defs/border_params" + image: + $ref: "#/$defs/image_params" + upscale: + $ref: "#/$defs/upscale_params" + + request_chain: + type: array + items: + $ref: "#/$defs/request_stage" + +type: object +additionalProperties: False +required: [stages] +properties: + stages: + $ref: "#/$defs/request_chain" diff --git a/common/pipelines/example.json b/common/pipelines/example.json new file mode 100644 index 00000000..f7b284ff --- /dev/null +++ b/common/pipelines/example.json @@ -0,0 +1,43 @@ +{ + "stages": [ + { + "name": "start", + "type": "source-txt2img", + "params": { + "prompt": "a magical wizard" + } + }, + { + "name": "refine", + "type": "blend-img2img", + "params": { + "prompt": "a magical wizard in a robe fighting a dragon" + } + }, + { + "name": "upscale", + "type": "upscale-stable-diffusion", + "params": { + "model": "stable-diffusion-x4-upscaler", + "prompt": "a magical wizard in a robe fighting a dragon", + "scale": 4, + "outscale": 4, + "tile_size": 128 + } + }, + { + "name": "save-local", + "type": "persist-disk", + "params": {} + }, + { + "name": "save-ceph", + "type": "persist-s3", + "params": { + "bucket": "storage-stable-diffusion", + "endpoint_url": "http://scylla.home.holdmyran.ch:8000", + "profile_name": "ceph" + } + } + ] +} \ No newline at end of file