1
0
Fork 0

add single class to store all request params

This commit is contained in:
Sean Sube 2024-02-17 15:12:18 -06:00
parent 1a4c31d077
commit bf1a88fac2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 276 additions and 139 deletions

View File

@ -8,6 +8,7 @@ from ..chain import (
BlendImg2ImgStage, BlendImg2ImgStage,
BlendMaskStage, BlendMaskStage,
ChainPipeline, ChainPipeline,
EditSafetyStage,
SourceTxt2ImgStage, SourceTxt2ImgStage,
TextPromptStage, TextPromptStage,
UpscaleOutpaintStage, UpscaleOutpaintStage,
@ -19,11 +20,12 @@ from ..image import expand_image
from ..output import make_output_names, read_metadata, save_image, save_result from ..output import make_output_names, read_metadata, save_image, save_result
from ..params import ( from ..params import (
Border, Border,
ExperimentalParams,
HighresParams, HighresParams,
ImageParams, ImageParams,
RequestParams,
Size, Size,
StageParams, StageParams,
UpscaleParams,
) )
from ..server import ServerContext from ..server import ServerContext
from ..server.load import get_source_filters from ..server.load import get_source_filters
@ -57,39 +59,49 @@ def add_safety_stage(
pipeline: ChainPipeline, pipeline: ChainPipeline,
) -> None: ) -> None:
if server.has_feature("horde-safety"): if server.has_feature("horde-safety"):
from ..chain.edit_safety import EditSafetyStage
pipeline.stage( pipeline.stage(
EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile) EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile)
) )
def add_prompt_filter(
server: ServerContext,
pipeline: ChainPipeline,
experimental: ExperimentalParams = None,
) -> None:
if experimental and experimental.prompt_editing.enabled:
if server.has_feature("prompt-filter"):
pipeline.stage(
TextPromptStage(),
StageParams(),
)
else:
logger.warning("prompt editing is not supported by the server")
def run_txt2img_pipeline( def run_txt2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, request: RequestParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
) -> None: ) -> None:
params = request.image
size = request.size
upscale = request.upscale
highres = request.highres
# if using panorama, the pipeline will tile itself (views) # if using panorama, the pipeline will tile itself (views)
tile_size = get_base_tile(params, size) tile_size = get_base_tile(params, size)
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
add_prompt_filter(server, chain)
if server.has_feature("prompt-filter"):
chain.stage(
TextPromptStage(),
StageParams(),
)
chain.stage( chain.stage(
SourceTxt2ImgStage(), SourceTxt2ImgStage(),
StageParams( StageParams(
tile_size=tile_size, tile_size=tile_size,
), ),
size=size, size=request.size,
prompt_index=0, prompt_index=0,
overlap=params.vae_overlap, overlap=params.vae_overlap,
) )
@ -145,13 +157,15 @@ def run_txt2img_pipeline(
def run_img2img_pipeline( def run_img2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, request: RequestParams,
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image, source: Image.Image,
strength: float, strength: float,
source_filter: Optional[str] = None, source_filter: Optional[str] = None,
) -> None: ) -> None:
params = request.image
upscale = request.upscale
highres = request.highres
# run filter on the source image # run filter on the source image
if source_filter is not None: if source_filter is not None:
f = get_source_filters().get(source_filter, None) f = get_source_filters().get(source_filter, None)
@ -246,10 +260,7 @@ def run_img2img_pipeline(
def run_inpaint_pipeline( def run_inpaint_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, request: RequestParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image, source: Image.Image,
mask: Image.Image, mask: Image.Image,
border: Border, border: Border,
@ -260,6 +271,11 @@ def run_inpaint_pipeline(
full_res_inpaint: bool, full_res_inpaint: bool,
full_res_inpaint_padding: float, full_res_inpaint_padding: float,
) -> None: ) -> None:
params = request.image
size = request.size
upscale = request.upscale
highres = request.highres
logger.debug("building inpaint pipeline") logger.debug("building inpaint pipeline")
tile_size = get_base_tile(params, size) tile_size = get_base_tile(params, size)
@ -453,12 +469,14 @@ def run_inpaint_pipeline(
def run_upscale_pipeline( def run_upscale_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, request: RequestParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image, source: Image.Image,
) -> None: ) -> None:
params = request.image
size = request.size
upscale = request.upscale
highres = request.highres
# set up the chain pipeline, no base stage for upscaling # set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline() chain = ChainPipeline()
tile_size = get_base_tile(params, size) tile_size = get_base_tile(params, size)
@ -521,13 +539,14 @@ def run_upscale_pipeline(
def run_blend_pipeline( def run_blend_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, request: RequestParams,
size: Size,
upscale: UpscaleParams,
# highres: HighresParams,
sources: List[Image.Image], sources: List[Image.Image],
mask: Image.Image, mask: Image.Image,
) -> None: ) -> None:
params = request.image
size = request.size
upscale = request.upscale
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
tile_size = get_base_tile(params, size) tile_size = get_base_tile(params, size)

View File

@ -12,6 +12,8 @@ logger = getLogger(__name__)
Param = Union[str, int, float] Param = Union[str, int, float]
Point = Tuple[int, int] Point = Tuple[int, int]
UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"]
UpscaleMethod = Literal["bilinear", "lanczos", "upscale"]
class SizeChart(IntEnum): class SizeChart(IntEnum):
@ -425,9 +427,6 @@ class StageParams:
) )
UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"]
class UpscaleParams: class UpscaleParams:
def __init__( def __init__(
self, self,
@ -532,9 +531,6 @@ class UpscaleParams:
) )
UpscaleMethod = Literal["bilinear", "lanczos", "upscale"]
class HighresParams: class HighresParams:
def __init__( def __init__(
self, self,
@ -593,6 +589,85 @@ class HighresParams:
) )
class LatentSymmetryParams:
enabled: bool
gradient_start: float
gradient_end: float
line_of_symmetry: float
def __init__(
self,
enabled: bool,
gradient_start: float,
gradient_end: float,
line_of_symmetry: float,
) -> None:
self.enabled = enabled
self.gradient_start = gradient_start
self.gradient_end = gradient_end
self.line_of_symmetry = line_of_symmetry
class PromptEditingParams:
enabled: bool
filter: str
remove_tokens: str
add_suffix: str
def __init__(
self,
enabled: bool,
filter: str,
remove_tokens: str,
add_suffix: str,
) -> None:
self.enabled = enabled
self.filter = filter
self.remove_tokens = remove_tokens
self.add_suffix = add_suffix
class ExperimentalParams:
latent_symmetry: LatentSymmetryParams
prompt_editing: PromptEditingParams
def __init__(
self,
latent_symmetry: LatentSymmetryParams,
prompt_editing: PromptEditingParams,
) -> None:
self.latent_symmetry = latent_symmetry
self.prompt_editing = prompt_editing
class RequestParams:
device: DeviceParams
image: ImageParams
size: Size | None
border: Border | None
upscale: UpscaleParams | None
highres: HighresParams | None
experimental: ExperimentalParams | None
def __init__(
self,
device: DeviceParams,
image: ImageParams,
size: Optional[Size] = None,
border: Optional[Border] = None,
upscale: Optional[UpscaleParams] = None,
highres: Optional[HighresParams] = None,
experimental: Optional[ExperimentalParams] = None,
) -> None:
self.device = device
self.image = image
self.size = size
self.border = border
self.upscale = upscale
self.highres = highres
self.experimental = experimental
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
if val is None: if val is None:
return SizeChart.auto return SizeChart.auto

View File

@ -52,13 +52,7 @@ from .load import (
get_upscaling_models, get_upscaling_models,
get_wildcard_data, get_wildcard_data,
) )
from .params import ( from .params import build_border, build_upscale, get_request_params, pipeline_from_json
build_border,
build_highres,
build_upscale,
pipeline_from_json,
pipeline_from_request,
)
from .utils import wrap_route from .utils import wrap_route
logger = getLogger(__name__) logger = getLogger(__name__)
@ -261,15 +255,13 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("source image is required") return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
size = Size(source.width, source.height)
device, params, _size = pipeline_from_request(server, "img2img") # TODO: look up the correct request field
upscale = build_upscale()
highres = build_highres()
source_filter = get_from_list( source_filter = get_from_list(
request.args, "sourceFilter", list(get_source_filters().keys()) request.args, "sourceFilter", list(get_source_filters().keys())
) )
# TODO: look up the correct request field
strength = get_and_clamp_float( strength = get_and_clamp_float(
request.args, request.args,
"strength", "strength",
@ -278,20 +270,22 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
get_config_value("strength", "min"), get_config_value("strength", "min"),
) )
replace_wildcards(params, get_wildcard_data()) params = get_request_params(server, JobType.IMG2IMG.value)
params.size = Size(source.width, source.height)
replace_wildcards(params.image, get_wildcard_data())
job_name = make_job_name("img2img", params, size, extras=[strength]) job_name = make_job_name(
JobType.IMG2IMG.value, params, params.size, extras=[strength]
)
queue = pool.submit( queue = pool.submit(
job_name, job_name,
JobType.IMG2IMG, JobType.IMG2IMG,
run_img2img_pipeline, run_img2img_pipeline,
server, server,
params, params,
upscale,
highres,
source, source,
strength, strength,
needs_device=device, needs_device=params.device,
source_filter=source_filter, source_filter=source_filter,
) )
@ -301,24 +295,17 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
def txt2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server, "txt2img") params = get_request_params()
upscale = build_upscale() replace_wildcards(params.image, get_wildcard_data())
highres = build_highres()
replace_wildcards(params, get_wildcard_data())
job_name = make_job_name("txt2img", params, size)
job_name = make_job_name(JobType.TXT2IMG.value, params.image, params.size)
queue = pool.submit( queue = pool.submit(
job_name, job_name,
JobType.TXT2IMG, JobType.TXT2IMG,
run_txt2img_pipeline, run_txt2img_pipeline,
server, server,
params, params,
size, needs_device=params.device,
upscale,
highres,
needs_device=device,
) )
logger.info("txt2img job queued for: %s", job_name) logger.info("txt2img job queued for: %s", job_name)
@ -343,6 +330,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
mask.alpha_composite(mask_top_layer) mask.alpha_composite(mask_top_layer)
mask.convert(mode="L") mask.convert(mode="L")
# TODO: look up the correct request field
full_res_inpaint = get_boolean( full_res_inpaint = get_boolean(
request.args, "fullresInpaint", get_config_value("fullresInpaint") request.args, "fullresInpaint", get_config_value("fullresInpaint")
) )
@ -354,10 +342,8 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
get_config_value("fullresInpaintPadding", "min"), get_config_value("fullresInpaintPadding", "min"),
) )
device, params, _size = pipeline_from_request(server, "inpaint") params = get_request_params(server, JobType.INPAINT.value)
expand = build_border() replace_wildcards(params.image, get_wildcard_data())
upscale = build_upscale()
highres = build_highres()
fill_color = get_not_empty(request.args, "fillColor", "white") fill_color = get_not_empty(request.args, "fillColor", "white")
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
@ -367,17 +353,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
) )
tile_order = TileOrder.spiral tile_order = TileOrder.spiral
replace_wildcards(params, get_wildcard_data())
job_name = make_job_name( job_name = make_job_name(
"inpaint", JobType.INPAINT.value,
params, params,
size, size,
extras=[ extras=[
expand.left, params.border.left,
expand.right, params.border.right,
expand.top, params.border.top,
expand.bottom, params.border.bottom,
mask_filter.__name__, mask_filter.__name__,
noise_source.__name__, noise_source.__name__,
fill_color, fill_color,
@ -391,19 +375,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
run_inpaint_pipeline, run_inpaint_pipeline,
server, server,
params, params,
size,
upscale,
highres,
source, source,
mask, mask,
expand,
noise_source, noise_source,
mask_filter, mask_filter,
fill_color, fill_color,
tile_order, tile_order,
full_res_inpaint, full_res_inpaint,
full_res_inpaint_padding, full_res_inpaint_padding,
needs_device=device, needs_device=params.device,
) )
logger.info("inpaint job queued for: %s", job_name) logger.info("inpaint job queued for: %s", job_name)
@ -418,24 +398,18 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request(server) params = get_request_params(server)
upscale = build_upscale() replace_wildcards(params.image, get_wildcard_data())
highres = build_highres()
replace_wildcards(params, get_wildcard_data()) job_name = make_job_name("upscale", params.image, params.size)
job_name = make_job_name("upscale", params, size)
queue = pool.submit( queue = pool.submit(
job_name, job_name,
JobType.UPSCALE, JobType.UPSCALE,
run_upscale_pipeline, run_upscale_pipeline,
server, server,
params, params,
size,
upscale,
highres,
source, source,
needs_device=device, needs_device=params.device,
) )
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
@ -571,22 +545,18 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
sources.append(source) sources.append(source)
device, params, size = pipeline_from_request(server) params = get_request_params(server)
upscale = build_upscale()
job_name = make_job_name("blend", params, size) job_name = make_job_name("blend", params.image, params.size)
queue = pool.submit( queue = pool.submit(
job_name, job_name,
JobType.BLEND, JobType.BLEND,
run_blend_pipeline, run_blend_pipeline,
server, server,
params, params,
size,
upscale,
# TODO: highres
sources, sources,
mask, mask,
needs_device=device, needs_device=params.device,
) )
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
@ -595,9 +565,9 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
def txt2txt(server: ServerContext, pool: DevicePoolExecutor): def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server) params = get_request_params(server)
job_name = make_job_name("txt2txt", params, size) job_name = make_job_name("txt2txt", params.image, params.size)
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
queue = pool.submit( queue = pool.submit(
@ -606,8 +576,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
run_txt2txt_pipeline, run_txt2txt_pipeline,
server, server,
params, params,
size, needs_device=params.device,
needs_device=device,
) )
return job_reply(job_name, queue=queue) return job_reply(job_name, queue=queue)

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
from flask import request from flask import request
@ -8,8 +8,12 @@ from ..diffusers.utils import random_seed
from ..params import ( from ..params import (
Border, Border,
DeviceParams, DeviceParams,
ExperimentalParams,
HighresParams, HighresParams,
ImageParams, ImageParams,
LatentSymmetryParams,
PromptEditingParams,
RequestParams,
Size, Size,
UpscaleParams, UpscaleParams,
) )
@ -345,6 +349,79 @@ def build_highres(
) )
def build_latent_symmetry(
data: Dict[str, str] = None,
) -> LatentSymmetryParams:
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
gradient_start = get_and_clamp_float(
data,
"gradientStart",
get_config_value("gradientStart"),
get_config_value("gradientStart", "max"),
get_config_value("gradientStart", "min"),
)
gradient_end = get_and_clamp_float(
data,
"gradientEnd",
get_config_value("gradientEnd"),
get_config_value("gradientEnd", "max"),
get_config_value("gradientEnd", "min"),
)
line_of_symmetry = get_and_clamp_float(
data,
"lineOfSymmetry",
get_config_value("lineOfSymmetry"),
get_config_value("lineOfSymmetry", "max"),
get_config_value("lineOfSymmetry", "min"),
)
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
def build_prompt_editing(
data: Dict[str, str] = None,
) -> Dict[str, str]:
if data is None:
data = request.args
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
prompt_filter = data.get("promptFilter", "")
remove_tokens = data.get("removeTokens", "")
add_suffix = data.get("addSuffix", "")
return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix)
def build_experimental(
data: Dict[str, str] = None,
) -> ExperimentalParams:
if data is None:
data = request.args
latent_symmetry_data = data.get("latentSymmetry", {})
latent_symmetry = build_latent_symmetry(latent_symmetry_data)
prompt_editing_data = data.get("promptEditing", {})
prompt_editing = build_prompt_editing(prompt_editing_data)
return ExperimentalParams(latent_symmetry, prompt_editing)
def is_json_request() -> bool:
return request.mimetype == "application/json"
def is_json_form_request() -> bool:
return request.mimetype == "multipart/form-data" and "json" in request.form
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
@ -353,36 +430,11 @@ def pipeline_from_json(
data: Dict[str, Union[str, Dict[str, str]]], data: Dict[str, Union[str, Dict[str, str]]],
default_pipeline: str = "txt2img", default_pipeline: str = "txt2img",
) -> PipelineParams: ) -> PipelineParams:
"""
Like pipeline_from_request but expects a nested structure.
"""
device = build_device(server, data.get("device", data)) device = build_device(server, data.get("device", data))
params = build_params(server, default_pipeline, data.get("params", data)) params = build_params(server, default_pipeline, data.get("params", data))
size = build_size(server, data.get("params", data)) size = build_size(server, data.get("params", data))
return (device, params, size)
def pipeline_from_request(
server: ServerContext,
default_pipeline: str = "txt2img",
) -> PipelineParams:
user = request.remote_addr user = request.remote_addr
mime = request.mimetype
if mime == "application/json":
device, params, size = pipeline_from_json(
server, request.json, default_pipeline
)
elif mime == "multipart/form-data":
form_json = load_config_str(request.form.get("json"))
device, params, size = pipeline_from_json(server, form_json, default_pipeline)
else:
device = build_device(server, request.args)
params = build_params(server, default_pipeline, request.args)
size = build_size(server, request.args)
logger.info( logger.info(
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
user, user,
@ -399,3 +451,39 @@ def pipeline_from_request(
) )
return (device, params, size) return (device, params, size)
def get_request_data(key: str | None = None) -> Any:
if is_json_request():
json = request.json
elif is_json_form_request():
json = load_config_str(request.form.get("json"))
else:
json = None
if key is not None and json is not None:
json = json.get(key)
return json or request.args
def get_request_params(
server: ServerContext, default_pipeline: str = None
) -> RequestParams:
data = get_request_data()
device, params, size = pipeline_from_json(server, default_pipeline)
border = build_border(data["border"])
upscale = build_upscale(data["upscale"])
highres = build_highres(data["highres"])
experimental = build_experimental(data["experimental"])
return RequestParams(
device,
params,
size=size,
border=border,
upscale=upscale,
highres=highres,
experimental=experimental,
)

View File

@ -33,14 +33,6 @@ import {
import { range } from '../utils.js'; import { range } from '../utils.js';
import { ApiClient } from './base.js'; import { ApiClient } from './base.js';
const FORM_HEADERS = {
'Content-Type': 'multipart/form-data',
};
const JSON_HEADERS = {
'Content-Type': 'application/json',
};
export function equalResponse(a: JobResponse, b: JobResponse): boolean { export function equalResponse(a: JobResponse, b: JobResponse): boolean {
return a.name === b.name; return a.name === b.name;
} }
@ -348,7 +340,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, { const job = await parseRequest(url, {
body: form, body: form,
headers: FORM_HEADERS,
method: 'POST', method: 'POST',
}); });
return { return {
@ -376,7 +367,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, { const job = await parseRequest(url, {
body: form, body: form,
headers: FORM_HEADERS,
method: 'POST', method: 'POST',
}); });
return { return {
@ -407,7 +397,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, { const job = await parseRequest(url, {
body: form, body: form,
headers: FORM_HEADERS,
method: 'POST', method: 'POST',
}); });
return { return {
@ -437,7 +426,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, { const job = await parseRequest(url, {
body: form, body: form,
headers: FORM_HEADERS,
method: 'POST', method: 'POST',
}); });
return { return {
@ -465,7 +453,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, { const job = await parseRequest(url, {
body: form, body: form,
headers: FORM_HEADERS,
method: 'POST', method: 'POST',
}); });
return { return {
@ -497,7 +484,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
const job = await parseRequest(url, { const job = await parseRequest(url, {
body: form, body: form,
headers: FORM_HEADERS,
method: 'POST', method: 'POST',
}); });
return { return {