from logging import getLogger from typing import Dict, Optional, Tuple import numpy as np from flask import request from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..params import ( Border, DeviceParams, HighresParams, ImageParams, Size, UpscaleParams, ) from ..utils import ( get_and_clamp_float, get_and_clamp_int, get_boolean, get_from_list, get_not_empty, ) from .context import ServerContext from .load import ( get_available_platforms, get_config_value, get_correction_models, get_highres_methods, get_network_models, get_upscaling_models, ) from .utils import get_model_path logger = getLogger(__name__) def build_device( server: ServerContext, data: Dict[str, str], ) -> Optional[DeviceParams]: # platform stuff device = None device_name = data.get("platform") if device_name is not None and device_name != "any": for platform in get_available_platforms(): if platform.device == device_name: device = platform return device def build_params( server: ServerContext, default_pipeline: str, data: Dict[str, str], ) -> ImageParams: # diffusion model model = get_not_empty(data, "model", get_config_value("model")) model_path = get_model_path(server, model) control = None control_name = data.get("control") for network in get_network_models(): if network.name == control_name: control = network # pipeline stuff pipeline = get_from_list( data, "pipeline", get_available_pipelines(), default_pipeline ) scheduler = get_from_list(data, "scheduler", get_pipeline_schedulers()) if scheduler is None: scheduler = get_config_value("scheduler") # prompt does not come from config prompt = data.get("prompt", "") negative_prompt = data.get("negativePrompt", None) if negative_prompt is not None and negative_prompt.strip() == "": negative_prompt = None # image params batch = get_and_clamp_int( data, "batch", get_config_value("batch"), get_config_value("batch", "max"), get_config_value("batch", "min"), ) cfg = get_and_clamp_float( data, "cfg", get_config_value("cfg"), get_config_value("cfg", "max"), get_config_value("cfg", "min"), ) eta = get_and_clamp_float( data, "eta", get_config_value("eta"), get_config_value("eta", "max"), get_config_value("eta", "min"), ) loopback = get_and_clamp_int( data, "loopback", get_config_value("loopback"), get_config_value("loopback", "max"), get_config_value("loopback", "min"), ) steps = get_and_clamp_int( data, "steps", get_config_value("steps"), get_config_value("steps", "max"), get_config_value("steps", "min"), ) tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE")) tiles = get_and_clamp_int( data, "tiles", get_config_value("tiles"), get_config_value("tiles", "max"), get_config_value("tiles", "min"), ) overlap = get_and_clamp_float( data, "overlap", get_config_value("overlap"), get_config_value("overlap", "max"), get_config_value("overlap", "min"), ) stride = get_and_clamp_int( data, "stride", get_config_value("stride"), get_config_value("stride", "max"), get_config_value("stride", "min"), ) if stride > tiles: logger.info("limiting stride to tile size, %s > %s", stride, tiles) stride = tiles seed = int(data.get("seed", -1)) if seed == -1: # this one can safely use np.random because it produces a single value seed = np.random.randint(np.iinfo(np.int32).max) params = ImageParams( model_path, pipeline, scheduler, prompt, cfg, steps, seed, eta=eta, negative_prompt=negative_prompt, batch=batch, control=control, loopback=loopback, tiled_vae=tiled_vae, tiles=tiles, overlap=overlap, stride=stride, ) return params def build_size( server: ServerContext, data: Dict[str, str], ) -> Size: height = get_and_clamp_int( data, "height", get_config_value("height"), get_config_value("height", "max"), get_config_value("height", "min"), ) width = get_and_clamp_int( data, "width", get_config_value("width"), get_config_value("width", "max"), get_config_value("width", "min"), ) return Size(width, height) def build_border( data: Dict[str, str] = None, ) -> Border: if data is None: data = request.args left = get_and_clamp_int( data, "left", get_config_value("left"), get_config_value("left", "max"), get_config_value("left", "min"), ) right = get_and_clamp_int( data, "right", get_config_value("right"), get_config_value("right", "max"), get_config_value("right", "min"), ) top = get_and_clamp_int( data, "top", get_config_value("top"), get_config_value("top", "max"), get_config_value("top", "min"), ) bottom = get_and_clamp_int( data, "bottom", get_config_value("bottom"), get_config_value("bottom", "max"), get_config_value("bottom", "min"), ) return Border(left, right, top, bottom) def build_upscale( data: Dict[str, str] = None, ) -> UpscaleParams: if data is None: data = request.args denoise = get_and_clamp_float( data, "denoise", get_config_value("denoise"), get_config_value("denoise", "max"), get_config_value("denoise", "min"), ) scale = get_and_clamp_int( data, "scale", get_config_value("scale"), get_config_value("scale", "max"), get_config_value("scale", "min"), ) outscale = get_and_clamp_int( data, "outscale", get_config_value("outscale"), get_config_value("outscale", "max"), get_config_value("outscale", "min"), ) upscaling = get_from_list(data, "upscaling", get_upscaling_models()) correction = get_from_list(data, "correction", get_correction_models()) faces = get_not_empty(data, "faces", "false") == "true" face_outscale = get_and_clamp_int( data, "faceOutscale", get_config_value("faceOutscale"), get_config_value("faceOutscale", "max"), get_config_value("faceOutscale", "min"), ) face_strength = get_and_clamp_float( data, "faceStrength", get_config_value("faceStrength"), get_config_value("faceStrength", "max"), get_config_value("faceStrength", "min"), ) upscale_order = data.get("upscaleOrder", "correction-first") return UpscaleParams( upscaling, correction_model=correction, denoise=denoise, faces=faces, face_outscale=face_outscale, face_strength=face_strength, format="onnx", outscale=outscale, scale=scale, upscale_order=upscale_order, ) def build_highres( data: Dict[str, str] = None, ) -> HighresParams: if data is None: data = request.args enabled = get_boolean(data, "highres", get_config_value("highres")) iterations = get_and_clamp_int( data, "highresIterations", get_config_value("highresIterations"), get_config_value("highresIterations", "max"), get_config_value("highresIterations", "min"), ) method = get_from_list(data, "highresMethod", get_highres_methods()) scale = get_and_clamp_int( data, "highresScale", get_config_value("highresScale"), get_config_value("highresScale", "max"), get_config_value("highresScale", "min"), ) steps = get_and_clamp_int( data, "highresSteps", get_config_value("highresSteps"), get_config_value("highresSteps", "max"), get_config_value("highresSteps", "min"), ) strength = get_and_clamp_float( data, "highresStrength", get_config_value("highresStrength"), get_config_value("highresStrength", "max"), get_config_value("highresStrength", "min"), ) return HighresParams( enabled, scale, steps, strength, method=method, iterations=iterations, ) PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] def pipeline_from_json( server: ServerContext, data: Dict[str, str], default_pipeline: str = "txt2img", ) -> PipelineParams: """ Like pipeline_from_request but expects a nested structure. """ device = build_device(server, data.get("device", data)) params = build_params(server, default_pipeline, data.get("params", data)) size = build_size(server, data.get("params", data)) return (device, params, size) def pipeline_from_request( server: ServerContext, default_pipeline: str = "txt2img", ) -> PipelineParams: user = request.remote_addr device = build_device(server, request.args) params = build_params(server, default_pipeline, request.args) size = build_size(server, request.args) logger.info( "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", user, params.steps, params.scheduler, params.model, params.pipeline, device or "any device", size.width, size.height, params.cfg, params.seed, params.prompt, ) return (device, params, size)