diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 63f2cb95..675ae0ec 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -8,6 +8,7 @@ from ..chain import ( BlendImg2ImgStage, BlendMaskStage, ChainPipeline, + EditSafetyStage, SourceTxt2ImgStage, TextPromptStage, UpscaleOutpaintStage, @@ -19,11 +20,12 @@ from ..image import expand_image from ..output import make_output_names, read_metadata, save_image, save_result from ..params import ( Border, + ExperimentalParams, HighresParams, ImageParams, + RequestParams, Size, StageParams, - UpscaleParams, ) from ..server import ServerContext from ..server.load import get_source_filters @@ -57,39 +59,49 @@ def add_safety_stage( pipeline: ChainPipeline, ) -> None: if server.has_feature("horde-safety"): - from ..chain.edit_safety import EditSafetyStage - pipeline.stage( EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile) ) +def add_prompt_filter( + server: ServerContext, + pipeline: ChainPipeline, + experimental: ExperimentalParams = None, +) -> None: + if experimental and experimental.prompt_editing.enabled: + if server.has_feature("prompt-filter"): + pipeline.stage( + TextPromptStage(), + StageParams(), + ) + else: + logger.warning("prompt editing is not supported by the server") + + def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + highres = request.highres + # if using panorama, the pipeline will tile itself (views) tile_size = get_base_tile(params, size) # prepare the chain pipeline and first stage chain = ChainPipeline() - - if server.has_feature("prompt-filter"): - chain.stage( - TextPromptStage(), - StageParams(), - ) + add_prompt_filter(server, chain) chain.stage( SourceTxt2ImgStage(), StageParams( tile_size=tile_size, ), - size=size, + size=request.size, prompt_index=0, overlap=params.vae_overlap, ) @@ -145,13 +157,15 @@ def run_txt2img_pipeline( def run_img2img_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, source: Image.Image, strength: float, source_filter: Optional[str] = None, ) -> None: + params = request.image + upscale = request.upscale + highres = request.highres + # run filter on the source image if source_filter is not None: f = get_source_filters().get(source_filter, None) @@ -246,10 +260,7 @@ def run_img2img_pipeline( def run_inpaint_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, source: Image.Image, mask: Image.Image, border: Border, @@ -260,6 +271,11 @@ def run_inpaint_pipeline( full_res_inpaint: bool, full_res_inpaint_padding: float, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + highres = request.highres + logger.debug("building inpaint pipeline") tile_size = get_base_tile(params, size) @@ -453,12 +469,14 @@ def run_inpaint_pipeline( def run_upscale_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, source: Image.Image, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + highres = request.highres + # set up the chain pipeline, no base stage for upscaling chain = ChainPipeline() tile_size = get_base_tile(params, size) @@ -521,13 +539,14 @@ def run_upscale_pipeline( def run_blend_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - upscale: UpscaleParams, - # highres: HighresParams, + request: RequestParams, sources: List[Image.Image], mask: Image.Image, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + # set up the chain pipeline and base stage chain = ChainPipeline() tile_size = get_base_tile(params, size) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 0db05d26..4d5ef00b 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -12,6 +12,8 @@ logger = getLogger(__name__) Param = Union[str, int, float] Point = Tuple[int, int] +UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"] +UpscaleMethod = Literal["bilinear", "lanczos", "upscale"] class SizeChart(IntEnum): @@ -425,9 +427,6 @@ class StageParams: ) -UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"] - - class UpscaleParams: def __init__( self, @@ -532,9 +531,6 @@ class UpscaleParams: ) -UpscaleMethod = Literal["bilinear", "lanczos", "upscale"] - - class HighresParams: def __init__( self, @@ -593,6 +589,85 @@ class HighresParams: ) +class LatentSymmetryParams: + enabled: bool + gradient_start: float + gradient_end: float + line_of_symmetry: float + + def __init__( + self, + enabled: bool, + gradient_start: float, + gradient_end: float, + line_of_symmetry: float, + ) -> None: + self.enabled = enabled + self.gradient_start = gradient_start + self.gradient_end = gradient_end + self.line_of_symmetry = line_of_symmetry + + +class PromptEditingParams: + enabled: bool + filter: str + remove_tokens: str + add_suffix: str + + def __init__( + self, + enabled: bool, + filter: str, + remove_tokens: str, + add_suffix: str, + ) -> None: + self.enabled = enabled + self.filter = filter + self.remove_tokens = remove_tokens + self.add_suffix = add_suffix + + +class ExperimentalParams: + latent_symmetry: LatentSymmetryParams + prompt_editing: PromptEditingParams + + def __init__( + self, + latent_symmetry: LatentSymmetryParams, + prompt_editing: PromptEditingParams, + ) -> None: + self.latent_symmetry = latent_symmetry + self.prompt_editing = prompt_editing + + +class RequestParams: + device: DeviceParams + image: ImageParams + size: Size | None + border: Border | None + upscale: UpscaleParams | None + highres: HighresParams | None + experimental: ExperimentalParams | None + + def __init__( + self, + device: DeviceParams, + image: ImageParams, + size: Optional[Size] = None, + border: Optional[Border] = None, + upscale: Optional[UpscaleParams] = None, + highres: Optional[HighresParams] = None, + experimental: Optional[ExperimentalParams] = None, + ) -> None: + self.device = device + self.image = image + self.size = size + self.border = border + self.upscale = upscale + self.highres = highres + self.experimental = experimental + + def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: if val is None: return SizeChart.auto diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index c3d32634..e0732235 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -52,13 +52,7 @@ from .load import ( get_upscaling_models, get_wildcard_data, ) -from .params import ( - build_border, - build_highres, - build_upscale, - pipeline_from_json, - pipeline_from_request, -) +from .params import build_border, build_upscale, get_request_params, pipeline_from_json from .utils import wrap_route logger = getLogger(__name__) @@ -261,15 +255,13 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): return error_reply("source image is required") source = Image.open(BytesIO(source_file.read())).convert("RGB") - size = Size(source.width, source.height) - device, params, _size = pipeline_from_request(server, "img2img") - upscale = build_upscale() - highres = build_highres() + # TODO: look up the correct request field source_filter = get_from_list( request.args, "sourceFilter", list(get_source_filters().keys()) ) + # TODO: look up the correct request field strength = get_and_clamp_float( request.args, "strength", @@ -278,20 +270,22 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): get_config_value("strength", "min"), ) - replace_wildcards(params, get_wildcard_data()) + params = get_request_params(server, JobType.IMG2IMG.value) + params.size = Size(source.width, source.height) + replace_wildcards(params.image, get_wildcard_data()) - job_name = make_job_name("img2img", params, size, extras=[strength]) + job_name = make_job_name( + JobType.IMG2IMG.value, params, params.size, extras=[strength] + ) queue = pool.submit( job_name, JobType.IMG2IMG, run_img2img_pipeline, server, params, - upscale, - highres, source, strength, - needs_device=device, + needs_device=params.device, source_filter=source_filter, ) @@ -301,24 +295,17 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor): - device, params, size = pipeline_from_request(server, "txt2img") - upscale = build_upscale() - highres = build_highres() - - replace_wildcards(params, get_wildcard_data()) - - job_name = make_job_name("txt2img", params, size) + params = get_request_params() + replace_wildcards(params.image, get_wildcard_data()) + job_name = make_job_name(JobType.TXT2IMG.value, params.image, params.size) queue = pool.submit( job_name, JobType.TXT2IMG, run_txt2img_pipeline, server, params, - size, - upscale, - highres, - needs_device=device, + needs_device=params.device, ) logger.info("txt2img job queued for: %s", job_name) @@ -343,6 +330,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): mask.alpha_composite(mask_top_layer) mask.convert(mode="L") + # TODO: look up the correct request field full_res_inpaint = get_boolean( request.args, "fullresInpaint", get_config_value("fullresInpaint") ) @@ -354,10 +342,8 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): get_config_value("fullresInpaintPadding", "min"), ) - device, params, _size = pipeline_from_request(server, "inpaint") - expand = build_border() - upscale = build_upscale() - highres = build_highres() + params = get_request_params(server, JobType.INPAINT.value) + replace_wildcards(params.image, get_wildcard_data()) fill_color = get_not_empty(request.args, "fillColor", "white") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") @@ -367,17 +353,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): ) tile_order = TileOrder.spiral - replace_wildcards(params, get_wildcard_data()) - job_name = make_job_name( - "inpaint", + JobType.INPAINT.value, params, size, extras=[ - expand.left, - expand.right, - expand.top, - expand.bottom, + params.border.left, + params.border.right, + params.border.top, + params.border.bottom, mask_filter.__name__, noise_source.__name__, fill_color, @@ -391,19 +375,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): run_inpaint_pipeline, server, params, - size, - upscale, - highres, source, mask, - expand, noise_source, mask_filter, fill_color, tile_order, full_res_inpaint, full_res_inpaint_padding, - needs_device=device, + needs_device=params.device, ) logger.info("inpaint job queued for: %s", job_name) @@ -418,24 +398,18 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") - device, params, size = pipeline_from_request(server) - upscale = build_upscale() - highres = build_highres() + params = get_request_params(server) + replace_wildcards(params.image, get_wildcard_data()) - replace_wildcards(params, get_wildcard_data()) - - job_name = make_job_name("upscale", params, size) + job_name = make_job_name("upscale", params.image, params.size) queue = pool.submit( job_name, JobType.UPSCALE, run_upscale_pipeline, server, params, - size, - upscale, - highres, source, - needs_device=device, + needs_device=params.device, ) logger.info("upscale job queued for: %s", job_name) @@ -571,22 +545,18 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") sources.append(source) - device, params, size = pipeline_from_request(server) - upscale = build_upscale() + params = get_request_params(server) - job_name = make_job_name("blend", params, size) + job_name = make_job_name("blend", params.image, params.size) queue = pool.submit( job_name, JobType.BLEND, run_blend_pipeline, server, params, - size, - upscale, - # TODO: highres sources, mask, - needs_device=device, + needs_device=params.device, ) logger.info("upscale job queued for: %s", job_name) @@ -595,9 +565,9 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): def txt2txt(server: ServerContext, pool: DevicePoolExecutor): - device, params, size = pipeline_from_request(server) + params = get_request_params(server) - job_name = make_job_name("txt2txt", params, size) + job_name = make_job_name("txt2txt", params.image, params.size) logger.info("upscale job queued for: %s", job_name) queue = pool.submit( @@ -606,8 +576,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor): run_txt2txt_pipeline, server, params, - size, - needs_device=device, + needs_device=params.device, ) return job_reply(job_name, queue=queue) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index df1648c3..3fd40e95 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from flask import request @@ -8,8 +8,12 @@ from ..diffusers.utils import random_seed from ..params import ( Border, DeviceParams, + ExperimentalParams, HighresParams, ImageParams, + LatentSymmetryParams, + PromptEditingParams, + RequestParams, Size, UpscaleParams, ) @@ -345,6 +349,79 @@ def build_highres( ) +def build_latent_symmetry( + data: Dict[str, str] = None, +) -> LatentSymmetryParams: + if data is None: + data = request.args + + enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry")) + + gradient_start = get_and_clamp_float( + data, + "gradientStart", + get_config_value("gradientStart"), + get_config_value("gradientStart", "max"), + get_config_value("gradientStart", "min"), + ) + + gradient_end = get_and_clamp_float( + data, + "gradientEnd", + get_config_value("gradientEnd"), + get_config_value("gradientEnd", "max"), + get_config_value("gradientEnd", "min"), + ) + + line_of_symmetry = get_and_clamp_float( + data, + "lineOfSymmetry", + get_config_value("lineOfSymmetry"), + get_config_value("lineOfSymmetry", "max"), + get_config_value("lineOfSymmetry", "min"), + ) + + return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry) + + +def build_prompt_editing( + data: Dict[str, str] = None, +) -> Dict[str, str]: + if data is None: + data = request.args + + enabled = get_boolean(data, "enabled", get_config_value("promptEditing")) + + prompt_filter = data.get("promptFilter", "") + remove_tokens = data.get("removeTokens", "") + add_suffix = data.get("addSuffix", "") + + return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix) + + +def build_experimental( + data: Dict[str, str] = None, +) -> ExperimentalParams: + if data is None: + data = request.args + + latent_symmetry_data = data.get("latentSymmetry", {}) + latent_symmetry = build_latent_symmetry(latent_symmetry_data) + + prompt_editing_data = data.get("promptEditing", {}) + prompt_editing = build_prompt_editing(prompt_editing_data) + + return ExperimentalParams(latent_symmetry, prompt_editing) + + +def is_json_request() -> bool: + return request.mimetype == "application/json" + + +def is_json_form_request() -> bool: + return request.mimetype == "multipart/form-data" and "json" in request.form + + PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] @@ -353,36 +430,11 @@ def pipeline_from_json( data: Dict[str, Union[str, Dict[str, str]]], default_pipeline: str = "txt2img", ) -> PipelineParams: - """ - Like pipeline_from_request but expects a nested structure. - """ - device = build_device(server, data.get("device", data)) params = build_params(server, default_pipeline, data.get("params", data)) size = build_size(server, data.get("params", data)) - return (device, params, size) - - -def pipeline_from_request( - server: ServerContext, - default_pipeline: str = "txt2img", -) -> PipelineParams: user = request.remote_addr - mime = request.mimetype - - if mime == "application/json": - device, params, size = pipeline_from_json( - server, request.json, default_pipeline - ) - elif mime == "multipart/form-data": - form_json = load_config_str(request.form.get("json")) - device, params, size = pipeline_from_json(server, form_json, default_pipeline) - else: - device = build_device(server, request.args) - params = build_params(server, default_pipeline, request.args) - size = build_size(server, request.args) - logger.info( "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", user, @@ -399,3 +451,39 @@ def pipeline_from_request( ) return (device, params, size) + + +def get_request_data(key: str | None = None) -> Any: + if is_json_request(): + json = request.json + elif is_json_form_request(): + json = load_config_str(request.form.get("json")) + else: + json = None + + if key is not None and json is not None: + json = json.get(key) + + return json or request.args + + +def get_request_params( + server: ServerContext, default_pipeline: str = None +) -> RequestParams: + data = get_request_data() + + device, params, size = pipeline_from_json(server, default_pipeline) + border = build_border(data["border"]) + upscale = build_upscale(data["upscale"]) + highres = build_highres(data["highres"]) + experimental = build_experimental(data["experimental"]) + + return RequestParams( + device, + params, + size=size, + border=border, + upscale=upscale, + highres=highres, + experimental=experimental, + ) diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 1b9cdbd7..e2e9e42d 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -33,14 +33,6 @@ import { import { range } from '../utils.js'; import { ApiClient } from './base.js'; -const FORM_HEADERS = { - 'Content-Type': 'multipart/form-data', -}; - -const JSON_HEADERS = { - 'Content-Type': 'application/json', -}; - export function equalResponse(a: JobResponse, b: JobResponse): boolean { return a.name === b.name; } @@ -348,7 +340,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe