1
0
Fork 0

fix(api): make request parsing consistent between JSON and forms

This commit is contained in:
Sean Sube 2023-09-13 17:27:44 -05:00
parent 8a5e211172
commit a33c88e670
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
17 changed files with 295 additions and 104 deletions

View File

@ -87,15 +87,15 @@ class ChainPipeline:
def steps(self, params: ImageParams, size: Size): def steps(self, params: ImageParams, size: Size):
steps = 0 steps = 0
for callback, _params, _kwargs in self.stages: for callback, _params, kwargs in self.stages:
steps += callback.steps(params, size) steps += callback.steps(kwargs.get("params", params), size)
return steps return steps
def outputs(self, params: ImageParams, sources: int): def outputs(self, params: ImageParams, sources: int):
outputs = sources outputs = sources
for callback, _params, _kwargs in self.stages: for callback, _params, kwargs in self.stages:
outputs += callback.outputs(params, outputs) outputs += callback.outputs(kwargs.get("params", params), outputs)
return outputs return outputs

View File

@ -21,14 +21,14 @@ class BaseStage:
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> List[Image.Image]:
raise NotImplementedError() raise NotImplementedError() # noqa
def steps( def steps(
self, self,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
) -> int: ) -> int:
return 1 return 1 # noqa
def outputs( def outputs(
self, self,

View File

@ -18,7 +18,7 @@ from ..diffusers.run import (
) )
from ..diffusers.utils import replace_wildcards from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name from ..output import json_params, make_output_name
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams from ..params import Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline from ..transformers.run import run_txt2txt_pipeline
from ..utils import ( from ..utils import (
base_join, base_join,
@ -50,10 +50,11 @@ from .load import (
get_wildcard_data, get_wildcard_data,
) )
from .params import ( from .params import (
border_from_request, build_border,
highres_from_request, build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request, pipeline_from_request,
upscale_from_request,
) )
from .utils import wrap_route from .utils import wrap_route
@ -168,8 +169,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
size = Size(source.width, source.height) size = Size(source.width, source.height)
device, params, _size = pipeline_from_request(server, "img2img") device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
source_filter = get_from_list( source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys()) request.args, "sourceFilter", list(get_source_filters().keys())
) )
@ -217,8 +218,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img") device, params, size = pipeline_from_request(server, "txt2img")
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
@ -271,9 +272,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
) )
device, params, _size = pipeline_from_request(server, "inpaint") device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request() expand = build_border()
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
fill_color = get_not_empty(request.args, "fillColor", "white") fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -341,8 +342,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = build_upscale()
highres = highres_from_request() highres = build_highres()
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
@ -367,6 +368,10 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
# keys that are specially parsed by params and should not show up in with_args
CHAIN_POP_KEYS = ["model", "control"]
def chain(server: ServerContext, pool: DevicePoolExecutor): def chain(server: ServerContext, pool: DevicePoolExecutor):
if request.is_json: if request.is_json:
logger.debug("chain pipeline request with JSON body") logger.debug("chain pipeline request with JSON body")
@ -386,9 +391,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.debug("validating chain request: %s against %s", data, schema) logger.debug("validating chain request: %s against %s", data, schema)
validate(data, schema) validate(data, schema)
# get defaults from the regular parameters device, base_params, base_size = pipeline_from_json(
device, base_params, base_size = pipeline_from_request( server, data=data.get("defaults")
server, data=data.get("defaults", None)
) )
# start building the pipeline # start building the pipeline
@ -399,32 +403,32 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("request stage: %s, %s", stage_class.__name__, kwargs) logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
# TODO: combine base params with stage params # TODO: combine base params with stage params
_device, params, size = pipeline_from_request(server, data=kwargs) _device, params, size = pipeline_from_json(server, data=kwargs)
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
if "model" in kwargs: # remove parsed keys, like model names (which become paths)
kwargs.pop("model") for pop_key in CHAIN_POP_KEYS:
if pop_key in kwargs:
if "control" in kwargs: kwargs.pop(pop_key)
logger.warning("TODO: resolve controlnet model")
kwargs.pop("control")
# replace kwargs with parsed versions
kwargs["params"] = params kwargs["params"] = params
kwargs["size"] = size
border = build_border(kwargs)
kwargs["border"] = border
upscale = build_upscale(kwargs)
kwargs["upscale"] = upscale
# prepare the stage metadata
stage = StageParams( stage = StageParams(
stage_data.get("name", stage_class.__name__), stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")), tile_size=get_size(kwargs.get("tiles")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
) )
if "border" in kwargs: # load any images related to this stage
border = Border.even(int(kwargs.get("border")))
kwargs["border"] = border
if "upscale" in kwargs:
upscale = UpscaleParams(kwargs.get("upscale"))
kwargs["upscale"] = upscale
stage_source_name = "source:%s" % (stage.name) stage_source_name = "source:%s" % (stage.name)
stage_mask_name = "mask:%s" % (stage.name) stage_mask_name = "mask:%s" % (stage.name)
@ -494,7 +498,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
sources.append(source) sources.append(source)
device, params, size = pipeline_from_request(server) device, params, size = pipeline_from_request(server)
upscale = upscale_from_request() upscale = build_upscale()
output = make_output_name(server, "upscale", params, size) output = make_output_name(server, "upscale", params, size)
job_name = output[0] job_name = output[0]

View File

@ -51,7 +51,7 @@ class ModelCache:
return return
for i in range(len(cache)): for i in range(len(cache)):
t, k, v = cache[i] t, k, _v = cache[i]
if tag == t and key != k: if tag == t and key != k:
logger.debug("updating model cache: %s %s", tag, key) logger.debug("updating model cache: %s %s", tag, key)
cache[i] = (tag, key, value) cache[i] = (tag, key, value)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Dict, Tuple from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
from flask import request from flask import request
@ -34,16 +34,10 @@ from .utils import get_model_path
logger = getLogger(__name__) logger = getLogger(__name__)
def pipeline_from_request( def build_device(
server: ServerContext, server: ServerContext,
default_pipeline: str = "txt2img", data: Dict[str, str],
data: Dict[str, str] = None, ) -> Optional[DeviceParams]:
) -> Tuple[DeviceParams, ImageParams, Size]:
user = request.remote_addr
if data is None:
data = request.args
# platform stuff # platform stuff
device = None device = None
device_name = data.get("platform") device_name = data.get("platform")
@ -53,6 +47,14 @@ def pipeline_from_request(
if platform.device == device_name: if platform.device == device_name:
device = platform device = platform
return device
def build_params(
server: ServerContext,
default_pipeline: str,
data: Dict[str, str],
) -> ImageParams:
# diffusion model # diffusion model
model = get_not_empty(data, "model", get_config_value("model")) model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model) model_path = get_model_path(server, model)
@ -115,20 +117,6 @@ def pipeline_from_request(
get_config_value("steps", "max"), get_config_value("steps", "max"),
get_config_value("steps", "min"), get_config_value("steps", "min"),
) )
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE")) tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int( tiles = get_and_clamp_int(
data, data,
@ -161,21 +149,6 @@ def pipeline_from_request(
# this one can safely use np.random because it produces a single value # this one can safely use np.random because it produces a single value
seed = np.random.randint(np.iinfo(np.int32).max) seed = np.random.randint(np.iinfo(np.int32).max)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
steps,
scheduler,
model_path,
pipeline,
device or "any device",
width,
height,
cfg,
seed,
prompt,
)
params = ImageParams( params = ImageParams(
model_path, model_path,
pipeline, pipeline,
@ -194,34 +167,60 @@ def pipeline_from_request(
overlap=overlap, overlap=overlap,
stride=stride, stride=stride,
) )
size = Size(width, height)
return (device, params, size) return params
def border_from_request() -> Border: def build_size(
server: ServerContext,
data: Dict[str, str],
) -> Size:
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
return Size(width, height)
def build_border(
data: Dict[str, str] = None,
) -> Border:
if data is None:
data = request.args
left = get_and_clamp_int( left = get_and_clamp_int(
request.args, data,
"left", "left",
get_config_value("left"), get_config_value("left"),
get_config_value("left", "max"), get_config_value("left", "max"),
get_config_value("left", "min"), get_config_value("left", "min"),
) )
right = get_and_clamp_int( right = get_and_clamp_int(
request.args, data,
"right", "right",
get_config_value("right"), get_config_value("right"),
get_config_value("right", "max"), get_config_value("right", "max"),
get_config_value("right", "min"), get_config_value("right", "min"),
) )
top = get_and_clamp_int( top = get_and_clamp_int(
request.args, data,
"top", "top",
get_config_value("top"), get_config_value("top"),
get_config_value("top", "max"), get_config_value("top", "max"),
get_config_value("top", "min"), get_config_value("top", "min"),
) )
bottom = get_and_clamp_int( bottom = get_and_clamp_int(
request.args, data,
"bottom", "bottom",
get_config_value("bottom"), get_config_value("bottom"),
get_config_value("bottom", "max"), get_config_value("bottom", "max"),
@ -231,46 +230,51 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom) return Border(left, right, top, bottom)
def upscale_from_request() -> UpscaleParams: def build_upscale(
data: Dict[str, str] = None,
) -> UpscaleParams:
if data is None:
data = request.args
denoise = get_and_clamp_float( denoise = get_and_clamp_float(
request.args, data,
"denoise", "denoise",
get_config_value("denoise"), get_config_value("denoise"),
get_config_value("denoise", "max"), get_config_value("denoise", "max"),
get_config_value("denoise", "min"), get_config_value("denoise", "min"),
) )
scale = get_and_clamp_int( scale = get_and_clamp_int(
request.args, data,
"scale", "scale",
get_config_value("scale"), get_config_value("scale"),
get_config_value("scale", "max"), get_config_value("scale", "max"),
get_config_value("scale", "min"), get_config_value("scale", "min"),
) )
outscale = get_and_clamp_int( outscale = get_and_clamp_int(
request.args, data,
"outscale", "outscale",
get_config_value("outscale"), get_config_value("outscale"),
get_config_value("outscale", "max"), get_config_value("outscale", "max"),
get_config_value("outscale", "min"), get_config_value("outscale", "min"),
) )
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models()) upscaling = get_from_list(data, "upscaling", get_upscaling_models())
correction = get_from_list(request.args, "correction", get_correction_models()) correction = get_from_list(data, "correction", get_correction_models())
faces = get_not_empty(request.args, "faces", "false") == "true" faces = get_not_empty(data, "faces", "false") == "true"
face_outscale = get_and_clamp_int( face_outscale = get_and_clamp_int(
request.args, data,
"faceOutscale", "faceOutscale",
get_config_value("faceOutscale"), get_config_value("faceOutscale"),
get_config_value("faceOutscale", "max"), get_config_value("faceOutscale", "max"),
get_config_value("faceOutscale", "min"), get_config_value("faceOutscale", "min"),
) )
face_strength = get_and_clamp_float( face_strength = get_and_clamp_float(
request.args, data,
"faceStrength", "faceStrength",
get_config_value("faceStrength"), get_config_value("faceStrength"),
get_config_value("faceStrength", "max"), get_config_value("faceStrength", "max"),
get_config_value("faceStrength", "min"), get_config_value("faceStrength", "min"),
) )
upscale_order = request.args.get("upscaleOrder", "correction-first") upscale_order = data.get("upscaleOrder", "correction-first")
return UpscaleParams( return UpscaleParams(
upscaling, upscaling,
@ -286,37 +290,43 @@ def upscale_from_request() -> UpscaleParams:
) )
def highres_from_request() -> HighresParams: def build_highres(
enabled = get_boolean(request.args, "highres", get_config_value("highres")) data: Dict[str, str] = None,
) -> HighresParams:
if data is None:
data = request.args
enabled = get_boolean(data, "highres", get_config_value("highres"))
iterations = get_and_clamp_int( iterations = get_and_clamp_int(
request.args, data,
"highresIterations", "highresIterations",
get_config_value("highresIterations"), get_config_value("highresIterations"),
get_config_value("highresIterations", "max"), get_config_value("highresIterations", "max"),
get_config_value("highresIterations", "min"), get_config_value("highresIterations", "min"),
) )
method = get_from_list(request.args, "highresMethod", get_highres_methods()) method = get_from_list(data, "highresMethod", get_highres_methods())
scale = get_and_clamp_int( scale = get_and_clamp_int(
request.args, data,
"highresScale", "highresScale",
get_config_value("highresScale"), get_config_value("highresScale"),
get_config_value("highresScale", "max"), get_config_value("highresScale", "max"),
get_config_value("highresScale", "min"), get_config_value("highresScale", "min"),
) )
steps = get_and_clamp_int( steps = get_and_clamp_int(
request.args, data,
"highresSteps", "highresSteps",
get_config_value("highresSteps"), get_config_value("highresSteps"),
get_config_value("highresSteps", "max"), get_config_value("highresSteps", "max"),
get_config_value("highresSteps", "min"), get_config_value("highresSteps", "min"),
) )
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, data,
"highresStrength", "highresStrength",
get_config_value("highresStrength"), get_config_value("highresStrength"),
get_config_value("highresStrength", "max"), get_config_value("highresStrength", "max"),
get_config_value("highresStrength", "min"), get_config_value("highresStrength", "min"),
) )
return HighresParams( return HighresParams(
enabled, enabled,
scale, scale,
@ -325,3 +335,50 @@ def highres_from_request() -> HighresParams:
method=method, method=method,
iterations=iterations, iterations=iterations,
) )
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
def pipeline_from_json(
server: ServerContext,
data: Dict[str, str],
default_pipeline: str = "txt2img",
) -> PipelineParams:
"""
Like pipeline_from_request but expects a nested structure.
"""
device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data))
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> PipelineParams:
user = request.remote_addr
device = build_device(server, request.args)
params = build_params(server, default_pipeline, request.args)
size = build_size(server, request.args)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
params.steps,
params.scheduler,
params.model_path,
params.pipeline,
device or "any device",
params.width,
params.height,
params.cfg,
params.seed,
params.prompt,
)
return (device, params, size)

View File

View File

@ -0,0 +1,18 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage
class BlendLinearStageTests(unittest.TestCase):
def test_stage(self):
stage = BlendLinearStage()
sources = [
Image.new("RGB", (64, 64), "black"),
]
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))

View File

@ -0,0 +1,42 @@
import unittest
from PIL import Image
from onnx_web.chain.tile import complete_tile
class TestCompleteTile(unittest.TestCase):
def test_with_complete_tile(self):
partial = Image.new("RGB", (64, 64))
output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64))
def test_with_partial_tile(self):
partial = Image.new("RGB", (64, 32))
output = complete_tile(partial, 64)
self.assertEqual(output.size, (64, 64))
def test_with_nothing(self):
output = complete_tile(None, 64)
self.assertIsNone(output)
class TestNeedsTile(unittest.TestCase):
def test_with_undersized(self):
pass
def test_with_oversized(self):
pass
def test_with_mixed(self):
pass
class TestTileGrads(unittest.TestCase):
def test_center_tile(self):
pass
def test_edge_tile(self):
pass

View File

@ -0,0 +1,13 @@
import unittest
from onnx_web.chain.upscale_highres import UpscaleHighresStage
from onnx_web.params import HighresParams, UpscaleParams
class UpscaleHighresStageTests(unittest.TestCase):
def test_empty(self):
stage = UpscaleHighresStage()
sources = []
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
self.assertEqual(len(result), 0)

View File

View File

@ -0,0 +1,12 @@
import unittest
from onnx_web.models.meta import NetworkModel
class NetworkModelTests(unittest.TestCase):
def test_json(self):
model = NetworkModel("test", "inversion")
json = model.tojson()
self.assertIn("name", json)
self.assertIn("type", json)

View File

@ -1,7 +1,9 @@
import unittest import unittest
from onnx_web.prompt.grammar import PromptPhrase from onnx_web.prompt.grammar import PromptPhrase
from onnx_web.prompt.parser import parse_prompt_onnx from onnx_web.prompt.parser import parse_prompt_onnx
class ParserTests(unittest.TestCase): class ParserTests(unittest.TestCase):
def test_single_word_phrase(self): def test_single_word_phrase(self):
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False) res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)

View File

@ -2,7 +2,8 @@ import unittest
from onnx_web.server.model_cache import ModelCache from onnx_web.server.model_cache import ModelCache
class TestStringMethods(unittest.TestCase):
class TestModelCache(unittest.TestCase):
def test_drop_existing(self): def test_drop_existing(self):
cache = ModelCache(10) cache = ModelCache(10)
cache.clear() cache.clear()
@ -32,3 +33,31 @@ class TestStringMethods(unittest.TestCase):
cache.set("foo", ("bar",), value) cache.set("foo", ("bar",), value)
self.assertGreater(cache.size, 0) self.assertGreater(cache.size, 0)
self.assertIs(cache.get("foo", ("bin",)), None) self.assertIs(cache.get("foo", ("bin",)), None)
"""
def test_set_existing(self):
cache = ModelCache(10)
cache.clear()
cache.set("foo", ("bar",), {
"value": 1,
})
value = {
"value": 2,
}
cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value)
"""
def test_set_missing(self):
cache = ModelCache(10)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertIs(cache.get("foo", ("bar",)), value)
def test_set_zero(self):
cache = ModelCache(0)
cache.clear()
value = {}
cache.set("foo", ("bar",), value)
self.assertEqual(cache.size, 0)

View File

@ -2,6 +2,7 @@ import unittest
from onnx_web.params import Border, Size from onnx_web.params import Border, Size
class BorderTests(unittest.TestCase): class BorderTests(unittest.TestCase):
def test_json(self): def test_json(self):
border = Border.even(0) border = Border.even(0)

View File

@ -1,5 +1,6 @@
import unittest import unittest
# just to get CI happy # just to get CI happy
class ErrorTest(unittest.TestCase): class ErrorTest(unittest.TestCase):
def test(self): def test(self):

View File

View File

@ -0,0 +1,12 @@
import unittest
from onnx_web.server.context import ServerContext
from onnx_web.worker.pool import DevicePoolExecutor
class TestWorkerPool(unittest.TestCase):
def test_no_devices(self):
server = ServerContext()
pool = DevicePoolExecutor(server, [])
pool.start()
pool.join()