fix(api): make request parsing consistent between JSON and forms
This commit is contained in:
parent
8a5e211172
commit
a33c88e670
|
@ -87,15 +87,15 @@ class ChainPipeline:
|
||||||
|
|
||||||
def steps(self, params: ImageParams, size: Size):
|
def steps(self, params: ImageParams, size: Size):
|
||||||
steps = 0
|
steps = 0
|
||||||
for callback, _params, _kwargs in self.stages:
|
for callback, _params, kwargs in self.stages:
|
||||||
steps += callback.steps(params, size)
|
steps += callback.steps(kwargs.get("params", params), size)
|
||||||
|
|
||||||
return steps
|
return steps
|
||||||
|
|
||||||
def outputs(self, params: ImageParams, sources: int):
|
def outputs(self, params: ImageParams, sources: int):
|
||||||
outputs = sources
|
outputs = sources
|
||||||
for callback, _params, _kwargs in self.stages:
|
for callback, _params, kwargs in self.stages:
|
||||||
outputs += callback.outputs(params, outputs)
|
outputs += callback.outputs(kwargs.get("params", params), outputs)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
|
@ -21,14 +21,14 @@ class BaseStage:
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> List[Image.Image]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError() # noqa
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
) -> int:
|
) -> int:
|
||||||
return 1
|
return 1 # noqa
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ..diffusers.run import (
|
||||||
)
|
)
|
||||||
from ..diffusers.utils import replace_wildcards
|
from ..diffusers.utils import replace_wildcards
|
||||||
from ..output import json_params, make_output_name
|
from ..output import json_params, make_output_name
|
||||||
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
|
from ..params import Size, StageParams, TileOrder
|
||||||
from ..transformers.run import run_txt2txt_pipeline
|
from ..transformers.run import run_txt2txt_pipeline
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -50,10 +50,11 @@ from .load import (
|
||||||
get_wildcard_data,
|
get_wildcard_data,
|
||||||
)
|
)
|
||||||
from .params import (
|
from .params import (
|
||||||
border_from_request,
|
build_border,
|
||||||
highres_from_request,
|
build_highres,
|
||||||
|
build_upscale,
|
||||||
|
pipeline_from_json,
|
||||||
pipeline_from_request,
|
pipeline_from_request,
|
||||||
upscale_from_request,
|
|
||||||
)
|
)
|
||||||
from .utils import wrap_route
|
from .utils import wrap_route
|
||||||
|
|
||||||
|
@ -168,8 +169,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
size = Size(source.width, source.height)
|
size = Size(source.width, source.height)
|
||||||
|
|
||||||
device, params, _size = pipeline_from_request(server, "img2img")
|
device, params, _size = pipeline_from_request(server, "img2img")
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
source_filter = get_from_list(
|
source_filter = get_from_list(
|
||||||
request.args, "sourceFilter", list(get_source_filters().keys())
|
request.args, "sourceFilter", list(get_source_filters().keys())
|
||||||
)
|
)
|
||||||
|
@ -217,8 +218,8 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(server, "txt2img")
|
device, params, size = pipeline_from_request(server, "txt2img")
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
|
@ -271,9 +272,9 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
device, params, _size = pipeline_from_request(server, "inpaint")
|
device, params, _size = pipeline_from_request(server, "inpaint")
|
||||||
expand = border_from_request()
|
expand = build_border()
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
|
|
||||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||||
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
||||||
|
@ -341,8 +342,8 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
highres = highres_from_request()
|
highres = build_highres()
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
|
@ -367,6 +368,10 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
return jsonify(json_params(output, params, size, upscale=upscale, highres=highres))
|
||||||
|
|
||||||
|
|
||||||
|
# keys that are specially parsed by params and should not show up in with_args
|
||||||
|
CHAIN_POP_KEYS = ["model", "control"]
|
||||||
|
|
||||||
|
|
||||||
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
if request.is_json:
|
if request.is_json:
|
||||||
logger.debug("chain pipeline request with JSON body")
|
logger.debug("chain pipeline request with JSON body")
|
||||||
|
@ -386,9 +391,8 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
logger.debug("validating chain request: %s against %s", data, schema)
|
logger.debug("validating chain request: %s against %s", data, schema)
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
device, base_params, base_size = pipeline_from_json(
|
||||||
device, base_params, base_size = pipeline_from_request(
|
server, data=data.get("defaults")
|
||||||
server, data=data.get("defaults", None)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# start building the pipeline
|
# start building the pipeline
|
||||||
|
@ -399,32 +403,32 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||||
|
|
||||||
# TODO: combine base params with stage params
|
# TODO: combine base params with stage params
|
||||||
_device, params, size = pipeline_from_request(server, data=kwargs)
|
_device, params, size = pipeline_from_json(server, data=kwargs)
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
if "model" in kwargs:
|
# remove parsed keys, like model names (which become paths)
|
||||||
kwargs.pop("model")
|
for pop_key in CHAIN_POP_KEYS:
|
||||||
|
if pop_key in kwargs:
|
||||||
if "control" in kwargs:
|
kwargs.pop(pop_key)
|
||||||
logger.warning("TODO: resolve controlnet model")
|
|
||||||
kwargs.pop("control")
|
|
||||||
|
|
||||||
|
# replace kwargs with parsed versions
|
||||||
kwargs["params"] = params
|
kwargs["params"] = params
|
||||||
|
kwargs["size"] = size
|
||||||
|
|
||||||
|
border = build_border(kwargs)
|
||||||
|
kwargs["border"] = border
|
||||||
|
|
||||||
|
upscale = build_upscale(kwargs)
|
||||||
|
kwargs["upscale"] = upscale
|
||||||
|
|
||||||
|
# prepare the stage metadata
|
||||||
stage = StageParams(
|
stage = StageParams(
|
||||||
stage_data.get("name", stage_class.__name__),
|
stage_data.get("name", stage_class.__name__),
|
||||||
tile_size=get_size(kwargs.get("tile_size")),
|
tile_size=get_size(kwargs.get("tiles")),
|
||||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "border" in kwargs:
|
# load any images related to this stage
|
||||||
border = Border.even(int(kwargs.get("border")))
|
|
||||||
kwargs["border"] = border
|
|
||||||
|
|
||||||
if "upscale" in kwargs:
|
|
||||||
upscale = UpscaleParams(kwargs.get("upscale"))
|
|
||||||
kwargs["upscale"] = upscale
|
|
||||||
|
|
||||||
stage_source_name = "source:%s" % (stage.name)
|
stage_source_name = "source:%s" % (stage.name)
|
||||||
stage_mask_name = "mask:%s" % (stage.name)
|
stage_mask_name = "mask:%s" % (stage.name)
|
||||||
|
|
||||||
|
@ -494,7 +498,7 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
sources.append(source)
|
sources.append(source)
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = upscale_from_request()
|
upscale = build_upscale()
|
||||||
|
|
||||||
output = make_output_name(server, "upscale", params, size)
|
output = make_output_name(server, "upscale", params, size)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
|
|
|
@ -51,7 +51,7 @@ class ModelCache:
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(len(cache)):
|
for i in range(len(cache)):
|
||||||
t, k, v = cache[i]
|
t, k, _v = cache[i]
|
||||||
if tag == t and key != k:
|
if tag == t and key != k:
|
||||||
logger.debug("updating model cache: %s %s", tag, key)
|
logger.debug("updating model cache: %s %s", tag, key)
|
||||||
cache[i] = (tag, key, value)
|
cache[i] = (tag, key, value)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from flask import request
|
from flask import request
|
||||||
|
@ -34,16 +34,10 @@ from .utils import get_model_path
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(
|
def build_device(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
default_pipeline: str = "txt2img",
|
data: Dict[str, str],
|
||||||
data: Dict[str, str] = None,
|
) -> Optional[DeviceParams]:
|
||||||
) -> Tuple[DeviceParams, ImageParams, Size]:
|
|
||||||
user = request.remote_addr
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
data = request.args
|
|
||||||
|
|
||||||
# platform stuff
|
# platform stuff
|
||||||
device = None
|
device = None
|
||||||
device_name = data.get("platform")
|
device_name = data.get("platform")
|
||||||
|
@ -53,6 +47,14 @@ def pipeline_from_request(
|
||||||
if platform.device == device_name:
|
if platform.device == device_name:
|
||||||
device = platform
|
device = platform
|
||||||
|
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def build_params(
|
||||||
|
server: ServerContext,
|
||||||
|
default_pipeline: str,
|
||||||
|
data: Dict[str, str],
|
||||||
|
) -> ImageParams:
|
||||||
# diffusion model
|
# diffusion model
|
||||||
model = get_not_empty(data, "model", get_config_value("model"))
|
model = get_not_empty(data, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(server, model)
|
model_path = get_model_path(server, model)
|
||||||
|
@ -115,20 +117,6 @@ def pipeline_from_request(
|
||||||
get_config_value("steps", "max"),
|
get_config_value("steps", "max"),
|
||||||
get_config_value("steps", "min"),
|
get_config_value("steps", "min"),
|
||||||
)
|
)
|
||||||
height = get_and_clamp_int(
|
|
||||||
data,
|
|
||||||
"height",
|
|
||||||
get_config_value("height"),
|
|
||||||
get_config_value("height", "max"),
|
|
||||||
get_config_value("height", "min"),
|
|
||||||
)
|
|
||||||
width = get_and_clamp_int(
|
|
||||||
data,
|
|
||||||
"width",
|
|
||||||
get_config_value("width"),
|
|
||||||
get_config_value("width", "max"),
|
|
||||||
get_config_value("width", "min"),
|
|
||||||
)
|
|
||||||
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
|
tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE"))
|
||||||
tiles = get_and_clamp_int(
|
tiles = get_and_clamp_int(
|
||||||
data,
|
data,
|
||||||
|
@ -161,21 +149,6 @@ def pipeline_from_request(
|
||||||
# this one can safely use np.random because it produces a single value
|
# this one can safely use np.random because it produces a single value
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
|
||||||
user,
|
|
||||||
steps,
|
|
||||||
scheduler,
|
|
||||||
model_path,
|
|
||||||
pipeline,
|
|
||||||
device or "any device",
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfg,
|
|
||||||
seed,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
params = ImageParams(
|
params = ImageParams(
|
||||||
model_path,
|
model_path,
|
||||||
pipeline,
|
pipeline,
|
||||||
|
@ -194,34 +167,60 @@ def pipeline_from_request(
|
||||||
overlap=overlap,
|
overlap=overlap,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
)
|
)
|
||||||
size = Size(width, height)
|
|
||||||
return (device, params, size)
|
return params
|
||||||
|
|
||||||
|
|
||||||
def border_from_request() -> Border:
|
def build_size(
|
||||||
|
server: ServerContext,
|
||||||
|
data: Dict[str, str],
|
||||||
|
) -> Size:
|
||||||
|
height = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"height",
|
||||||
|
get_config_value("height"),
|
||||||
|
get_config_value("height", "max"),
|
||||||
|
get_config_value("height", "min"),
|
||||||
|
)
|
||||||
|
width = get_and_clamp_int(
|
||||||
|
data,
|
||||||
|
"width",
|
||||||
|
get_config_value("width"),
|
||||||
|
get_config_value("width", "max"),
|
||||||
|
get_config_value("width", "min"),
|
||||||
|
)
|
||||||
|
return Size(width, height)
|
||||||
|
|
||||||
|
|
||||||
|
def build_border(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> Border:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
left = get_and_clamp_int(
|
left = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"left",
|
"left",
|
||||||
get_config_value("left"),
|
get_config_value("left"),
|
||||||
get_config_value("left", "max"),
|
get_config_value("left", "max"),
|
||||||
get_config_value("left", "min"),
|
get_config_value("left", "min"),
|
||||||
)
|
)
|
||||||
right = get_and_clamp_int(
|
right = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"right",
|
"right",
|
||||||
get_config_value("right"),
|
get_config_value("right"),
|
||||||
get_config_value("right", "max"),
|
get_config_value("right", "max"),
|
||||||
get_config_value("right", "min"),
|
get_config_value("right", "min"),
|
||||||
)
|
)
|
||||||
top = get_and_clamp_int(
|
top = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"top",
|
"top",
|
||||||
get_config_value("top"),
|
get_config_value("top"),
|
||||||
get_config_value("top", "max"),
|
get_config_value("top", "max"),
|
||||||
get_config_value("top", "min"),
|
get_config_value("top", "min"),
|
||||||
)
|
)
|
||||||
bottom = get_and_clamp_int(
|
bottom = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"bottom",
|
"bottom",
|
||||||
get_config_value("bottom"),
|
get_config_value("bottom"),
|
||||||
get_config_value("bottom", "max"),
|
get_config_value("bottom", "max"),
|
||||||
|
@ -231,46 +230,51 @@ def border_from_request() -> Border:
|
||||||
return Border(left, right, top, bottom)
|
return Border(left, right, top, bottom)
|
||||||
|
|
||||||
|
|
||||||
def upscale_from_request() -> UpscaleParams:
|
def build_upscale(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> UpscaleParams:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
denoise = get_and_clamp_float(
|
denoise = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"denoise",
|
"denoise",
|
||||||
get_config_value("denoise"),
|
get_config_value("denoise"),
|
||||||
get_config_value("denoise", "max"),
|
get_config_value("denoise", "max"),
|
||||||
get_config_value("denoise", "min"),
|
get_config_value("denoise", "min"),
|
||||||
)
|
)
|
||||||
scale = get_and_clamp_int(
|
scale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"scale",
|
"scale",
|
||||||
get_config_value("scale"),
|
get_config_value("scale"),
|
||||||
get_config_value("scale", "max"),
|
get_config_value("scale", "max"),
|
||||||
get_config_value("scale", "min"),
|
get_config_value("scale", "min"),
|
||||||
)
|
)
|
||||||
outscale = get_and_clamp_int(
|
outscale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"outscale",
|
"outscale",
|
||||||
get_config_value("outscale"),
|
get_config_value("outscale"),
|
||||||
get_config_value("outscale", "max"),
|
get_config_value("outscale", "max"),
|
||||||
get_config_value("outscale", "min"),
|
get_config_value("outscale", "min"),
|
||||||
)
|
)
|
||||||
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
|
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
|
||||||
correction = get_from_list(request.args, "correction", get_correction_models())
|
correction = get_from_list(data, "correction", get_correction_models())
|
||||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
faces = get_not_empty(data, "faces", "false") == "true"
|
||||||
face_outscale = get_and_clamp_int(
|
face_outscale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"faceOutscale",
|
"faceOutscale",
|
||||||
get_config_value("faceOutscale"),
|
get_config_value("faceOutscale"),
|
||||||
get_config_value("faceOutscale", "max"),
|
get_config_value("faceOutscale", "max"),
|
||||||
get_config_value("faceOutscale", "min"),
|
get_config_value("faceOutscale", "min"),
|
||||||
)
|
)
|
||||||
face_strength = get_and_clamp_float(
|
face_strength = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"faceStrength",
|
"faceStrength",
|
||||||
get_config_value("faceStrength"),
|
get_config_value("faceStrength"),
|
||||||
get_config_value("faceStrength", "max"),
|
get_config_value("faceStrength", "max"),
|
||||||
get_config_value("faceStrength", "min"),
|
get_config_value("faceStrength", "min"),
|
||||||
)
|
)
|
||||||
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
upscale_order = data.get("upscaleOrder", "correction-first")
|
||||||
|
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
upscaling,
|
upscaling,
|
||||||
|
@ -286,37 +290,43 @@ def upscale_from_request() -> UpscaleParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def highres_from_request() -> HighresParams:
|
def build_highres(
|
||||||
enabled = get_boolean(request.args, "highres", get_config_value("highres"))
|
data: Dict[str, str] = None,
|
||||||
|
) -> HighresParams:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
|
enabled = get_boolean(data, "highres", get_config_value("highres"))
|
||||||
iterations = get_and_clamp_int(
|
iterations = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"highresIterations",
|
"highresIterations",
|
||||||
get_config_value("highresIterations"),
|
get_config_value("highresIterations"),
|
||||||
get_config_value("highresIterations", "max"),
|
get_config_value("highresIterations", "max"),
|
||||||
get_config_value("highresIterations", "min"),
|
get_config_value("highresIterations", "min"),
|
||||||
)
|
)
|
||||||
method = get_from_list(request.args, "highresMethod", get_highres_methods())
|
method = get_from_list(data, "highresMethod", get_highres_methods())
|
||||||
scale = get_and_clamp_int(
|
scale = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"highresScale",
|
"highresScale",
|
||||||
get_config_value("highresScale"),
|
get_config_value("highresScale"),
|
||||||
get_config_value("highresScale", "max"),
|
get_config_value("highresScale", "max"),
|
||||||
get_config_value("highresScale", "min"),
|
get_config_value("highresScale", "min"),
|
||||||
)
|
)
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args,
|
data,
|
||||||
"highresSteps",
|
"highresSteps",
|
||||||
get_config_value("highresSteps"),
|
get_config_value("highresSteps"),
|
||||||
get_config_value("highresSteps", "max"),
|
get_config_value("highresSteps", "max"),
|
||||||
get_config_value("highresSteps", "min"),
|
get_config_value("highresSteps", "min"),
|
||||||
)
|
)
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
request.args,
|
data,
|
||||||
"highresStrength",
|
"highresStrength",
|
||||||
get_config_value("highresStrength"),
|
get_config_value("highresStrength"),
|
||||||
get_config_value("highresStrength", "max"),
|
get_config_value("highresStrength", "max"),
|
||||||
get_config_value("highresStrength", "min"),
|
get_config_value("highresStrength", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return HighresParams(
|
return HighresParams(
|
||||||
enabled,
|
enabled,
|
||||||
scale,
|
scale,
|
||||||
|
@ -325,3 +335,50 @@ def highres_from_request() -> HighresParams:
|
||||||
method=method,
|
method=method,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_from_json(
|
||||||
|
server: ServerContext,
|
||||||
|
data: Dict[str, str],
|
||||||
|
default_pipeline: str = "txt2img",
|
||||||
|
) -> PipelineParams:
|
||||||
|
"""
|
||||||
|
Like pipeline_from_request but expects a nested structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = build_device(server, data.get("device", data))
|
||||||
|
params = build_params(server, default_pipeline, data.get("params", data))
|
||||||
|
size = build_size(server, data.get("params", data))
|
||||||
|
|
||||||
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_from_request(
|
||||||
|
server: ServerContext,
|
||||||
|
default_pipeline: str = "txt2img",
|
||||||
|
) -> PipelineParams:
|
||||||
|
user = request.remote_addr
|
||||||
|
|
||||||
|
device = build_device(server, request.args)
|
||||||
|
params = build_params(server, default_pipeline, request.args)
|
||||||
|
size = build_size(server, request.args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||||
|
user,
|
||||||
|
params.steps,
|
||||||
|
params.scheduler,
|
||||||
|
params.model_path,
|
||||||
|
params.pipeline,
|
||||||
|
device or "any device",
|
||||||
|
params.width,
|
||||||
|
params.height,
|
||||||
|
params.cfg,
|
||||||
|
params.seed,
|
||||||
|
params.prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (device, params, size)
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||||
|
|
||||||
|
|
||||||
|
class BlendLinearStageTests(unittest.TestCase):
|
||||||
|
def test_stage(self):
|
||||||
|
stage = BlendLinearStage()
|
||||||
|
sources = [
|
||||||
|
Image.new("RGB", (64, 64), "black"),
|
||||||
|
]
|
||||||
|
stage_source = Image.new("RGB", (64, 64), "white")
|
||||||
|
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
|
|
@ -0,0 +1,42 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.tile import complete_tile
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompleteTile(unittest.TestCase):
|
||||||
|
def test_with_complete_tile(self):
|
||||||
|
partial = Image.new("RGB", (64, 64))
|
||||||
|
output = complete_tile(partial, 64)
|
||||||
|
|
||||||
|
self.assertEqual(output.size, (64, 64))
|
||||||
|
|
||||||
|
def test_with_partial_tile(self):
|
||||||
|
partial = Image.new("RGB", (64, 32))
|
||||||
|
output = complete_tile(partial, 64)
|
||||||
|
|
||||||
|
self.assertEqual(output.size, (64, 64))
|
||||||
|
|
||||||
|
def test_with_nothing(self):
|
||||||
|
output = complete_tile(None, 64)
|
||||||
|
self.assertIsNone(output)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNeedsTile(unittest.TestCase):
|
||||||
|
def test_with_undersized(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_with_oversized(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_with_mixed(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestTileGrads(unittest.TestCase):
|
||||||
|
def test_center_tile(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_edge_tile(self):
|
||||||
|
pass
|
|
@ -0,0 +1,13 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleHighresStageTests(unittest.TestCase):
|
||||||
|
def test_empty(self):
|
||||||
|
stage = UpscaleHighresStage()
|
||||||
|
sources = []
|
||||||
|
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,12 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.models.meta import NetworkModel
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModelTests(unittest.TestCase):
|
||||||
|
def test_json(self):
|
||||||
|
model = NetworkModel("test", "inversion")
|
||||||
|
json = model.tojson()
|
||||||
|
|
||||||
|
self.assertIn("name", json)
|
||||||
|
self.assertIn("type", json)
|
|
@ -1,7 +1,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from onnx_web.prompt.grammar import PromptPhrase
|
from onnx_web.prompt.grammar import PromptPhrase
|
||||||
from onnx_web.prompt.parser import parse_prompt_onnx
|
from onnx_web.prompt.parser import parse_prompt_onnx
|
||||||
|
|
||||||
|
|
||||||
class ParserTests(unittest.TestCase):
|
class ParserTests(unittest.TestCase):
|
||||||
def test_single_word_phrase(self):
|
def test_single_word_phrase(self):
|
||||||
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
|
res = parse_prompt_onnx(None, "foo (bar) bin", debug=False)
|
||||||
|
|
|
@ -2,7 +2,8 @@ import unittest
|
||||||
|
|
||||||
from onnx_web.server.model_cache import ModelCache
|
from onnx_web.server.model_cache import ModelCache
|
||||||
|
|
||||||
class TestStringMethods(unittest.TestCase):
|
|
||||||
|
class TestModelCache(unittest.TestCase):
|
||||||
def test_drop_existing(self):
|
def test_drop_existing(self):
|
||||||
cache = ModelCache(10)
|
cache = ModelCache(10)
|
||||||
cache.clear()
|
cache.clear()
|
||||||
|
@ -32,3 +33,31 @@ class TestStringMethods(unittest.TestCase):
|
||||||
cache.set("foo", ("bar",), value)
|
cache.set("foo", ("bar",), value)
|
||||||
self.assertGreater(cache.size, 0)
|
self.assertGreater(cache.size, 0)
|
||||||
self.assertIs(cache.get("foo", ("bin",)), None)
|
self.assertIs(cache.get("foo", ("bin",)), None)
|
||||||
|
|
||||||
|
"""
|
||||||
|
def test_set_existing(self):
|
||||||
|
cache = ModelCache(10)
|
||||||
|
cache.clear()
|
||||||
|
cache.set("foo", ("bar",), {
|
||||||
|
"value": 1,
|
||||||
|
})
|
||||||
|
value = {
|
||||||
|
"value": 2,
|
||||||
|
}
|
||||||
|
cache.set("foo", ("bar",), value)
|
||||||
|
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_set_missing(self):
|
||||||
|
cache = ModelCache(10)
|
||||||
|
cache.clear()
|
||||||
|
value = {}
|
||||||
|
cache.set("foo", ("bar",), value)
|
||||||
|
self.assertIs(cache.get("foo", ("bar",)), value)
|
||||||
|
|
||||||
|
def test_set_zero(self):
|
||||||
|
cache = ModelCache(0)
|
||||||
|
cache.clear()
|
||||||
|
value = {}
|
||||||
|
cache.set("foo", ("bar",), value)
|
||||||
|
self.assertEqual(cache.size, 0)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
||||||
|
|
||||||
from onnx_web.params import Border, Size
|
from onnx_web.params import Border, Size
|
||||||
|
|
||||||
|
|
||||||
class BorderTests(unittest.TestCase):
|
class BorderTests(unittest.TestCase):
|
||||||
def test_json(self):
|
def test_json(self):
|
||||||
border = Border.even(0)
|
border = Border.even(0)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
# just to get CI happy
|
# just to get CI happy
|
||||||
class ErrorTest(unittest.TestCase):
|
class ErrorTest(unittest.TestCase):
|
||||||
def test(self):
|
def test(self):
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkerPool(unittest.TestCase):
|
||||||
|
def test_no_devices(self):
|
||||||
|
server = ServerContext()
|
||||||
|
pool = DevicePoolExecutor(server, [])
|
||||||
|
pool.start()
|
||||||
|
pool.join()
|
Loading…
Reference in New Issue