1
0
Fork 0
onnx-web/api/onnx_web/server/params.py

511 lines
13 KiB
Python
Raw Normal View History

from logging import getLogger
from typing import Any, Dict, Optional, Tuple, Union
from flask import request
2023-04-13 04:11:53 +00:00
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
2023-12-03 17:11:23 +00:00
from ..diffusers.utils import random_seed
2023-04-01 17:06:31 +00:00
from ..params import (
Border,
DeviceParams,
ExperimentalParams,
2023-04-01 17:06:31 +00:00
HighresParams,
ImageParams,
LatentSymmetryParams,
PromptEditingParams,
RequestParams,
2023-04-01 17:06:31 +00:00
Size,
UpscaleParams,
)
2023-04-28 18:56:36 +00:00
from ..utils import (
get_and_clamp_float,
get_and_clamp_int,
get_boolean,
get_from_list,
get_not_empty,
2024-02-17 19:51:42 +00:00
load_config_str,
2023-04-28 18:56:36 +00:00
)
2023-03-05 13:19:48 +00:00
from .context import ServerContext
from .load import (
2023-02-26 20:15:30 +00:00
get_available_platforms,
get_config_value,
get_correction_models,
get_highres_methods,
get_network_models,
2023-02-26 20:15:30 +00:00
get_upscaling_models,
)
2023-02-26 20:15:30 +00:00
from .utils import get_model_path
logger = getLogger(__name__)
def build_device(
2023-09-16 00:16:47 +00:00
_server: ServerContext,
data: Dict[str, str],
) -> Optional[DeviceParams]:
2023-09-11 02:21:57 +00:00
# platform stuff
device = None
device_name = data.get("platform")
if device_name is not None and device_name != "any":
for platform in get_available_platforms():
if platform.device == device_name:
device = platform
return device
def build_params(
server: ServerContext,
default_pipeline: str,
data: Dict[str, str],
) -> ImageParams:
# diffusion model
model = get_not_empty(data, "model", get_config_value("model"))
model_path = get_model_path(server, model)
2023-09-11 02:21:57 +00:00
control = None
control_name = data.get("control")
for network in get_network_models():
if network.name == control_name:
control = network
# pipeline stuff
pipeline = get_from_list(
data, "pipeline", get_available_pipelines(), default_pipeline
)
scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers())
if scheduler is None:
scheduler = get_config_value("scheduler")
# prompt does not come from config
prompt = data.get("prompt", "")
negative_prompt = data.get("negativePrompt", None)
if negative_prompt is not None and negative_prompt.strip() == "":
negative_prompt = None
# image params
batch = get_and_clamp_int(
data,
"batch",
get_config_value("batch"),
get_config_value("batch", "max"),
get_config_value("batch", "min"),
)
cfg = get_and_clamp_float(
data,
"cfg",
get_config_value("cfg"),
get_config_value("cfg", "max"),
get_config_value("cfg", "min"),
)
eta = get_and_clamp_float(
data,
"eta",
get_config_value("eta"),
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
loopback = get_and_clamp_int(
data,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int(
data,
"steps",
get_config_value("steps"),
get_config_value("steps", "max"),
get_config_value("steps", "min"),
)
tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae"))
thumbnail = get_boolean(data, "thumbnail", get_config_value("thumbnail"))
unet_overlap = get_and_clamp_float(
data,
"unet_overlap",
get_config_value("unet_overlap"),
get_config_value("unet_overlap", "max"),
get_config_value("unet_overlap", "min"),
)
unet_tile = get_and_clamp_int(
data,
"unet_tile",
get_config_value("unet_tile"),
get_config_value("unet_tile", "max"),
get_config_value("unet_tile", "min"),
)
vae_overlap = get_and_clamp_float(
data,
"vae_overlap",
get_config_value("vae_overlap"),
get_config_value("vae_overlap", "max"),
get_config_value("vae_overlap", "min"),
)
vae_tile = get_and_clamp_int(
data,
"vae_tile",
get_config_value("vae_tile"),
get_config_value("vae_tile", "max"),
get_config_value("vae_tile", "min"),
)
seed = int(data.get("seed", -1))
if seed == -1:
2023-12-03 17:11:23 +00:00
seed = random_seed()
params = ImageParams(
model_path,
pipeline,
scheduler,
prompt,
cfg,
steps,
seed,
eta=eta,
negative_prompt=negative_prompt,
batch=batch,
control=control,
2023-04-22 15:39:23 +00:00
loopback=loopback,
tiled_vae=tiled_vae,
unet_overlap=unet_overlap,
unet_tile=unet_tile,
vae_overlap=vae_overlap,
vae_tile=vae_tile,
thumbnail=thumbnail,
)
return params
def build_size(
2023-09-16 00:16:47 +00:00
_server: ServerContext,
data: Dict[str, str],
) -> Size:
height = get_and_clamp_int(
data,
"height",
get_config_value("height"),
get_config_value("height", "max"),
get_config_value("height", "min"),
)
width = get_and_clamp_int(
data,
"width",
get_config_value("width"),
get_config_value("width", "max"),
get_config_value("width", "min"),
)
return Size(width, height)
def build_border(
data: Dict[str, str] = None,
) -> Border:
if data is None:
data = request.args
left = get_and_clamp_int(
data,
"left",
get_config_value("left"),
get_config_value("left", "max"),
get_config_value("left", "min"),
)
right = get_and_clamp_int(
data,
"right",
get_config_value("right"),
get_config_value("right", "max"),
get_config_value("right", "min"),
)
top = get_and_clamp_int(
data,
"top",
get_config_value("top"),
get_config_value("top", "max"),
get_config_value("top", "min"),
)
bottom = get_and_clamp_int(
data,
"bottom",
get_config_value("bottom"),
get_config_value("bottom", "max"),
get_config_value("bottom", "min"),
)
return Border(left, right, top, bottom)
def build_upscale(
data: Dict[str, str] = None,
) -> UpscaleParams:
if data is None:
data = request.args
upscale = get_boolean(data, "upscale", False)
denoise = get_and_clamp_float(
data,
"denoise",
get_config_value("denoise"),
get_config_value("denoise", "max"),
get_config_value("denoise", "min"),
)
scale = get_and_clamp_int(
data,
"scale",
get_config_value("scale"),
get_config_value("scale", "max"),
get_config_value("scale", "min"),
)
outscale = get_and_clamp_int(
data,
"outscale",
get_config_value("outscale"),
get_config_value("outscale", "max"),
get_config_value("outscale", "min"),
)
upscaling = get_from_list(data, "upscaling", get_upscaling_models())
correction = get_from_list(data, "correction", get_correction_models())
faces = get_boolean(data, "faces", False)
face_outscale = get_and_clamp_int(
data,
"faceOutscale",
get_config_value("faceOutscale"),
get_config_value("faceOutscale", "max"),
get_config_value("faceOutscale", "min"),
)
face_strength = get_and_clamp_float(
data,
"faceStrength",
get_config_value("faceStrength"),
get_config_value("faceStrength", "max"),
get_config_value("faceStrength", "min"),
)
upscale_order = data.get("upscaleOrder", "correction-first")
return UpscaleParams(
upscaling,
correction_model=correction,
denoise=denoise,
upscale=upscale,
faces=faces,
face_outscale=face_outscale,
face_strength=face_strength,
outscale=outscale,
scale=scale,
upscale_order=upscale_order,
)
2023-04-01 16:26:10 +00:00
def build_highres(
data: Dict[str, str] = None,
) -> HighresParams:
if data is None:
data = request.args
enabled = get_boolean(data, "highres", get_config_value("highres"))
iterations = get_and_clamp_int(
data,
"highresIterations",
get_config_value("highresIterations"),
get_config_value("highresIterations", "max"),
get_config_value("highresIterations", "min"),
)
method = get_from_list(data, "highresMethod", get_highres_methods())
scale = get_and_clamp_int(
data,
"highresScale",
get_config_value("highresScale"),
get_config_value("highresScale", "max"),
get_config_value("highresScale", "min"),
)
steps = get_and_clamp_int(
data,
"highresSteps",
get_config_value("highresSteps"),
get_config_value("highresSteps", "max"),
get_config_value("highresSteps", "min"),
)
strength = get_and_clamp_float(
data,
"highresStrength",
get_config_value("highresStrength"),
get_config_value("highresStrength", "max"),
get_config_value("highresStrength", "min"),
)
2023-04-01 16:26:10 +00:00
return HighresParams(
enabled,
2023-04-01 16:26:10 +00:00
scale,
steps,
strength,
method=method,
iterations=iterations,
2023-04-01 16:26:10 +00:00
)
def build_latent_symmetry(
data: Dict[str, str] = None,
) -> LatentSymmetryParams:
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
gradient_start = get_and_clamp_float(
data,
"gradientStart",
get_config_value("gradientStart"),
get_config_value("gradientStart", "max"),
get_config_value("gradientStart", "min"),
)
gradient_end = get_and_clamp_float(
data,
"gradientEnd",
get_config_value("gradientEnd"),
get_config_value("gradientEnd", "max"),
get_config_value("gradientEnd", "min"),
)
line_of_symmetry = get_and_clamp_float(
data,
"lineOfSymmetry",
get_config_value("lineOfSymmetry"),
get_config_value("lineOfSymmetry", "max"),
get_config_value("lineOfSymmetry", "min"),
)
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
def build_prompt_editing(
data: Dict[str, str] = None,
) -> Dict[str, str]:
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "")
2024-02-24 17:06:32 +00:00
min_length = get_and_clamp_int(
data,
"minLength",
get_config_value("minLength"),
get_config_value("minLength", "max"),
get_config_value("minLength", "min"),
)
2024-02-24 17:06:32 +00:00
return PromptEditingParams(
enabled, prompt_filter, remove_tokens, add_suffix, min_length
)
def build_experimental(
data: Dict[str, str] = None,
) -> ExperimentalParams:
if data is None:
data = request.args
latent_symmetry_data = data.get("latentSymmetry", {})
latent_symmetry = build_latent_symmetry(latent_symmetry_data)
prompt_editing_data = data.get("promptEditing", {})
prompt_editing = build_prompt_editing(prompt_editing_data)
return ExperimentalParams(latent_symmetry, prompt_editing)
def is_json_request() -> bool:
return request.mimetype == "application/json"
def is_json_form_request() -> bool:
return request.mimetype == "multipart/form-data" and "json" in request.form
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
def pipeline_from_json(
server: ServerContext,
2023-12-26 14:06:16 +00:00
data: Dict[str, Union[str, Dict[str, str]]],
default_pipeline: str = "txt2img",
) -> PipelineParams:
device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data))
user = request.remote_addr
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
params.steps,
params.scheduler,
2023-09-14 03:06:12 +00:00
params.model,
params.pipeline,
device or "any device",
2023-09-14 03:06:12 +00:00
size.width,
size.height,
params.cfg,
params.seed,
params.prompt,
)
return (device, params, size)
def get_request_data(key: str | None = None) -> Any:
if is_json_request():
json = request.json
elif is_json_form_request():
json = load_config_str(request.form.get("json"))
else:
json = None
if key is not None and json is not None:
json = json.get(key)
return json or request.args
def get_request_params(
server: ServerContext, default_pipeline: str = None
) -> RequestParams:
data = get_request_data()
2024-02-17 21:22:33 +00:00
device, params, size = pipeline_from_json(server, data, default_pipeline)
2024-02-17 21:24:47 +00:00
border = None
if "border" in data:
border = build_border(data["border"])
upscale = None
if "upscale" in data:
upscale = build_upscale(data["upscale"])
highres = None
if "highres" in data:
highres = build_highres(data["highres"])
experimental = None
if "experimental" in data:
experimental = build_experimental(data["experimental"])
return RequestParams(
device,
params,
size=size,
border=border,
upscale=upscale,
highres=highres,
experimental=experimental,
)