add single class to store all request params
This commit is contained in:
parent
1a4c31d077
commit
bf1a88fac2
|
@ -8,6 +8,7 @@ from ..chain import (
|
||||||
BlendImg2ImgStage,
|
BlendImg2ImgStage,
|
||||||
BlendMaskStage,
|
BlendMaskStage,
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
|
EditSafetyStage,
|
||||||
SourceTxt2ImgStage,
|
SourceTxt2ImgStage,
|
||||||
TextPromptStage,
|
TextPromptStage,
|
||||||
UpscaleOutpaintStage,
|
UpscaleOutpaintStage,
|
||||||
|
@ -19,11 +20,12 @@ from ..image import expand_image
|
||||||
from ..output import make_output_names, read_metadata, save_image, save_result
|
from ..output import make_output_names, read_metadata, save_image, save_result
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
|
ExperimentalParams,
|
||||||
HighresParams,
|
HighresParams,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
|
RequestParams,
|
||||||
Size,
|
Size,
|
||||||
StageParams,
|
StageParams,
|
||||||
UpscaleParams,
|
|
||||||
)
|
)
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..server.load import get_source_filters
|
from ..server.load import get_source_filters
|
||||||
|
@ -57,39 +59,49 @@ def add_safety_stage(
|
||||||
pipeline: ChainPipeline,
|
pipeline: ChainPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
if server.has_feature("horde-safety"):
|
if server.has_feature("horde-safety"):
|
||||||
from ..chain.edit_safety import EditSafetyStage
|
|
||||||
|
|
||||||
pipeline.stage(
|
pipeline.stage(
|
||||||
EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile)
|
EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_prompt_filter(
|
||||||
|
server: ServerContext,
|
||||||
|
pipeline: ChainPipeline,
|
||||||
|
experimental: ExperimentalParams = None,
|
||||||
|
) -> None:
|
||||||
|
if experimental and experimental.prompt_editing.enabled:
|
||||||
|
if server.has_feature("prompt-filter"):
|
||||||
|
pipeline.stage(
|
||||||
|
TextPromptStage(),
|
||||||
|
StageParams(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("prompt editing is not supported by the server")
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
request: RequestParams,
|
||||||
size: Size,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
highres: HighresParams,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
params = request.image
|
||||||
|
size = request.size
|
||||||
|
upscale = request.upscale
|
||||||
|
highres = request.highres
|
||||||
|
|
||||||
# if using panorama, the pipeline will tile itself (views)
|
# if using panorama, the pipeline will tile itself (views)
|
||||||
tile_size = get_base_tile(params, size)
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
|
add_prompt_filter(server, chain)
|
||||||
if server.has_feature("prompt-filter"):
|
|
||||||
chain.stage(
|
|
||||||
TextPromptStage(),
|
|
||||||
StageParams(),
|
|
||||||
)
|
|
||||||
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
SourceTxt2ImgStage(),
|
SourceTxt2ImgStage(),
|
||||||
StageParams(
|
StageParams(
|
||||||
tile_size=tile_size,
|
tile_size=tile_size,
|
||||||
),
|
),
|
||||||
size=size,
|
size=request.size,
|
||||||
prompt_index=0,
|
prompt_index=0,
|
||||||
overlap=params.vae_overlap,
|
overlap=params.vae_overlap,
|
||||||
)
|
)
|
||||||
|
@ -145,13 +157,15 @@ def run_txt2img_pipeline(
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
request: RequestParams,
|
||||||
upscale: UpscaleParams,
|
|
||||||
highres: HighresParams,
|
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
strength: float,
|
strength: float,
|
||||||
source_filter: Optional[str] = None,
|
source_filter: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
params = request.image
|
||||||
|
upscale = request.upscale
|
||||||
|
highres = request.highres
|
||||||
|
|
||||||
# run filter on the source image
|
# run filter on the source image
|
||||||
if source_filter is not None:
|
if source_filter is not None:
|
||||||
f = get_source_filters().get(source_filter, None)
|
f = get_source_filters().get(source_filter, None)
|
||||||
|
@ -246,10 +260,7 @@ def run_img2img_pipeline(
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
request: RequestParams,
|
||||||
size: Size,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
highres: HighresParams,
|
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
mask: Image.Image,
|
mask: Image.Image,
|
||||||
border: Border,
|
border: Border,
|
||||||
|
@ -260,6 +271,11 @@ def run_inpaint_pipeline(
|
||||||
full_res_inpaint: bool,
|
full_res_inpaint: bool,
|
||||||
full_res_inpaint_padding: float,
|
full_res_inpaint_padding: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
params = request.image
|
||||||
|
size = request.size
|
||||||
|
upscale = request.upscale
|
||||||
|
highres = request.highres
|
||||||
|
|
||||||
logger.debug("building inpaint pipeline")
|
logger.debug("building inpaint pipeline")
|
||||||
tile_size = get_base_tile(params, size)
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
||||||
|
@ -453,12 +469,14 @@ def run_inpaint_pipeline(
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
request: RequestParams,
|
||||||
size: Size,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
highres: HighresParams,
|
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
params = request.image
|
||||||
|
size = request.size
|
||||||
|
upscale = request.upscale
|
||||||
|
highres = request.highres
|
||||||
|
|
||||||
# set up the chain pipeline, no base stage for upscaling
|
# set up the chain pipeline, no base stage for upscaling
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
tile_size = get_base_tile(params, size)
|
tile_size = get_base_tile(params, size)
|
||||||
|
@ -521,13 +539,14 @@ def run_upscale_pipeline(
|
||||||
def run_blend_pipeline(
|
def run_blend_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
request: RequestParams,
|
||||||
size: Size,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
# highres: HighresParams,
|
|
||||||
sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
mask: Image.Image,
|
mask: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
params = request.image
|
||||||
|
size = request.size
|
||||||
|
upscale = request.upscale
|
||||||
|
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
tile_size = get_base_tile(params, size)
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
|
@ -12,6 +12,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
Param = Union[str, int, float]
|
Param = Union[str, int, float]
|
||||||
Point = Tuple[int, int]
|
Point = Tuple[int, int]
|
||||||
|
UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"]
|
||||||
|
UpscaleMethod = Literal["bilinear", "lanczos", "upscale"]
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
|
@ -425,9 +427,6 @@ class StageParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"]
|
|
||||||
|
|
||||||
|
|
||||||
class UpscaleParams:
|
class UpscaleParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -532,9 +531,6 @@ class UpscaleParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
UpscaleMethod = Literal["bilinear", "lanczos", "upscale"]
|
|
||||||
|
|
||||||
|
|
||||||
class HighresParams:
|
class HighresParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -593,6 +589,85 @@ class HighresParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatentSymmetryParams:
|
||||||
|
enabled: bool
|
||||||
|
gradient_start: float
|
||||||
|
gradient_end: float
|
||||||
|
line_of_symmetry: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
gradient_start: float,
|
||||||
|
gradient_end: float,
|
||||||
|
line_of_symmetry: float,
|
||||||
|
) -> None:
|
||||||
|
self.enabled = enabled
|
||||||
|
self.gradient_start = gradient_start
|
||||||
|
self.gradient_end = gradient_end
|
||||||
|
self.line_of_symmetry = line_of_symmetry
|
||||||
|
|
||||||
|
|
||||||
|
class PromptEditingParams:
|
||||||
|
enabled: bool
|
||||||
|
filter: str
|
||||||
|
remove_tokens: str
|
||||||
|
add_suffix: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
filter: str,
|
||||||
|
remove_tokens: str,
|
||||||
|
add_suffix: str,
|
||||||
|
) -> None:
|
||||||
|
self.enabled = enabled
|
||||||
|
self.filter = filter
|
||||||
|
self.remove_tokens = remove_tokens
|
||||||
|
self.add_suffix = add_suffix
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentalParams:
|
||||||
|
latent_symmetry: LatentSymmetryParams
|
||||||
|
prompt_editing: PromptEditingParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
latent_symmetry: LatentSymmetryParams,
|
||||||
|
prompt_editing: PromptEditingParams,
|
||||||
|
) -> None:
|
||||||
|
self.latent_symmetry = latent_symmetry
|
||||||
|
self.prompt_editing = prompt_editing
|
||||||
|
|
||||||
|
|
||||||
|
class RequestParams:
|
||||||
|
device: DeviceParams
|
||||||
|
image: ImageParams
|
||||||
|
size: Size | None
|
||||||
|
border: Border | None
|
||||||
|
upscale: UpscaleParams | None
|
||||||
|
highres: HighresParams | None
|
||||||
|
experimental: ExperimentalParams | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: DeviceParams,
|
||||||
|
image: ImageParams,
|
||||||
|
size: Optional[Size] = None,
|
||||||
|
border: Optional[Border] = None,
|
||||||
|
upscale: Optional[UpscaleParams] = None,
|
||||||
|
highres: Optional[HighresParams] = None,
|
||||||
|
experimental: Optional[ExperimentalParams] = None,
|
||||||
|
) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.image = image
|
||||||
|
self.size = size
|
||||||
|
self.border = border
|
||||||
|
self.upscale = upscale
|
||||||
|
self.highres = highres
|
||||||
|
self.experimental = experimental
|
||||||
|
|
||||||
|
|
||||||
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
|
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
|
||||||
if val is None:
|
if val is None:
|
||||||
return SizeChart.auto
|
return SizeChart.auto
|
||||||
|
|
|
@ -52,13 +52,7 @@ from .load import (
|
||||||
get_upscaling_models,
|
get_upscaling_models,
|
||||||
get_wildcard_data,
|
get_wildcard_data,
|
||||||
)
|
)
|
||||||
from .params import (
|
from .params import build_border, build_upscale, get_request_params, pipeline_from_json
|
||||||
build_border,
|
|
||||||
build_highres,
|
|
||||||
build_upscale,
|
|
||||||
pipeline_from_json,
|
|
||||||
pipeline_from_request,
|
|
||||||
)
|
|
||||||
from .utils import wrap_route
|
from .utils import wrap_route
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -261,15 +255,13 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("source image is required")
|
return error_reply("source image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
size = Size(source.width, source.height)
|
|
||||||
|
|
||||||
device, params, _size = pipeline_from_request(server, "img2img")
|
# TODO: look up the correct request field
|
||||||
upscale = build_upscale()
|
|
||||||
highres = build_highres()
|
|
||||||
source_filter = get_from_list(
|
source_filter = get_from_list(
|
||||||
request.args, "sourceFilter", list(get_source_filters().keys())
|
request.args, "sourceFilter", list(get_source_filters().keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: look up the correct request field
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
request.args,
|
request.args,
|
||||||
"strength",
|
"strength",
|
||||||
|
@ -278,20 +270,22 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
get_config_value("strength", "min"),
|
get_config_value("strength", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
params = get_request_params(server, JobType.IMG2IMG.value)
|
||||||
|
params.size = Size(source.width, source.height)
|
||||||
|
replace_wildcards(params.image, get_wildcard_data())
|
||||||
|
|
||||||
job_name = make_job_name("img2img", params, size, extras=[strength])
|
job_name = make_job_name(
|
||||||
|
JobType.IMG2IMG.value, params, params.size, extras=[strength]
|
||||||
|
)
|
||||||
queue = pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.IMG2IMG,
|
JobType.IMG2IMG,
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
upscale,
|
|
||||||
highres,
|
|
||||||
source,
|
source,
|
||||||
strength,
|
strength,
|
||||||
needs_device=device,
|
needs_device=params.device,
|
||||||
source_filter=source_filter,
|
source_filter=source_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -301,24 +295,17 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
|
|
||||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(server, "txt2img")
|
params = get_request_params()
|
||||||
upscale = build_upscale()
|
replace_wildcards(params.image, get_wildcard_data())
|
||||||
highres = build_highres()
|
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
|
||||||
|
|
||||||
job_name = make_job_name("txt2img", params, size)
|
|
||||||
|
|
||||||
|
job_name = make_job_name(JobType.TXT2IMG.value, params.image, params.size)
|
||||||
queue = pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.TXT2IMG,
|
JobType.TXT2IMG,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
needs_device=params.device,
|
||||||
upscale,
|
|
||||||
highres,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("txt2img job queued for: %s", job_name)
|
logger.info("txt2img job queued for: %s", job_name)
|
||||||
|
@ -343,6 +330,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
mask.alpha_composite(mask_top_layer)
|
mask.alpha_composite(mask_top_layer)
|
||||||
mask.convert(mode="L")
|
mask.convert(mode="L")
|
||||||
|
|
||||||
|
# TODO: look up the correct request field
|
||||||
full_res_inpaint = get_boolean(
|
full_res_inpaint = get_boolean(
|
||||||
request.args, "fullresInpaint", get_config_value("fullresInpaint")
|
request.args, "fullresInpaint", get_config_value("fullresInpaint")
|
||||||
)
|
)
|
||||||
|
@ -354,10 +342,8 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
get_config_value("fullresInpaintPadding", "min"),
|
get_config_value("fullresInpaintPadding", "min"),
|
||||||
)
|
)
|
||||||
|
|
||||||
device, params, _size = pipeline_from_request(server, "inpaint")
|
params = get_request_params(server, JobType.INPAINT.value)
|
||||||
expand = build_border()
|
replace_wildcards(params.image, get_wildcard_data())
|
||||||
upscale = build_upscale()
|
|
||||||
highres = build_highres()
|
|
||||||
|
|
||||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||||
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
||||||
|
@ -367,17 +353,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
tile_order = TileOrder.spiral
|
tile_order = TileOrder.spiral
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
|
||||||
|
|
||||||
job_name = make_job_name(
|
job_name = make_job_name(
|
||||||
"inpaint",
|
JobType.INPAINT.value,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
extras=[
|
extras=[
|
||||||
expand.left,
|
params.border.left,
|
||||||
expand.right,
|
params.border.right,
|
||||||
expand.top,
|
params.border.top,
|
||||||
expand.bottom,
|
params.border.bottom,
|
||||||
mask_filter.__name__,
|
mask_filter.__name__,
|
||||||
noise_source.__name__,
|
noise_source.__name__,
|
||||||
fill_color,
|
fill_color,
|
||||||
|
@ -391,19 +375,15 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
|
||||||
upscale,
|
|
||||||
highres,
|
|
||||||
source,
|
source,
|
||||||
mask,
|
mask,
|
||||||
expand,
|
|
||||||
noise_source,
|
noise_source,
|
||||||
mask_filter,
|
mask_filter,
|
||||||
fill_color,
|
fill_color,
|
||||||
tile_order,
|
tile_order,
|
||||||
full_res_inpaint,
|
full_res_inpaint,
|
||||||
full_res_inpaint_padding,
|
full_res_inpaint_padding,
|
||||||
needs_device=device,
|
needs_device=params.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("inpaint job queued for: %s", job_name)
|
logger.info("inpaint job queued for: %s", job_name)
|
||||||
|
@ -418,24 +398,18 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server)
|
params = get_request_params(server)
|
||||||
upscale = build_upscale()
|
replace_wildcards(params.image, get_wildcard_data())
|
||||||
highres = build_highres()
|
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
job_name = make_job_name("upscale", params.image, params.size)
|
||||||
|
|
||||||
job_name = make_job_name("upscale", params, size)
|
|
||||||
queue = pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.UPSCALE,
|
JobType.UPSCALE,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
|
||||||
upscale,
|
|
||||||
highres,
|
|
||||||
source,
|
source,
|
||||||
needs_device=device,
|
needs_device=params.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
@ -571,22 +545,18 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
sources.append(source)
|
sources.append(source)
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server)
|
params = get_request_params(server)
|
||||||
upscale = build_upscale()
|
|
||||||
|
|
||||||
job_name = make_job_name("blend", params, size)
|
job_name = make_job_name("blend", params.image, params.size)
|
||||||
queue = pool.submit(
|
queue = pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
JobType.BLEND,
|
JobType.BLEND,
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
|
||||||
upscale,
|
|
||||||
# TODO: highres
|
|
||||||
sources,
|
sources,
|
||||||
mask,
|
mask,
|
||||||
needs_device=device,
|
needs_device=params.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
@ -595,9 +565,9 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
|
|
||||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(server)
|
params = get_request_params(server)
|
||||||
|
|
||||||
job_name = make_job_name("txt2txt", params, size)
|
job_name = make_job_name("txt2txt", params.image, params.size)
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
queue = pool.submit(
|
queue = pool.submit(
|
||||||
|
@ -606,8 +576,7 @@ def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
run_txt2txt_pipeline,
|
run_txt2txt_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
needs_device=params.device,
|
||||||
needs_device=device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return job_reply(job_name, queue=queue)
|
return job_reply(job_name, queue=queue)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
|
@ -8,8 +8,12 @@ from ..diffusers.utils import random_seed
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
DeviceParams,
|
DeviceParams,
|
||||||
|
ExperimentalParams,
|
||||||
HighresParams,
|
HighresParams,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
|
LatentSymmetryParams,
|
||||||
|
PromptEditingParams,
|
||||||
|
RequestParams,
|
||||||
Size,
|
Size,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
|
@ -345,6 +349,79 @@ def build_highres(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_latent_symmetry(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> LatentSymmetryParams:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
|
enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry"))
|
||||||
|
|
||||||
|
gradient_start = get_and_clamp_float(
|
||||||
|
data,
|
||||||
|
"gradientStart",
|
||||||
|
get_config_value("gradientStart"),
|
||||||
|
get_config_value("gradientStart", "max"),
|
||||||
|
get_config_value("gradientStart", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
|
gradient_end = get_and_clamp_float(
|
||||||
|
data,
|
||||||
|
"gradientEnd",
|
||||||
|
get_config_value("gradientEnd"),
|
||||||
|
get_config_value("gradientEnd", "max"),
|
||||||
|
get_config_value("gradientEnd", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
|
line_of_symmetry = get_and_clamp_float(
|
||||||
|
data,
|
||||||
|
"lineOfSymmetry",
|
||||||
|
get_config_value("lineOfSymmetry"),
|
||||||
|
get_config_value("lineOfSymmetry", "max"),
|
||||||
|
get_config_value("lineOfSymmetry", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry)
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt_editing(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
|
enabled = get_boolean(data, "enabled", get_config_value("promptEditing"))
|
||||||
|
|
||||||
|
prompt_filter = data.get("promptFilter", "")
|
||||||
|
remove_tokens = data.get("removeTokens", "")
|
||||||
|
add_suffix = data.get("addSuffix", "")
|
||||||
|
|
||||||
|
return PromptEditingParams(enabled, prompt_filter, remove_tokens, add_suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def build_experimental(
|
||||||
|
data: Dict[str, str] = None,
|
||||||
|
) -> ExperimentalParams:
|
||||||
|
if data is None:
|
||||||
|
data = request.args
|
||||||
|
|
||||||
|
latent_symmetry_data = data.get("latentSymmetry", {})
|
||||||
|
latent_symmetry = build_latent_symmetry(latent_symmetry_data)
|
||||||
|
|
||||||
|
prompt_editing_data = data.get("promptEditing", {})
|
||||||
|
prompt_editing = build_prompt_editing(prompt_editing_data)
|
||||||
|
|
||||||
|
return ExperimentalParams(latent_symmetry, prompt_editing)
|
||||||
|
|
||||||
|
|
||||||
|
def is_json_request() -> bool:
|
||||||
|
return request.mimetype == "application/json"
|
||||||
|
|
||||||
|
|
||||||
|
def is_json_form_request() -> bool:
|
||||||
|
return request.mimetype == "multipart/form-data" and "json" in request.form
|
||||||
|
|
||||||
|
|
||||||
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
|
PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size]
|
||||||
|
|
||||||
|
|
||||||
|
@ -353,36 +430,11 @@ def pipeline_from_json(
|
||||||
data: Dict[str, Union[str, Dict[str, str]]],
|
data: Dict[str, Union[str, Dict[str, str]]],
|
||||||
default_pipeline: str = "txt2img",
|
default_pipeline: str = "txt2img",
|
||||||
) -> PipelineParams:
|
) -> PipelineParams:
|
||||||
"""
|
|
||||||
Like pipeline_from_request but expects a nested structure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
device = build_device(server, data.get("device", data))
|
device = build_device(server, data.get("device", data))
|
||||||
params = build_params(server, default_pipeline, data.get("params", data))
|
params = build_params(server, default_pipeline, data.get("params", data))
|
||||||
size = build_size(server, data.get("params", data))
|
size = build_size(server, data.get("params", data))
|
||||||
|
|
||||||
return (device, params, size)
|
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(
|
|
||||||
server: ServerContext,
|
|
||||||
default_pipeline: str = "txt2img",
|
|
||||||
) -> PipelineParams:
|
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
mime = request.mimetype
|
|
||||||
|
|
||||||
if mime == "application/json":
|
|
||||||
device, params, size = pipeline_from_json(
|
|
||||||
server, request.json, default_pipeline
|
|
||||||
)
|
|
||||||
elif mime == "multipart/form-data":
|
|
||||||
form_json = load_config_str(request.form.get("json"))
|
|
||||||
device, params, size = pipeline_from_json(server, form_json, default_pipeline)
|
|
||||||
else:
|
|
||||||
device = build_device(server, request.args)
|
|
||||||
params = build_params(server, default_pipeline, request.args)
|
|
||||||
size = build_size(server, request.args)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
"request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s",
|
||||||
user,
|
user,
|
||||||
|
@ -399,3 +451,39 @@ def pipeline_from_request(
|
||||||
)
|
)
|
||||||
|
|
||||||
return (device, params, size)
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_data(key: str | None = None) -> Any:
|
||||||
|
if is_json_request():
|
||||||
|
json = request.json
|
||||||
|
elif is_json_form_request():
|
||||||
|
json = load_config_str(request.form.get("json"))
|
||||||
|
else:
|
||||||
|
json = None
|
||||||
|
|
||||||
|
if key is not None and json is not None:
|
||||||
|
json = json.get(key)
|
||||||
|
|
||||||
|
return json or request.args
|
||||||
|
|
||||||
|
|
||||||
|
def get_request_params(
|
||||||
|
server: ServerContext, default_pipeline: str = None
|
||||||
|
) -> RequestParams:
|
||||||
|
data = get_request_data()
|
||||||
|
|
||||||
|
device, params, size = pipeline_from_json(server, default_pipeline)
|
||||||
|
border = build_border(data["border"])
|
||||||
|
upscale = build_upscale(data["upscale"])
|
||||||
|
highres = build_highres(data["highres"])
|
||||||
|
experimental = build_experimental(data["experimental"])
|
||||||
|
|
||||||
|
return RequestParams(
|
||||||
|
device,
|
||||||
|
params,
|
||||||
|
size=size,
|
||||||
|
border=border,
|
||||||
|
upscale=upscale,
|
||||||
|
highres=highres,
|
||||||
|
experimental=experimental,
|
||||||
|
)
|
||||||
|
|
|
@ -33,14 +33,6 @@ import {
|
||||||
import { range } from '../utils.js';
|
import { range } from '../utils.js';
|
||||||
import { ApiClient } from './base.js';
|
import { ApiClient } from './base.js';
|
||||||
|
|
||||||
const FORM_HEADERS = {
|
|
||||||
'Content-Type': 'multipart/form-data',
|
|
||||||
};
|
|
||||||
|
|
||||||
const JSON_HEADERS = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
};
|
|
||||||
|
|
||||||
export function equalResponse(a: JobResponse, b: JobResponse): boolean {
|
export function equalResponse(a: JobResponse, b: JobResponse): boolean {
|
||||||
return a.name === b.name;
|
return a.name === b.name;
|
||||||
}
|
}
|
||||||
|
@ -348,7 +340,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
|
|
||||||
const job = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body: form,
|
body: form,
|
||||||
headers: FORM_HEADERS,
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
@ -376,7 +367,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
|
|
||||||
const job = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body: form,
|
body: form,
|
||||||
headers: FORM_HEADERS,
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
@ -407,7 +397,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
|
|
||||||
const job = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body: form,
|
body: form,
|
||||||
headers: FORM_HEADERS,
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
@ -437,7 +426,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
|
|
||||||
const job = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body: form,
|
body: form,
|
||||||
headers: FORM_HEADERS,
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
@ -465,7 +453,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
|
|
||||||
const job = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body: form,
|
body: form,
|
||||||
headers: FORM_HEADERS,
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
@ -497,7 +484,6 @@ export function makeClient(root: string, batchInterval: number, token: Maybe<str
|
||||||
|
|
||||||
const job = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body: form,
|
body: form,
|
||||||
headers: FORM_HEADERS,
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
|
Loading…
Reference in New Issue