fix(api): make request parsing consistent between JSON and forms
This commit is contained in:
parent
8a5e211172
commit
a33c88e670
|
@ -87,15 +87,15 @@ class ChainPipeline:
|
|||
|
||||
def steps(self, params: ImageParams, size: Size):
|
||||
steps = 0
|
||||
for callback, _params, _kwargs in self.stages:
|
||||
steps += callback.steps(params, size)
|
||||
for callback, _params, kwargs in self.stages:
|
||||
steps += callback.steps(kwargs.get("params", params), size)
|
||||
|
||||
return steps
|
||||
|
||||
def outputs(self, params: ImageParams, sources: int):
|
||||
outputs = sources
|
||||
for callback, _params, _kwargs in self.stages:
|
||||
outputs += callback.outputs(params, outputs)
|
||||
for callback, _params, kwargs in self.stages:
|
||||
outputs += callback.outputs(kwargs.get("params", params), outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -21,14 +21,14 @@ class BaseStage:
|
|||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError() # noqa
|
||||
|
||||
def steps(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
return 1
|
||||
return 1 # noqa
|
||||
|
||||
def outputs(
|
||||
self,
|
||||
|
|
|
@ -18,7 +18,7 @@ from ..diffusers.run import (
|
|||
)
|
||||
from ..diffusers.utils import replace_wildcards
|
||||
from ..output import json_params, make_output_name
|
||||
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
|
||||
from ..params import Size, StageParams, TileOrder
|
||||
from ..transformers.run import run_txt2txt_pipeline
|
||||
from ..utils import (
|
||||
base_join,
|
||||
|
@ -50,10 +50,11 @@ from .load import (
|
|||
get_wildcard_data,
|
||||
)
|
||||
from .params import (
|
||||
border_from_request,
|
||||
highres_from_request,
|
||||
build_border,
|
||||
build_highres,
|
||||
build_upscale,
|
||||
pipeline_from_json,
|
||||
pipeline_from_request,
|
||||
upscale_from_request,
|
||||
)
|
||||
from .utils import wrap_route
|
||||
|
||||
|
@ -168,8 +169,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
size = Size(source.width, source.height)
|
||||
|
||||
device, params, _size = pipeline_from_request(server, "img2img")
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
source_filter = get_from_list(
|
||||
request.args, "sourceFilter", list(get_source_filters().keys())
|
||||
)
|
||||
|
@ -217,8 +218,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(server, "txt2img")
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
|
@ -271,9 +272,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
device, params, _size = pipeline_from_request(server, "inpaint")
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
expand = build_border()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
|
||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
||||
|
@ -341,8 +342,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
upscale = build_upscale()
|
||||
highres = build_highres()
|
||||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
|
@ -367,6 +368,10 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||
|
||||
|
||||
# keys that are specially parsed by params and should not show up in with_args
|
||||
CHAIN_POP_KEYS = ["model", "control"]
|
||||
|
||||
|
||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||
if request.is_json:
|
||||
logger.debug("chain pipeline request with JSON body")
|
||||
|
@ -386,9 +391,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
logger.debug("validating chain request: %s against %s", data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
device, base_params, base_size = pipeline_from_request(
|
||||
server, data=data.get("defaults", None)
|
||||
device, base_params, base_size = pipeline_from_json(
|
||||
server, data=data.get("defaults")
|
||||
)
|
||||
|
||||
# start building the pipeline
|
||||
|
@ -399,32 +403,32 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||
|
||||
# TODO: combine base params with stage params
|
||||
_device, params, size = pipeline_from_request(server, data=kwargs)
|
||||
_device, params, size = pipeline_from_json(server, data=kwargs)
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
|
||||
if "control" in kwargs:
|
||||
logger.warning("TODO: resolve controlnet model")
|
||||
kwargs.pop("control")
|
||||
# remove parsed keys, like model names (which become paths)
|
||||
for pop_key in CHAIN_POP_KEYS:
|
||||
if pop_key in kwargs:
|
||||
kwargs.pop(pop_key)
|
||||
|
||||
# replace kwargs with parsed versions
|
||||
kwargs["params"] = params
|
||||
kwargs["size"] = size
|
||||
|
||||
border = build_border(kwargs)
|
||||
kwargs["border"] = border
|
||||
|
||||
upscale = build_upscale(kwargs)
|
||||
kwargs["upscale"] = upscale
|
||||
|
||||
# prepare the stage metadata
|
||||
stage = StageParams(
|
||||
stage_data.get("name", stage_class.__name__),
|
||||
tile_size=get_size(kwargs.get("tile_size")),
|
||||
tile_size=get_size(kwargs.get("tiles")),
|
||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||
)
|
||||
|
||||
if "border" in kwargs:
|
||||
border = Border.even(int(kwargs.get("border")))
|
||||
kwargs["border"] = border
|
||||
|
||||
if "upscale" in kwargs:
|
||||
upscale = UpscaleParams(kwargs.get("upscale"))
|
||||
kwargs["upscale"] = upscale
|
||||
|
||||
# load any images related to this stage
|
||||
stage_source_name = "source:%s" % (stage.name)
|
||||
stage_mask_name = "mask:%s" % (stage.name)
|
||||
|
||||
|
@ -494,7 +498,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
sources.append(source)
|
||||
|
||||
device, params, size = pipeline_from_request(server)
|
||||
upscale = upscale_from_request()
|
||||
upscale = build_upscale()
|
||||
|
||||
output = make_output_name(server, "upscale", params, size)
|
||||
job_name = output[0]
|
||||
|
|
|
@ -51,7 +51,7 @@ class ModelCache:
|
|||
return
|
||||
|
||||
for i in range(len(cache)):
|
||||
t, k, v = cache[i]
|
||||
t, k, _v = cache[i]
|
||||
if tag == t and key != k:
|
||||
logger.debug("updating model cache: %s %s", tag, key)
|
||||
cache[i] = (tag, key, value)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Dict, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from flask import request
|
||||
|
@ -34,16 +34,10 @@ from .utils import get_model_path
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def pipeline_from_request(
|
||||
def build_device(
|
||||
server: ServerContext,
|
||||
default_pipeline: str = "txt2img",
|
||||
data: Dict[str, str] = None,
|
||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
data: Dict[str, str],
|
||||
) -> Optional[DeviceParams]:
|
||||
# platform stuff
|
||||
device = None
|
||||
device_name = data.get("platform")
|
||||
|
@ -53,6 +47,14 @@ def pipeline_from_request(
|
|||
if platform.device == device_name:
|
||||
device = platform
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def build_params(
|
||||
server: ServerContext,
|
||||
default_pipeline: str,
|
||||
data: Dict[str, str],
|
||||
) -> ImageParams:
|
||||
# diffusion model
|
||||
model = get_not_empty(data, "model", get_config_value("model"))
|
||||
model_path = get_model_path(server, model)
|
||||
|
@ -115,20 +117,6 @@ def pipeline_from_request(
|
|||
get_config_value("steps", "max"),
|
||||
get_config_value("steps", "min"),
|
||||
)
|
||||
height = get_and_clamp_int(
|
||||
data,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
data,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
|
||||
tiles = get_and_clamp_int(
|
||||
data,
|
||||
|
@ -161,21 +149,6 @@ def pipeline_from_request(
|
|||
# this one can safely use np.random because it produces a single value
|
||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||
|
||||
logger.info(
|
||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
user,
|
||||
steps,
|
||||
scheduler,
|
||||
model_path,
|
||||
pipeline,
|
||||
device or "any device",
|
||||
width,
|
||||
height,
|
||||
cfg,
|
||||
seed,
|
||||
prompt,
|
||||
)
|
||||
|
||||
params = ImageParams(
|
||||
model_path,
|
||||
pipeline,
|
||||
|
@ -194,34 +167,60 @@ def pipeline_from_request(
|
|||
overlap=overlap,
|
||||
stride=stride,
|
||||
)
|
||||
size = Size(width, height)
|
||||
return (device, params, size)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def border_from_request() -> Border:
|
||||
def build_size(
|
||||
server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
) -> Size:
|
||||
height = get_and_clamp_int(
|
||||
data,
|
||||
"height",
|
||||
get_config_value("height"),
|
||||
get_config_value("height", "max"),
|
||||
get_config_value("height", "min"),
|
||||
)
|
||||
width = get_and_clamp_int(
|
||||
data,
|
||||
"width",
|
||||
get_config_value("width"),
|
||||
get_config_value("width", "max"),
|
||||
get_config_value("width", "min"),
|
||||
)
|
||||
return Size(width, height)
|
||||
|
||||
|
||||
def build_border(
|
||||
data: Dict[str, str] = None,
|
||||
) -> Border:
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
left = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"left",
|
||||
get_config_value("left"),
|
||||
get_config_value("left", "max"),
|
||||
get_config_value("left", "min"),
|
||||
)
|
||||
right = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"right",
|
||||
get_config_value("right"),
|
||||
get_config_value("right", "max"),
|
||||
get_config_value("right", "min"),
|
||||
)
|
||||
top = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"top",
|
||||
get_config_value("top"),
|
||||
get_config_value("top", "max"),
|
||||
get_config_value("top", "min"),
|
||||
)
|
||||
bottom = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"bottom",
|
||||
get_config_value("bottom"),
|
||||
get_config_value("bottom", "max"),
|
||||
|
@ -231,46 +230,51 @@ def border_from_request() -> Border:
|
|||
return Border(left, right, top, bottom)
|
||||
|
||||
|
||||
def upscale_from_request() -> UpscaleParams:
|
||||
def build_upscale(
|
||||
data: Dict[str, str] = None,
|
||||
) -> UpscaleParams:
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
denoise = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"denoise",
|
||||
get_config_value("denoise"),
|
||||
get_config_value("denoise", "max"),
|
||||
get_config_value("denoise", "min"),
|
||||
)
|
||||
scale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"scale",
|
||||
get_config_value("scale"),
|
||||
get_config_value("scale", "max"),
|
||||
get_config_value("scale", "min"),
|
||||
)
|
||||
outscale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"outscale",
|
||||
get_config_value("outscale"),
|
||||
get_config_value("outscale", "max"),
|
||||
get_config_value("outscale", "min"),
|
||||
)
|
||||
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
|
||||
correction = get_from_list(request.args, "correction", get_correction_models())
|
||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
|
||||
correction = get_from_list(data, "correction", get_correction_models())
|
||||
faces = get_not_empty(data, "faces", "false") == "true"
|
||||
face_outscale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"faceOutscale",
|
||||
get_config_value("faceOutscale"),
|
||||
get_config_value("faceOutscale", "max"),
|
||||
get_config_value("faceOutscale", "min"),
|
||||
)
|
||||
face_strength = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"faceStrength",
|
||||
get_config_value("faceStrength"),
|
||||
get_config_value("faceStrength", "max"),
|
||||
get_config_value("faceStrength", "min"),
|
||||
)
|
||||
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
||||
upscale_order = data.get("upscaleOrder", "correction-first")
|
||||
|
||||
return UpscaleParams(
|
||||
upscaling,
|
||||
|
@ -286,37 +290,43 @@ def upscale_from_request() -> UpscaleParams:
|
|||
)
|
||||
|
||||
|
||||
def highres_from_request() -> HighresParams:
|
||||
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
|
||||
def build_highres(
|
||||
data: Dict[str, str] = None,
|
||||
) -> HighresParams:
|
||||
if data is None:
|
||||
data = request.args
|
||||
|
||||
enabled = get_boolean(data, "highres", get_config_value("highres"))
|
||||
iterations = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"highresIterations",
|
||||
get_config_value("highresIterations"),
|
||||
get_config_value("highresIterations", "max"),
|
||||
get_config_value("highresIterations", "min"),
|
||||
)
|
||||
method = get_from_list(request.args, "highresMethod", get_highres_methods())
|
||||
method = get_from_list(data, "highresMethod", get_highres_methods())
|
||||
scale = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"highresScale",
|
||||
get_config_value("highresScale"),
|
||||
get_config_value("highresScale", "max"),
|
||||
get_config_value("highresScale", "min"),
|
||||
)
|
||||
steps = get_and_clamp_int(
|
||||
request.args,
|
||||
data,
|
||||
"highresSteps",
|
||||
get_config_value("highresSteps"),
|
||||
get_config_value("highresSteps", "max"),
|
||||
get_config_value("highresSteps", "min"),
|
||||
)
|
||||
strength = get_and_clamp_float(
|
||||
request.args,
|
||||
data,
|
||||
"highresStrength",
|
||||
get_config_value("highresStrength"),
|
||||
get_config_value("highresStrength", "max"),
|
||||
get_config_value("highresStrength", "min"),
|
||||
)
|
||||
|
||||
return HighresParams(
|
||||
enabled,
|
||||
scale,
|
||||
|
@ -325,3 +335,50 @@ def highres_from_request() -> HighresParams:
|
|||
method=method,
|
||||
iterations=iterations,
|
||||
)
|
||||
|
||||
|
||||
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
|
||||
|
||||
|
||||
def pipeline_from_json(
|
||||
server: ServerContext,
|
||||
data: Dict[str, str],
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> PipelineParams:
|
||||
"""
|
||||
Like pipeline_from_request but expects a nested structure.
|
||||
"""
|
||||
|
||||
device = build_device(server, data.get("device", data))
|
||||
params = build_params(server, default_pipeline, data.get("params", data))
|
||||
size = build_size(server, data.get("params", data))
|
||||
|
||||
return (device, params, size)
|
||||
|
||||
|
||||
def pipeline_from_request(
|
||||
server: ServerContext,
|
||||
default_pipeline: str = "txt2img",
|
||||
) -> PipelineParams:
|
||||
user = request.remote_addr
|
||||
|
||||
device = build_device(server, request.args)
|
||||
params = build_params(server, default_pipeline, request.args)
|
||||
size = build_size(server, request.args)
|
||||
|
||||
logger.info(
|
||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||
user,
|
||||
params.steps,
|
||||
params.scheduler,
|
||||
params.model_path,
|
||||
params.pipeline,
|
||||
device or "any device",
|
||||
params.width,
|
||||
params.height,
|
||||
params.cfg,
|
||||
params.seed,
|
||||
params.prompt,
|
||||
)
|
||||
|
||||
return (device, params, size)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||
|
||||
|
||||
class BlendLinearStageTests(unittest.TestCase):
|
||||
def test_stage(self):
|
||||
stage = BlendLinearStage()
|
||||
sources = [
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
stage_source = Image.new("RGB", (64, 64), "white")
|
||||
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
|
|
@ -0,0 +1,42 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.tile import complete_tile
|
||||
|
||||
|
||||
class TestCompleteTile(unittest.TestCase):
|
||||
def test_with_complete_tile(self):
|
||||
partial = Image.new("RGB", (64, 64))
|
||||
output = complete_tile(partial, 64)
|
||||
|
||||
self.assertEqual(output.size, (64, 64))
|
||||
|
||||
def test_with_partial_tile(self):
|
||||
partial = Image.new("RGB", (64, 32))
|
||||
output = complete_tile(partial, 64)
|
||||
|
||||
self.assertEqual(output.size, (64, 64))
|
||||
|
||||
def test_with_nothing(self):
|
||||
output = complete_tile(None, 64)
|
||||
self.assertIsNone(output)
|
||||
|
||||
|
||||
class TestNeedsTile(unittest.TestCase):
|
||||
def test_with_undersized(self):
|
||||
pass
|
||||
|
||||
def test_with_oversized(self):
|
||||
pass
|
||||
|
||||
def test_with_mixed(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestTileGrads(unittest.TestCase):
|
||||
def test_center_tile(self):
|
||||
pass
|
||||
|
||||
def test_edge_tile(self):
|
||||
pass
|
|
@ -0,0 +1,13 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
|
||||
|
||||
class UpscaleHighresStageTests(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
stage = UpscaleHighresStage()
|
||||
sources = []
|
||||
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,12 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.models.meta import NetworkModel
|
||||
|
||||
|
||||
class NetworkModelTests(unittest.TestCase):
|
||||
def test_json(self):
|
||||
model = NetworkModel("test", "inversion")
|
||||
json = model.tojson()
|
||||
|
||||
self.assertIn("name", json)
|
||||
self.assertIn("type", json)
|
|
@ -1,7 +1,9 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.prompt.grammar import PromptPhrase
|
||||
from onnx_web.prompt.parser import parse_prompt_onnx
|
||||
|
||||
|
||||
class ParserTests(unittest.TestCase):
|
||||
def test_single_word_phrase(self):
|
||||
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
|
||||
|
|
|
@ -2,7 +2,8 @@ import unittest
|
|||
|
||||
from onnx_web.server.model_cache import ModelCache
|
||||
|
||||
class TestStringMethods(unittest.TestCase):
|
||||
|
||||
class TestModelCache(unittest.TestCase):
|
||||
def test_drop_existing(self):
|
||||
cache = ModelCache(10)
|
||||
cache.clear()
|
||||
|
@ -32,3 +33,31 @@ class TestStringMethods(unittest.TestCase):
|
|||
cache.set("foo", ("bar",), value)
|
||||
self.assertGreater(cache.size, 0)
|
||||
self.assertIs(cache.get("foo", ("bin",)), None)
|
||||
|
||||
"""
|
||||
def test_set_existing(self):
|
||||
cache = ModelCache(10)
|
||||
cache.clear()
|
||||
cache.set("foo", ("bar",), {
|
||||
"value": 1,
|
||||
})
|
||||
value = {
|
||||
"value": 2,
|
||||
}
|
||||
cache.set("foo", ("bar",), value)
|
||||
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||
"""
|
||||
|
||||
def test_set_missing(self):
|
||||
cache = ModelCache(10)
|
||||
cache.clear()
|
||||
value = {}
|
||||
cache.set("foo", ("bar",), value)
|
||||
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||
|
||||
def test_set_zero(self):
|
||||
cache = ModelCache(0)
|
||||
cache.clear()
|
||||
value = {}
|
||||
cache.set("foo", ("bar",), value)
|
||||
self.assertEqual(cache.size, 0)
|
||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
|
||||
from onnx_web.params import Border, Size
|
||||
|
||||
|
||||
class BorderTests(unittest.TestCase):
|
||||
def test_json(self):
|
||||
border = Border.even(0)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
|
||||
|
||||
# just to get CI happy
|
||||
class ErrorTest(unittest.TestCase):
|
||||
def test(self):
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.pool import DevicePoolExecutor
|
||||
|
||||
|
||||
class TestWorkerPool(unittest.TestCase):
|
||||
def test_no_devices(self):
|
||||
server = ServerContext()
|
||||
pool = DevicePoolExecutor(server, [])
|
||||
pool.start()
|
||||
pool.join()
|
Loading…
Reference in New Issue