diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index ff01125a..9b554de3 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -87,15 +87,15 @@ class ChainPipeline: def steps(self, params: ImageParams, size: Size): steps = 0 - for callback, _params, _kwargs in self.stages: - steps += callback.steps(params, size) + for callback, _params, kwargs in self.stages: + steps += callback.steps(kwargs.get("params", params), size) return steps def outputs(self, params: ImageParams, sources: int): outputs = sources - for callback, _params, _kwargs in self.stages: - outputs += callback.outputs(params, outputs) + for callback, _params, kwargs in self.stages: + outputs += callback.outputs(kwargs.get("params", params), outputs) return outputs diff --git a/api/onnx_web/chain/stage.py b/api/onnx_web/chain/stage.py index fff56ba7..3942460b 100644 --- a/api/onnx_web/chain/stage.py +++ b/api/onnx_web/chain/stage.py @@ -21,14 +21,14 @@ class BaseStage: stage_source: Optional[Image.Image] = None, **kwargs, ) -> List[Image.Image]: - raise NotImplementedError() + raise NotImplementedError() # noqa def steps( self, params: ImageParams, size: Size, ) -> int: - return 1 + return 1 # noqa def outputs( self, diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index fa4cc449..3865cea4 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -18,7 +18,7 @@ from ..diffusers.run import ( ) from ..diffusers.utils import replace_wildcards from ..output import json_params, make_output_name -from ..params import Border, Size, StageParams, TileOrder, UpscaleParams +from ..params import Size, StageParams, TileOrder from ..transformers.run import run_txt2txt_pipeline from ..utils import ( base_join, @@ -50,10 +50,11 @@ from .load import ( get_wildcard_data, ) from .params import ( - border_from_request, - highres_from_request, + build_border, + build_highres, + build_upscale, + pipeline_from_json, pipeline_from_request, - upscale_from_request, ) from .utils import wrap_route @@ -168,8 +169,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): size = Size(source.width, source.height) device, params, _size = pipeline_from_request(server, "img2img") - upscale = upscale_from_request() - highres = highres_from_request() + upscale = build_upscale() + highres = build_highres() source_filter = get_from_list( request.args, "sourceFilter", list(get_source_filters().keys()) ) @@ -217,8 +218,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server, "txt2img") - upscale = upscale_from_request() - highres = highres_from_request() + upscale = build_upscale() + highres = build_highres() replace_wildcards(params, get_wildcard_data()) @@ -271,9 +272,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): ) device, params, _size = pipeline_from_request(server, "inpaint") - expand = border_from_request() - upscale = upscale_from_request() - highres = highres_from_request() + expand = build_border() + upscale = build_upscale() + highres = build_highres() fill_color = get_not_empty(request.args, "fillColor", "white") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") @@ -341,8 +342,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") device, params, size = pipeline_from_request(server) - upscale = upscale_from_request() - highres = highres_from_request() + upscale = build_upscale() + highres = build_highres() replace_wildcards(params, get_wildcard_data()) @@ -367,6 +368,10 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) +# keys that are specially parsed by params and should not show up in with_args +CHAIN_POP_KEYS = ["model", "control"] + + def chain(server: ServerContext, pool: DevicePoolExecutor): if request.is_json: logger.debug("chain pipeline request with JSON body") @@ -386,9 +391,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.debug("validating chain request: %s against %s", data, schema) validate(data, schema) - # get defaults from the regular parameters - device, base_params, base_size = pipeline_from_request( - server, data=data.get("defaults", None) + device, base_params, base_size = pipeline_from_json( + server, data=data.get("defaults") ) # start building the pipeline @@ -399,32 +403,32 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("request stage: %s, %s", stage_class.__name__, kwargs) # TODO: combine base params with stage params - _device, params, size = pipeline_from_request(server, data=kwargs) + _device, params, size = pipeline_from_json(server, data=kwargs) replace_wildcards(params, get_wildcard_data()) - if "model" in kwargs: - kwargs.pop("model") - - if "control" in kwargs: - logger.warning("TODO: resolve controlnet model") - kwargs.pop("control") + # remove parsed keys, like model names (which become paths) + for pop_key in CHAIN_POP_KEYS: + if pop_key in kwargs: + kwargs.pop(pop_key) + # replace kwargs with parsed versions kwargs["params"] = params + kwargs["size"] = size + border = build_border(kwargs) + kwargs["border"] = border + + upscale = build_upscale(kwargs) + kwargs["upscale"] = upscale + + # prepare the stage metadata stage = StageParams( stage_data.get("name", stage_class.__name__), - tile_size=get_size(kwargs.get("tile_size")), + tile_size=get_size(kwargs.get("tiles")), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), ) - if "border" in kwargs: - border = Border.even(int(kwargs.get("border"))) - kwargs["border"] = border - - if "upscale" in kwargs: - upscale = UpscaleParams(kwargs.get("upscale")) - kwargs["upscale"] = upscale - + # load any images related to this stage stage_source_name = "source:%s" % (stage.name) stage_mask_name = "mask:%s" % (stage.name) @@ -494,7 +498,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): sources.append(source) device, params, size = pipeline_from_request(server) - upscale = upscale_from_request() + upscale = build_upscale() output = make_output_name(server, "upscale", params, size) job_name = output[0] diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 21da25f4..6525d4ae 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -51,7 +51,7 @@ class ModelCache: return for i in range(len(cache)): - t, k, v = cache[i] + t, k, _v = cache[i] if tag == t and key != k: logger.debug("updating model cache: %s %s", tag, key) cache[i] = (tag, key, value) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index c32fdbf2..bc6ef14d 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import numpy as np from flask import request @@ -34,16 +34,10 @@ from .utils import get_model_path logger = getLogger(__name__) -def pipeline_from_request( +def build_device( server: ServerContext, - default_pipeline: str = "txt2img", - data: Dict[str, str] = None, -) -> Tuple[DeviceParams, ImageParams, Size]: - user = request.remote_addr - - if data is None: - data = request.args - + data: Dict[str, str], +) -> Optional[DeviceParams]: # platform stuff device = None device_name = data.get("platform") @@ -53,6 +47,14 @@ def pipeline_from_request( if platform.device == device_name: device = platform + return device + + +def build_params( + server: ServerContext, + default_pipeline: str, + data: Dict[str, str], +) -> ImageParams: # diffusion model model = get_not_empty(data, "model", get_config_value("model")) model_path = get_model_path(server, model) @@ -115,20 +117,6 @@ def pipeline_from_request( get_config_value("steps", "max"), get_config_value("steps", "min"), ) - height = get_and_clamp_int( - data, - "height", - get_config_value("height"), - get_config_value("height", "max"), - get_config_value("height", "min"), - ) - width = get_and_clamp_int( - data, - "width", - get_config_value("width"), - get_config_value("width", "max"), - get_config_value("width", "min"), - ) tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE")) tiles = get_and_clamp_int( data, @@ -161,21 +149,6 @@ def pipeline_from_request( # this one can safely use np.random because it produces a single value seed = np.random.randint(np.iinfo(np.int32).max) - logger.info( - "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", - user, - steps, - scheduler, - model_path, - pipeline, - device or "any device", - width, - height, - cfg, - seed, - prompt, - ) - params = ImageParams( model_path, pipeline, @@ -194,34 +167,60 @@ def pipeline_from_request( overlap=overlap, stride=stride, ) - size = Size(width, height) - return (device, params, size) + + return params -def border_from_request() -> Border: +def build_size( + server: ServerContext, + data: Dict[str, str], +) -> Size: + height = get_and_clamp_int( + data, + "height", + get_config_value("height"), + get_config_value("height", "max"), + get_config_value("height", "min"), + ) + width = get_and_clamp_int( + data, + "width", + get_config_value("width"), + get_config_value("width", "max"), + get_config_value("width", "min"), + ) + return Size(width, height) + + +def build_border( + data: Dict[str, str] = None, +) -> Border: + if data is None: + data = request.args + left = get_and_clamp_int( - request.args, + data, "left", get_config_value("left"), get_config_value("left", "max"), get_config_value("left", "min"), ) right = get_and_clamp_int( - request.args, + data, "right", get_config_value("right"), get_config_value("right", "max"), get_config_value("right", "min"), ) top = get_and_clamp_int( - request.args, + data, "top", get_config_value("top"), get_config_value("top", "max"), get_config_value("top", "min"), ) bottom = get_and_clamp_int( - request.args, + data, "bottom", get_config_value("bottom"), get_config_value("bottom", "max"), @@ -231,46 +230,51 @@ def border_from_request() -> Border: return Border(left, right, top, bottom) -def upscale_from_request() -> UpscaleParams: +def build_upscale( + data: Dict[str, str] = None, +) -> UpscaleParams: + if data is None: + data = request.args + denoise = get_and_clamp_float( - request.args, + data, "denoise", get_config_value("denoise"), get_config_value("denoise", "max"), get_config_value("denoise", "min"), ) scale = get_and_clamp_int( - request.args, + data, "scale", get_config_value("scale"), get_config_value("scale", "max"), get_config_value("scale", "min"), ) outscale = get_and_clamp_int( - request.args, + data, "outscale", get_config_value("outscale"), get_config_value("outscale", "max"), get_config_value("outscale", "min"), ) - upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) - correction = get_from_list(request.args, "correction", get_correction_models()) - faces = get_not_empty(request.args, "faces", "false") == "true" + upscaling = get_from_list(data, "upscaling", get_upscaling_models()) + correction = get_from_list(data, "correction", get_correction_models()) + faces = get_not_empty(data, "faces", "false") == "true" face_outscale = get_and_clamp_int( - request.args, + data, "faceOutscale", get_config_value("faceOutscale"), get_config_value("faceOutscale", "max"), get_config_value("faceOutscale", "min"), ) face_strength = get_and_clamp_float( - request.args, + data, "faceStrength", get_config_value("faceStrength"), get_config_value("faceStrength", "max"), get_config_value("faceStrength", "min"), ) - upscale_order = request.args.get("upscaleOrder", "correction-first") + upscale_order = data.get("upscaleOrder", "correction-first") return UpscaleParams( upscaling, @@ -286,37 +290,43 @@ def upscale_from_request() -> UpscaleParams: ) -def highres_from_request() -> HighresParams: - enabled = get_boolean(request.args, "highres", get_config_value("highres")) +def build_highres( + data: Dict[str, str] = None, +) -> HighresParams: + if data is None: + data = request.args + + enabled = get_boolean(data, "highres", get_config_value("highres")) iterations = get_and_clamp_int( - request.args, + data, "highresIterations", get_config_value("highresIterations"), get_config_value("highresIterations", "max"), get_config_value("highresIterations", "min"), ) - method = get_from_list(request.args, "highresMethod", get_highres_methods()) + method = get_from_list(data, "highresMethod", get_highres_methods()) scale = get_and_clamp_int( - request.args, + data, "highresScale", get_config_value("highresScale"), get_config_value("highresScale", "max"), get_config_value("highresScale", "min"), ) steps = get_and_clamp_int( - request.args, + data, "highresSteps", get_config_value("highresSteps"), get_config_value("highresSteps", "max"), get_config_value("highresSteps", "min"), ) strength = get_and_clamp_float( - request.args, + data, "highresStrength", get_config_value("highresStrength"), get_config_value("highresStrength", "max"), get_config_value("highresStrength", "min"), ) + return HighresParams( enabled, scale, @@ -325,3 +335,50 @@ def highres_from_request() -> HighresParams: method=method, iterations=iterations, ) + + +PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] + + +def pipeline_from_json( + server: ServerContext, + data: Dict[str, str], + default_pipeline: str = "txt2img", +) -> PipelineParams: + """ + Like pipeline_from_request but expects a nested structure. + """ + + device = build_device(server, data.get("device", data)) + params = build_params(server, default_pipeline, data.get("params", data)) + size = build_size(server, data.get("params", data)) + + return (device, params, size) + + +def pipeline_from_request( + server: ServerContext, + default_pipeline: str = "txt2img", +) -> PipelineParams: + user = request.remote_addr + + device = build_device(server, request.args) + params = build_params(server, default_pipeline, request.args) + size = build_size(server, request.args) + + logger.info( + "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", + user, + params.steps, + params.scheduler, + params.model_path, + params.pipeline, + device or "any device", + params.width, + params.height, + params.cfg, + params.seed, + params.prompt, + ) + + return (device, params, size) diff --git a/api/tests/chain/__init__.py b/api/tests/chain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/chain/test_blend_linear.py b/api/tests/chain/test_blend_linear.py new file mode 100644 index 00000000..9d20fe55 --- /dev/null +++ b/api/tests/chain/test_blend_linear.py @@ -0,0 +1,18 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_linear import BlendLinearStage + + +class BlendLinearStageTests(unittest.TestCase): + def test_stage(self): + stage = BlendLinearStage() + sources = [ + Image.new("RGB", (64, 64), "black"), + ] + stage_source = Image.new("RGB", (64, 64), "white") + result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127)) \ No newline at end of file diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py new file mode 100644 index 00000000..71b7d2e5 --- /dev/null +++ b/api/tests/chain/test_tile.py @@ -0,0 +1,42 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.tile import complete_tile + + +class TestCompleteTile(unittest.TestCase): + def test_with_complete_tile(self): + partial = Image.new("RGB", (64, 64)) + output = complete_tile(partial, 64) + + self.assertEqual(output.size, (64, 64)) + + def test_with_partial_tile(self): + partial = Image.new("RGB", (64, 32)) + output = complete_tile(partial, 64) + + self.assertEqual(output.size, (64, 64)) + + def test_with_nothing(self): + output = complete_tile(None, 64) + self.assertIsNone(output) + + +class TestNeedsTile(unittest.TestCase): + def test_with_undersized(self): + pass + + def test_with_oversized(self): + pass + + def test_with_mixed(self): + pass + + +class TestTileGrads(unittest.TestCase): + def test_center_tile(self): + pass + + def test_edge_tile(self): + pass diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py new file mode 100644 index 00000000..f5e17a84 --- /dev/null +++ b/api/tests/chain/test_upscale_highres.py @@ -0,0 +1,13 @@ +import unittest + +from onnx_web.chain.upscale_highres import UpscaleHighresStage +from onnx_web.params import HighresParams, UpscaleParams + + +class UpscaleHighresStageTests(unittest.TestCase): + def test_empty(self): + stage = UpscaleHighresStage() + sources = [] + result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) + + self.assertEqual(len(result), 0) \ No newline at end of file diff --git a/api/tests/models/__init__.py b/api/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/models/test_meta.py b/api/tests/models/test_meta.py new file mode 100644 index 00000000..458c8c37 --- /dev/null +++ b/api/tests/models/test_meta.py @@ -0,0 +1,12 @@ +import unittest + +from onnx_web.models.meta import NetworkModel + + +class NetworkModelTests(unittest.TestCase): + def test_json(self): + model = NetworkModel("test", "inversion") + json = model.tojson() + + self.assertIn("name", json) + self.assertIn("type", json) diff --git a/api/tests/prompt/test_parser.py b/api/tests/prompt/test_parser.py index 20c03341..b6b13a23 100644 --- a/api/tests/prompt/test_parser.py +++ b/api/tests/prompt/test_parser.py @@ -1,7 +1,9 @@ import unittest + from onnx_web.prompt.grammar import PromptPhrase from onnx_web.prompt.parser import parse_prompt_onnx + class ParserTests(unittest.TestCase): def test_single_word_phrase(self): res = parse_prompt_onnx(None, "foo (bar) bin", debug=False) diff --git a/api/tests/server/test_model_cache.py b/api/tests/server/test_model_cache.py index 000065d0..0e4839c9 100644 --- a/api/tests/server/test_model_cache.py +++ b/api/tests/server/test_model_cache.py @@ -2,7 +2,8 @@ import unittest from onnx_web.server.model_cache import ModelCache -class TestStringMethods(unittest.TestCase): + +class TestModelCache(unittest.TestCase): def test_drop_existing(self): cache = ModelCache(10) cache.clear() @@ -32,3 +33,31 @@ class TestStringMethods(unittest.TestCase): cache.set("foo", ("bar",), value) self.assertGreater(cache.size, 0) self.assertIs(cache.get("foo", ("bin",)), None) + + """ + def test_set_existing(self): + cache = ModelCache(10) + cache.clear() + cache.set("foo", ("bar",), { + "value": 1, + }) + value = { + "value": 2, + } + cache.set("foo", ("bar",), value) + self.assertIs(cache.get("foo", ("bar",)), value) + """ + + def test_set_missing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertIs(cache.get("foo", ("bar",)), value) + + def test_set_zero(self): + cache = ModelCache(0) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertEqual(cache.size, 0) diff --git a/api/tests/test_params.py b/api/tests/test_params.py index 0f84cfab..09eb7576 100644 --- a/api/tests/test_params.py +++ b/api/tests/test_params.py @@ -2,6 +2,7 @@ import unittest from onnx_web.params import Border, Size + class BorderTests(unittest.TestCase): def test_json(self): border = Border.even(0) diff --git a/api/tests/test_test.py b/api/tests/test_test.py index c6ee310c..ae70622f 100644 --- a/api/tests/test_test.py +++ b/api/tests/test_test.py @@ -1,5 +1,6 @@ import unittest + # just to get CI happy class ErrorTest(unittest.TestCase): def test(self): diff --git a/api/tests/worker/__init__.py b/api/tests/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py new file mode 100644 index 00000000..cbb4c42c --- /dev/null +++ b/api/tests/worker/test_pool.py @@ -0,0 +1,12 @@ +import unittest + +from onnx_web.server.context import ServerContext +from onnx_web.worker.pool import DevicePoolExecutor + + +class TestWorkerPool(unittest.TestCase): + def test_no_devices(self): + server = ServerContext() + pool = DevicePoolExecutor(server, []) + pool.start() + pool.join() \ No newline at end of file