1
0
Fork 0

add single class to store all request params

This commit is contained in:
Sean Sube 2024-02-17 15:12:18 -06:00
parent 1a4c31d077
commit bf1a88fac2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 276 additions and 139 deletions

View File

@ -8,6 +8,7 @@ from ..chain import (
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
EditSafetyStage,
SourceTxt2ImgStage,
TextPromptStage,
UpscaleOutpaintStage,
@ -19,11 +20,12 @@ from ..image import expand_image
from ..output import make_output_names, read_metadata, save_image, save_result
from ..params import (
Border,
ExperimentalParams,
HighresParams,
ImageParams,
RequestParams,
Size,
StageParams,
UpscaleParams,
)
from ..server import ServerContext
from ..server.load import get_source_filters
@ -57,39 +59,49 @@ def add_safety_stage(
pipeline: ChainPipeline,
) -> None:
if server.has_feature("horde-safety"):
from ..chain.edit_safety import EditSafetyStage
pipeline.stage(
EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile)
)
def add_prompt_filter(
server: ServerContext,
pipeline: ChainPipeline,
experimental: ExperimentalParams = None,
) -> None:
if experimental and experimental.prompt_editing.enabled:
if server.has_feature("prompt-filter"):
pipeline.stage(
TextPromptStage(),
StageParams(),
)
else:
logger.warning("prompt editing is not supported by the server")
def run_txt2img_pipeline(
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
request: RequestParams,
) -> None:
params = request.image
size = request.size
upscale = request.upscale
highres = request.highres
# if using panorama, the pipeline will tile itself (views)
tile_size = get_base_tile(params, size)
# prepare the chain pipeline and first stage
chain = ChainPipeline()
if server.has_feature("prompt-filter"):
chain.stage(
TextPromptStage(),
StageParams(),
)
add_prompt_filter(server, chain)
chain.stage(
SourceTxt2ImgStage(),
StageParams(
tile_size=tile_size,
),
size=size,
size=request.size,
prompt_index=0,
overlap=params.vae_overlap,
)
@ -145,13 +157,15 @@ def run_txt2img_pipeline(
def run_img2img_pipeline(
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
upscale: UpscaleParams,
highres: HighresParams,
request: RequestParams,
source: Image.Image,
strength: float,
source_filter: Optional[str] = None,
) -> None:
params = request.image
upscale = request.upscale
highres = request.highres
# run filter on the source image
if source_filter is not None:
f = get_source_filters().get(source_filter, None)
@ -246,10 +260,7 @@ def run_img2img_pipeline(
def run_inpaint_pipeline(
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
request: RequestParams,
source: Image.Image,
mask: Image.Image,
border: Border,
@ -260,6 +271,11 @@ def run_inpaint_pipeline(
full_res_inpaint: bool,
full_res_inpaint_padding: float,
) -> None:
params = request.image
size = request.size
upscale = request.upscale
highres = request.highres
logger.debug("building inpaint pipeline")
tile_size = get_base_tile(params, size)
@ -453,12 +469,14 @@ def run_inpaint_pipeline(
def run_upscale_pipeline(
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
request: RequestParams,
source: Image.Image,
) -> None:
params = request.image
size = request.size
upscale = request.upscale
highres = request.highres
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
tile_size = get_base_tile(params, size)
@ -521,13 +539,14 @@ def run_upscale_pipeline(
def run_blend_pipeline(
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
upscale: UpscaleParams,
# highres: HighresParams,
request: RequestParams,
sources: List[Image.Image],
mask: Image.Image,
) -> None:
params = request.image
size = request.size
upscale = request.upscale
# set up the chain pipeline and base stage
chain = ChainPipeline()
tile_size = get_base_tile(params, size)

View File

@ -12,6 +12,8 @@ logger = getLogger(__name__)
Param = Union[str, int, float]
Point = Tuple[int, int]
UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"]
UpscaleMethod = Literal["bilinear", "lanczos", "upscale"]
class SizeChart(IntEnum):
@ -425,9 +427,6 @@ class StageParams:
)
UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"]
class UpscaleParams:
def __init__(
self,
@ -532,9 +531,6 @@ class UpscaleParams:
)
UpscaleMethod = Literal["bilinear", "lanczos", "upscale"]
class HighresParams:
def __init__(
self,
@ -593,6 +589,85 @@ class HighresParams:
)
class LatentSymmetryParams:
enabled: bool
gradient_start: float
gradient_end: float
line_of_symmetry: float
def __init__(
self,
enabled: bool,
gradient_start: float,
gradient_end: float,
line_of_symmetry: float,
) -> None:
self.enabled = enabled
self.gradient_start = gradient_start
self.gradient_end = gradient_end
self.line_of_symmetry = line_of_symmetry
class PromptEditingParams:
enabled: bool
filter: str
remove_tokens: str
add_suffix: str
def __init__(
self,
enabled: bool,
filter: str,
remove_tokens: str,
add_suffix: str,
) -> None:
self.enabled = enabled
self.filter = filter
self.remove_tokens = remove_tokens
self.add_suffix = add_suffix
class ExperimentalParams:
latent_symmetry: LatentSymmetryParams
prompt_editing: PromptEditingParams
def __init__(
self,
latent_symmetry: LatentSymmetryParams,
prompt_editing: PromptEditingParams,
) -> None:
self.latent_symmetry = latent_symmetry
self.prompt_editing = prompt_editing
class RequestParams:
device: DeviceParams
image: ImageParams
size: Size | None
border: Border | None
upscale: UpscaleParams | None
highres: HighresParams | None
experimental: ExperimentalParams | None
def __init__(
self,
device: DeviceParams,
image: ImageParams,
size: Optional[Size] = None,
border: Optional[Border] = None,
upscale: Optional[UpscaleParams] = None,
highres: Optional[HighresParams] = None,
experimental: Optional[ExperimentalParams] = None,
) -> None:
self.device = device
self.image = image
self.size = size
self.border = border
self.upscale = upscale
self.highres = highres
self.experimental = experimental
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
if val is None:
return SizeChart.auto

View File

@ -52,13 +52,7 @@ from .load import (
get_upscaling_models,
get_wildcard_data,
)
from .params import (
build_border,
build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request,
)
from .params import build_border, build_upscale, get_request_params, pipeline_from_json
from .utils import wrap_route
logger = getLogger(__name__)
@ -261,15 +255,13 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
size = Size(source.width, source.height)
device, params, _size = pipeline_from_request(server, "img2img")
upscale = build_upscale()
highres = build_highres()
# TODO: look up the correct request field
source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys())
)
# TODO: look up the correct request field
strength = get_and_clamp_float(
request.args,
"strength",
@ -278,20 +270,22 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"),
)
replace_wildcards(params, get_wildcard_data())
params = get_request_params(server, JobType.IMG2IMG.value)
params.size = Size(source.width, source.height)
replace_wildcards(params.image, get_wildcard_data())
job_name = make_job_name("img2img", params, size, extras=[strength])
job_name = make_job_name(
JobType.IMG2IMG.value, params, params.size, extras=[strength]
)
queue = pool.submit(
job_name,
JobType.IMG2IMG,
run_img2img_pipeline,
server,
params,
upscale,
highres,
source,
strength,
needs_device=device,
needs_device=params.device,
source_filter=source_filter,
)
@ -301,24 +295,17 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img")
upscale = build_upscale()
highres = build_highres()
replace_wildcards(params, get_wildcard_data())
job_name = make_job_name("txt2img", params, size)
params = get_request_params()
replace_wildcards(params.image, get_wildcard_data())
job_name = make_job_name(JobType.TXT2IMG.value, params.image, params.size)
queue = pool.submit(
job_name,
JobType.TXT2IMG,
run_txt2img_pipeline,
server,
params,
size,
upscale,
highres,
needs_device=device,
needs_device=params.device,
)
logger.info("txt2img job queued for: %s", job_name)
@ -343,6 +330,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
mask.alpha_composite(mask_top_layer)
mask.convert(mode="L")
# TODO: look up the correct request field
full_res_inpaint = get_boolean(
request.args, "fullresInpaint", get_config_value("fullresInpaint")
)
@ -354,10 +342,8 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
get_config_value("fullresInpaintPadding", "min"),
)
device, params, _size = pipeline_from_request(server, "inpaint")
expand = build_border()
upscale = build_upscale()
highres = build_highres()
params = get_request_params(server, JobType.INPAINT.value)
replace_wildcards(params.image, get_wildcard_data())
fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -367,17 +353,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
)
tile_order = TileOrder.spiral
replace_wildcards(params, get_wildcard_data())
job_name = make_job_name(
"inpaint",
JobType.INPAINT.value,
params,
size,
extras=[
expand.left,
expand.right,
expand.top,
expand.bottom,
params.border.left,
params.border.right,
params.border.top,
params.border.bottom,
mask_filter.__name__,
noise_source.__name__,
fill_color,
@ -391,19 +375,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
run_inpaint_pipeline,
server,
params,
size,
upscale,
highres,
source,
mask,
expand,
noise_source,
mask_filter,
fill_color,
tile_order,
full_res_inpaint,
full_res_inpaint_padding,
needs_device=device,
needs_device=params.device,
)
logger.info("inpaint job queued for: %s", job_name)
@ -418,24 +398,18 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server)
upscale = build_upscale()
highres = build_highres()
params = get_request_params(server)
replace_wildcards(params.image, get_wildcard_data())
replace_wildcards(params, get_wildcard_data())
job_name = make_job_name("upscale", params, size)
job_name = make_job_name("upscale", params.image, params.size)
queue = pool.submit(
job_name,
JobType.UPSCALE,
run_upscale_pipeline,
server,
params,
size,
upscale,
highres,
source,
needs_device=device,
needs_device=params.device,
)
logger.info("upscale job queued for: %s", job_name)
@ -571,22 +545,18 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB")
sources.append(source)
device, params, size = pipeline_from_request(server)
upscale = build_upscale()
params = get_request_params(server)
job_name = make_job_name("blend", params, size)
job_name = make_job_name("blend", params.image, params.size)
queue = pool.submit(
job_name,
JobType.BLEND,
run_blend_pipeline,
server,
params,
size,
upscale,
# TODO: highres
sources,
mask,
needs_device=device,
needs_device=params.device,
)
logger.info("upscale job queued for: %s", job_name)
@ -595,9 +565,9 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
params = get_request_params(server)
job_name = make_job_name("txt2txt", params, size)
job_name = make_job_name("txt2txt", params.image, params.size)
logger.info("upscale job queued for: %s", job_name)
queue = pool.submit(
@ -606,8 +576,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
run_txt2txt_pipeline,
server,
params,
size,
needs_device=device,
needs_device=params.device,
)
return job_reply(job_name, queue=queue)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
from flask import request
@ -8,8 +8,12 @@ from ..diffusers.utils import random_seed
from ..params import (
Border,
DeviceParams,
ExperimentalParams,
HighresParams,
ImageParams,
LatentSymmetryParams,
PromptEditingParams,
RequestParams,
Size,
UpscaleParams,
)
@ -345,6 +349,79 @@ def build_highres(
)
def build_latent_symmetry(
data: Dict[str, str] = None,
) -> LatentSymmetryParams:
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
gradient_start = get_and_clamp_float(
data,
"gradientStart",
get_config_value("gradientStart"),
get_config_value("gradientStart", "max"),
get_config_value("gradientStart", "min"),
)
gradient_end = get_and_clamp_float(
data,
"gradientEnd",
get_config_value("gradientEnd"),
get_config_value("gradientEnd", "max"),
get_config_value("gradientEnd", "min"),
)
line_of_symmetry = get_and_clamp_float(
data,
"lineOfSymmetry",
get_config_value("lineOfSymmetry"),
get_config_value("lineOfSymmetry", "max"),
get_config_value("lineOfSymmetry", "min"),
)
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
def build_prompt_editing(
data: Dict[str, str] = None,
) -> Dict[str, str]:
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "")
return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix)
def build_experimental(
data: Dict[str, str] = None,
) -> ExperimentalParams:
if data is None:
data = request.args
latent_symmetry_data = data.get("latentSymmetry", {})
latent_symmetry = build_latent_symmetry(latent_symmetry_data)
prompt_editing_data = data.get("promptEditing", {})
prompt_editing = build_prompt_editing(prompt_editing_data)
return ExperimentalParams(latent_symmetry, prompt_editing)
def is_json_request() -> bool:
return request.mimetype == "application/json"
def is_json_form_request() -> bool:
return request.mimetype == "multipart/form-data" and "json" in request.form
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
@ -353,36 +430,11 @@ def pipeline_from_json(
data: Dict[str, Union[str, Dict[str, str]]],
default_pipeline: str = "txt2img",
) -> PipelineParams:
"""
Like pipeline_from_request but expects a nested structure.
"""
device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data))
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> PipelineParams:
user = request.remote_addr
mime = request.mimetype
if mime == "application/json":
device, params, size = pipeline_from_json(
server, request.json, default_pipeline
)
elif mime == "multipart/form-data":
form_json = load_config_str(request.form.get("json"))
device, params, size = pipeline_from_json(server, form_json, default_pipeline)
else:
device = build_device(server, request.args)
params = build_params(server, default_pipeline, request.args)
size = build_size(server, request.args)
logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user,
@ -399,3 +451,39 @@ def pipeline_from_request(
)
return (device, params, size)
def get_request_data(key: str | None = None) -> Any:
if is_json_request():
json = request.json
elif is_json_form_request():
json = load_config_str(request.form.get("json"))
else:
json = None
if key is not None and json is not None:
json = json.get(key)
return json or request.args
def get_request_params(
server: ServerContext, default_pipeline: str = None
) -> RequestParams:
data = get_request_data()
device, params, size = pipeline_from_json(server, default_pipeline)
border = build_border(data["border"])
upscale = build_upscale(data["upscale"])
highres = build_highres(data["highres"])
experimental = build_experimental(data["experimental"])
return RequestParams(
device,
params,
size=size,
border=border,
upscale=upscale,
highres=highres,
experimental=experimental,
)

View File

@ -33,14 +33,6 @@ import {
import { range } from '../utils.js';
import { ApiClient } from './base.js';
const FORM_HEADERS = {
'Content-Type': 'multipart/form-data',
};
const JSON_HEADERS = {
'Content-Type': 'application/json',
};
export function equalResponse(a: JobResponse, b: JobResponse): boolean {
return a.name === b.name;
}
@ -348,7 +340,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, {
body: form,
headers: FORM_HEADERS,
method: 'POST',
});
return {
@ -376,7 +367,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, {
body: form,
headers: FORM_HEADERS,
method: 'POST',
});
return {
@ -407,7 +397,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, {
body: form,
headers: FORM_HEADERS,
method: 'POST',
});
return {
@ -437,7 +426,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, {
body: form,
headers: FORM_HEADERS,
method: 'POST',
});
return {
@ -465,7 +453,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, {
body: form,
headers: FORM_HEADERS,
method: 'POST',
});
return {
@ -497,7 +484,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, {
body: form,
headers: FORM_HEADERS,
method: 'POST',
});
return {