diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index e558ca61..9f46400f 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -16,6 +16,8 @@ from diffusers import ( OnnxStableDiffusionPipeline, OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, + # types + DiffusionPipeline, ) from flask import Flask, jsonify, request, send_from_directory, url_for from hashlib import sha256 @@ -79,15 +81,15 @@ pipeline_schedulers = { } -def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0): +def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float: return min(max(float(args.get(key, default_value)), min_value), max_value) -def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1): +def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1) -> int: return min(max(int(args.get(key, default_value)), min_value), max_value) -def get_from_map(args, key: str, values, default): +def get_from_map(args, key: str, values: dict[str, Any], default: Any): selected = args.get(key, default) if selected in values: return values[selected] @@ -109,7 +111,7 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: return image_latents -def load_pipeline(pipeline, model: str, provider: str, scheduler): +def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options @@ -146,7 +148,7 @@ def json_with_cors(data, origin='*'): return res -def make_output_path(type: str, params: Tuple[Union[str, int, float]]): +def make_output_path(type: str, seed: int, params: Tuple[Union[str, int, float]]): sha = sha256() sha.update(type.encode('utf-8')) for param in params: @@ -159,7 +161,7 @@ def make_output_path(type: str, params: Tuple[Union[str, int, float]]): else: print('cannot hash param: %s, %s' % (param, type(param))) - output_file = '%s_%s.png' % (type, sha.hexdigest()) + output_file = '%s_%s_%s.png' % (type, seed, sha.hexdigest()) output_full = safer_join(output_path, output_file) return (output_file, output_full) @@ -238,7 +240,7 @@ def list_schedulers(): return json_with_cors(list(pipeline_schedulers.keys())) -def pipeline_from_request(pipeline): +def pipeline_from_request(pipeline: DiffusionPipeline): user = request.remote_addr # pipeline stuff @@ -250,12 +252,13 @@ def pipeline_from_request(pipeline): # image params prompt = request.args.get('prompt', default_prompt) - negative_prompt = request.args.get('negative', None); + negative_prompt = request.args.get('negative', None) - cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0) - steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps) - height = get_and_clamp_int(request.args, 'height', default_height, max_height) - width = get_and_clamp_int(request.args, 'width', default_width, max_width) + cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0) + steps = get_and_clamp_int(request.args, 'steps', default_steps, config_params.get('steps').get('max')) + height = get_and_clamp_int( + request.args, 'height', default_height, config_params.get('height').get('max')) + width = get_and_clamp_int(request.args, 'width', default_width, config_params.get('width').get('max')) seed = int(request.args.get('seed', -1)) if seed == -1: @@ -405,5 +408,5 @@ def inpaint(): @app.route('/output/') -def output(filename): +def output(filename: str): return send_from_directory(path.join('..', output_path), filename, as_attachment=False) diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 2dbb0be4..3cd3df78 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -13,13 +13,20 @@ ], "settings": { "cSpell.words": [ + "astype", + "CUDA", "ddim", "ddpm", "directml", "ftfy", + "Heun", "huggingface", "Inpaint", + "jsonify", + "Karras", + "KDPM", "Multistep", + "ndarray", "numpy", "Onnx", "onnxruntime", @@ -27,7 +34,9 @@ "pretrained", "protobuf", "runwayml", + "scandir", "scipy", + "Singlestep", "spacy", "spinalcase", "stringcase",