1
0
Fork 0

fix(api): make request parsing consistent between JSON and forms

This commit is contained in:
Sean Sube 2023-09-13 17:27:44 -05:00
parent 8a5e211172
commit a33c88e670
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
17 changed files with 295 additions and 104 deletions

View File

@ -87,15 +87,15 @@ class ChainPipeline:
def steps(self, params: ImageParams, size: Size):
steps = 0
for callback, _params, _kwargs in self.stages:
steps += callback.steps(params, size)
for callback, _params, kwargs in self.stages:
steps += callback.steps(kwargs.get("params", params), size)
return steps
def outputs(self, params: ImageParams, sources: int):
outputs = sources
for callback, _params, _kwargs in self.stages:
outputs += callback.outputs(params, outputs)
for callback, _params, kwargs in self.stages:
outputs += callback.outputs(kwargs.get("params", params), outputs)
return outputs

View File

@ -21,14 +21,14 @@ class BaseStage:
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
raise NotImplementedError()
raise NotImplementedError() # noqa
def steps(
self,
params: ImageParams,
size: Size,
) -> int:
return 1
return 1 # noqa
def outputs(
self,

View File

@ -18,7 +18,7 @@ from ..diffusers.run import (
)
from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
from ..params import Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
base_join,
@ -50,10 +50,11 @@ from .load import (
get_wildcard_data,
)
from .params import (
border_from_request,
highres_from_request,
build_border,
build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request,
upscale_from_request,
)
from .utils import wrap_route
@ -168,8 +169,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
size = Size(source.width, source.height)
device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys())
)
@ -217,8 +218,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img")
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
replace_wildcards(params, get_wildcard_data())
@ -271,9 +272,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
)
device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
upscale = upscale_from_request()
highres = highres_from_request()
expand = build_border()
upscale = build_upscale()
highres = build_highres()
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -341,8 +342,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
highres = highres_from_request()
upscale = build_upscale()
highres = build_highres()
replace_wildcards(params, get_wildcard_data())
@ -367,6 +368,10 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
# keys that are specially parsed by params and should not show up in with_args
CHAIN_POP_KEYS = ["model", "control"]
def chain(server: ServerContext, pool: DevicePoolExecutor):
if request.is_json:
logger.debug("chain pipeline request with JSON body")
@ -386,9 +391,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema)
# get defaults from the regular parameters
device, base_params, base_size = pipeline_from_request(
server, data=data.get("defaults", None)
device, base_params, base_size = pipeline_from_json(
server, data=data.get("defaults")
)
# start building the pipeline
@ -399,32 +403,32 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
# TODO: combine base params with stage params
_device, params, size = pipeline_from_request(server, data=kwargs)
_device, params, size = pipeline_from_json(server, data=kwargs)
replace_wildcards(params, get_wildcard_data())
if "model" in kwargs:
kwargs.pop("model")
if "control" in kwargs:
logger.warning("TODO: resolve controlnet model")
kwargs.pop("control")
# remove parsed keys, like model names (which become paths)
for pop_key in CHAIN_POP_KEYS:
if pop_key in kwargs:
kwargs.pop(pop_key)
# replace kwargs with parsed versions
kwargs["params"] = params
kwargs["size"] = size
border = build_border(kwargs)
kwargs["border"] = border
upscale = build_upscale(kwargs)
kwargs["upscale"] = upscale
# prepare the stage metadata
stage = StageParams(
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")),
tile_size=get_size(kwargs.get("tiles")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
if "border" in kwargs:
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
# load any images related to this stage
stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name)
@ -494,7 +498,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
sources.append(source)
device, params, size = pipeline_from_request(server)
upscale = upscale_from_request()
upscale = build_upscale()
output = make_output_name(server, "upscale", params, size)
job_name = output[0]

View File

@ -51,7 +51,7 @@ class ModelCache:
return
for i in range(len(cache)):
t, k, v = cache[i]
t, k, _v = cache[i]
if tag == t and key != k:
logger.debug("updating model cache: %s %s", tag, key)
cache[i] = (tag, key, value)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple
import numpy as np
from flask import request
@ -34,16 +34,10 @@ from .utils import get_model_path
logger = getLogger(__name__)
def pipeline_from_request(
def build_device(
server: ServerContext,
default_pipeline: str = "txt2img",
data: Dict[str, str] = None,
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
if data is None:
data = request.args
data: Dict[str, str],
) -> Optional[DeviceParams]:
# platform stuff
device = None
device_name = data.get("platform")
@ -53,6 +47,14 @@ def pipeline_from_request(
if platform.device == device_name:
device = platform
return device
def build_params(
server: ServerContext,
default_pipeline: str,
data: Dict[str, str],
) -> ImageParams:
# diffusion model
model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model)
@ -115,20 +117,6 @@ def pipeline_from_request(
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
data,
@ -161,21 +149,6 @@ def pipeline_from_request(
# this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams(
model_path,
pipeline,
@ -194,34 +167,60 @@ def pipeline_from_request(
overlap=overlap,
stride=stride,
)
size = Size(width, height)
return (device, params, size)
return params
def border_from_request() -> Border:
def build_size(
server: ServerContext,
data: Dict[str, str],
) -> Size:
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
return Size(width, height)
def build_border(
data: Dict[str, str] = None,
) -> Border:
if data is None:
data = request.args
left = get_and_clamp_int(
request.args,
data,
"left",
get_config_value("left"),
get_config_value("left", "max"),
get_config_value("left", "min"),
)
right = get_and_clamp_int(
request.args,
data,
"right",
get_config_value("right"),
get_config_value("right", "max"),
get_config_value("right", "min"),
)
top = get_and_clamp_int(
request.args,
data,
"top",
get_config_value("top"),
get_config_value("top", "max"),
get_config_value("top", "min"),
)
bottom = get_and_clamp_int(
request.args,
data,
"bottom",
get_config_value("bottom"),
get_config_value("bottom", "max"),
@ -231,46 +230,51 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams:
def build_upscale(
data: Dict[str, str] = None,
) -> UpscaleParams:
if data is None:
data = request.args
denoise = get_and_clamp_float(
request.args,
data,
"denoise",
get_config_value("denoise"),
get_config_value("denoise", "max"),
get_config_value("denoise", "min"),
)
scale = get_and_clamp_int(
request.args,
data,
"scale",
get_config_value("scale"),
get_config_value("scale", "max"),
get_config_value("scale", "min"),
)
outscale = get_and_clamp_int(
request.args,
data,
"outscale",
get_config_value("outscale"),
get_config_value("outscale", "max"),
get_config_value("outscale", "min"),
)
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
correction = get_from_list(request.args, "correction", get_correction_models())
faces = get_not_empty(request.args, "faces", "false") == "true"
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
correction = get_from_list(data, "correction", get_correction_models())
faces = get_not_empty(data, "faces", "false") == "true"
face_outscale = get_and_clamp_int(
request.args,
data,
"faceOutscale",
get_config_value("faceOutscale"),
get_config_value("faceOutscale", "max"),
get_config_value("faceOutscale", "min"),
)
face_strength = get_and_clamp_float(
request.args,
data,
"faceStrength",
get_config_value("faceStrength"),
get_config_value("faceStrength", "max"),
get_config_value("faceStrength", "min"),
)
upscale_order = request.args.get("upscaleOrder", "correction-first")
upscale_order = data.get("upscaleOrder", "correction-first")
return UpscaleParams(
upscaling,
@ -286,37 +290,43 @@ def upscale_from_request() -> UpscaleParams:
)
def highres_from_request() -> HighresParams:
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
def build_highres(
data: Dict[str, str] = None,
) -> HighresParams:
if data is None:
data = request.args
enabled = get_boolean(data, "highres", get_config_value("highres"))
iterations = get_and_clamp_int(
request.args,
data,
"highresIterations",
get_config_value("highresIterations"),
get_config_value("highresIterations", "max"),
get_config_value("highresIterations", "min"),
)
method = get_from_list(request.args, "highresMethod", get_highres_methods())
method = get_from_list(data, "highresMethod", get_highres_methods())
scale = get_and_clamp_int(
request.args,
data,
"highresScale",
get_config_value("highresScale"),
get_config_value("highresScale", "max"),
get_config_value("highresScale", "min"),
)
steps = get_and_clamp_int(
request.args,
data,
"highresSteps",
get_config_value("highresSteps"),
get_config_value("highresSteps", "max"),
get_config_value("highresSteps", "min"),
)
strength = get_and_clamp_float(
request.args,
data,
"highresStrength",
get_config_value("highresStrength"),
get_config_value("highresStrength", "max"),
get_config_value("highresStrength", "min"),
)
return HighresParams(
enabled,
scale,
@ -325,3 +335,50 @@ def highres_from_request() -> HighresParams:
method=method,
iterations=iterations,
)
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
def pipeline_from_json(
server: ServerContext,
data: Dict[str, str],
default_pipeline: str = "txt2img",
) -> PipelineParams:
"""
Like pipeline_from_request but expects a nested structure.
"""
device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data))
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> PipelineParams:
user = request.remote_addr
device = build_device(server, request.args)
params = build_params(server, default_pipeline, request.args)
size = build_size(server, request.args)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
params.steps,
params.scheduler,
params.model_path,
params.pipeline,
device or "any device",
params.width,
params.height,
params.cfg,
params.seed,
params.prompt,
)
return (device, params, size)

View File

View File

@ -0,0 +1,18 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage
class BlendLinearStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendLinearStage()
sources = [
Image.new("RGB", (64, 64), "black"),
]
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))

View File

@ -0,0 +1,42 @@
import unittest
from PIL import Image
from onnx_web.chain.tile import complete_tile
class TestCompleteTile(unittest.TestCase):
def test_with_complete_tile(self):
partial = Image.new("RGB", (64, 64))
output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64))
def test_with_partial_tile(self):
partial = Image.new("RGB", (64, 32))
output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64))
def test_with_nothing(self):
output = complete_tile(None, 64)
self.assertIsNone(output)
class TestNeedsTile(unittest.TestCase):
def test_with_undersized(self):
pass
def test_with_oversized(self):
pass
def test_with_mixed(self):
pass
class TestTileGrads(unittest.TestCase):
def test_center_tile(self):
pass
def test_edge_tile(self):
pass

View File

@ -0,0 +1,13 @@
import unittest
from onnx_web.chain.upscale_highres import UpscaleHighresStage
from onnx_web.params import HighresParams, UpscaleParams
class UpscaleHighresStageTests(unittest.TestCase):
def test_empty(self):
stage = UpscaleHighresStage()
sources = []
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
self.assertEqual(len(result), 0)

View File

View File

@ -0,0 +1,12 @@
import unittest
from onnx_web.models.meta import NetworkModel
class NetworkModelTests(unittest.TestCase):
def test_json(self):
model = NetworkModel("test", "inversion")
json = model.tojson()
self.assertIn("name", json)
self.assertIn("type", json)

View File

@ -1,7 +1,9 @@
import unittest
from onnx_web.prompt.grammar import PromptPhrase
from onnx_web.prompt.parser import parse_prompt_onnx
class ParserTests(unittest.TestCase):
def test_single_word_phrase(self):
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)

View File

@ -2,7 +2,8 @@ import unittest
from onnx_web.server.model_cache import ModelCache
class TestStringMethods(unittest.TestCase):
class TestModelCache(unittest.TestCase):
def test_drop_existing(self):
cache = ModelCache(10)
cache.clear()
@ -32,3 +33,31 @@ class TestStringMethods(unittest.TestCase):
cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bin",)), None)
"""
def test_set_existing(self):
cache = ModelCache(10)
cache.clear()
cache.set("foo", ("bar",), {
"value": 1,
})
value = {
"value": 2,
}
cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value)
"""
def test_set_missing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value)
def test_set_zero(self):
cache = ModelCache(0)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertEqual(cache.size, 0)

View File

@ -2,6 +2,7 @@ import unittest
from onnx_web.params import Border, Size
class BorderTests(unittest.TestCase):
def test_json(self):
border = Border.even(0)

View File

@ -1,5 +1,6 @@
import unittest
# just to get CI happy
class ErrorTest(unittest.TestCase):
def test(self):

View File

View File

@ -0,0 +1,12 @@
import unittest
from onnx_web.server.context import ServerContext
from onnx_web.worker.pool import DevicePoolExecutor
class TestWorkerPool(unittest.TestCase):
def test_no_devices(self):
server = ServerContext()
pool = DevicePoolExecutor(server, [])
pool.start()
pool.join()