2023-01-28 20:56:06 +00:00
|
|
|
from enum import IntEnum
|
2023-02-18 21:44:39 +00:00
|
|
|
from logging import getLogger
|
2023-04-29 19:23:00 +00:00
|
|
|
from math import ceil
|
2023-02-18 21:44:39 +00:00
|
|
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-04-13 04:26:16 +00:00
|
|
|
from .models.meta import NetworkModel
|
2023-02-26 21:21:58 +00:00
|
|
|
from .torch_before_ort import GraphOptimizationLevel, SessionOptions
|
2023-02-18 21:44:39 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
2023-02-15 01:01:14 +00:00
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
Param = Union[str, int, float]
|
|
|
|
Point = Tuple[int, int]
|
|
|
|
|
|
|
|
|
2023-01-28 20:56:06 +00:00
|
|
|
class SizeChart(IntEnum):
|
2023-12-03 18:13:45 +00:00
|
|
|
micro = 64
|
2023-02-05 13:53:26 +00:00
|
|
|
mini = 128 # small tile for very expensive models
|
|
|
|
half = 256 # half tile for outpainting
|
|
|
|
auto = 512 # auto tile size
|
2023-01-28 20:56:06 +00:00
|
|
|
hd1k = 2**10
|
|
|
|
hd2k = 2**11
|
|
|
|
hd4k = 2**12
|
|
|
|
hd8k = 2**13
|
|
|
|
hd16k = 2**14
|
2023-07-02 23:54:10 +00:00
|
|
|
hd32k = 2**15
|
2023-01-28 20:56:06 +00:00
|
|
|
hd64k = 2**16
|
2023-09-15 13:40:56 +00:00
|
|
|
max = 2**32 # should be a reasonable upper limit for now
|
2023-01-28 20:56:06 +00:00
|
|
|
|
|
|
|
|
2023-02-12 00:00:18 +00:00
|
|
|
class TileOrder:
|
|
|
|
grid = "grid"
|
|
|
|
kernel = "kernel"
|
|
|
|
spiral = "spiral"
|
|
|
|
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
class Border:
|
|
|
|
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
|
|
|
self.left = left
|
|
|
|
self.right = right
|
|
|
|
self.top = top
|
|
|
|
self.bottom = bottom
|
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
def __str__(self) -> str:
|
2023-04-15 17:36:41 +00:00
|
|
|
return "(%s, %s, %s, %s)" % (self.left, self.right, self.top, self.bottom)
|
2023-01-28 23:09:19 +00:00
|
|
|
|
2023-07-10 06:02:30 +00:00
|
|
|
def isZero(self) -> bool:
|
|
|
|
return all(
|
|
|
|
value == 0 for value in (self.left, self.right, self.top, self.bottom)
|
|
|
|
)
|
|
|
|
|
2023-02-02 04:31:01 +00:00
|
|
|
def tojson(self):
|
|
|
|
return {
|
2023-02-05 13:53:26 +00:00
|
|
|
"left": self.left,
|
|
|
|
"right": self.right,
|
|
|
|
"top": self.top,
|
|
|
|
"bottom": self.bottom,
|
2023-02-02 04:31:01 +00:00
|
|
|
}
|
|
|
|
|
2023-02-18 22:27:48 +00:00
|
|
|
def with_args(self, **kwargs):
|
|
|
|
return Border(
|
|
|
|
kwargs.get("left", self.left),
|
|
|
|
kwargs.get("right", self.right),
|
|
|
|
kwargs.get("top", self.top),
|
|
|
|
kwargs.get("bottom", self.bottom),
|
|
|
|
)
|
|
|
|
|
2023-01-28 15:55:47 +00:00
|
|
|
@classmethod
|
|
|
|
def even(cls, all: int):
|
|
|
|
return Border(all, all, all, all)
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
|
|
|
|
class Size:
|
|
|
|
def __init__(self, width: int, height: int) -> None:
|
|
|
|
self.width = width
|
|
|
|
self.height = height
|
|
|
|
|
2023-02-18 11:44:43 +00:00
|
|
|
def __iter__(self):
|
|
|
|
return iter([self.width, self.height])
|
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
def __str__(self) -> str:
|
2023-02-05 13:53:26 +00:00
|
|
|
return "%sx%s" % (self.width, self.height)
|
2023-01-28 23:09:19 +00:00
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
def add_border(self, border: Border):
|
2023-02-05 13:53:26 +00:00
|
|
|
return Size(
|
|
|
|
border.left + self.width + border.right,
|
2023-03-10 23:59:16 +00:00
|
|
|
border.top + self.height + border.bottom,
|
2023-02-05 13:53:26 +00:00
|
|
|
)
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-08-20 20:19:07 +00:00
|
|
|
def max(self, width: int, height: int):
|
|
|
|
return Size(max(self.width, width), max(self.height, height))
|
|
|
|
|
2023-07-10 22:41:08 +00:00
|
|
|
def min(self, width: int, height: int):
|
|
|
|
return Size(min(self.width, width), min(self.height, height))
|
|
|
|
|
2023-04-29 20:40:26 +00:00
|
|
|
def round_to_tile(self, tile=512):
|
2023-04-29 19:23:00 +00:00
|
|
|
return Size(
|
2023-04-29 19:55:41 +00:00
|
|
|
ceil(self.width / tile) * tile,
|
|
|
|
ceil(self.height / tile) * tile,
|
2023-04-29 19:23:00 +00:00
|
|
|
)
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
def tojson(self) -> Dict[str, int]:
|
|
|
|
return {
|
2023-02-05 13:53:26 +00:00
|
|
|
"width": self.width,
|
2023-06-30 12:20:49 +00:00
|
|
|
"height": self.height,
|
2023-01-28 04:48:06 +00:00
|
|
|
}
|
|
|
|
|
2023-02-18 22:27:48 +00:00
|
|
|
def with_args(self, **kwargs):
|
|
|
|
return Size(
|
|
|
|
kwargs.get("width", self.width),
|
2023-06-30 12:20:49 +00:00
|
|
|
kwargs.get("height", self.height),
|
2023-02-18 22:27:48 +00:00
|
|
|
)
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-02-04 19:49:34 +00:00
|
|
|
class DeviceParams:
|
2023-02-05 13:53:26 +00:00
|
|
|
def __init__(
|
2023-02-18 21:44:39 +00:00
|
|
|
self,
|
|
|
|
device: str,
|
|
|
|
provider: str,
|
|
|
|
options: Optional[dict] = None,
|
|
|
|
optimizations: Optional[List[str]] = None,
|
2023-02-05 13:53:26 +00:00
|
|
|
) -> None:
|
2023-02-04 19:49:34 +00:00
|
|
|
self.device = device
|
|
|
|
self.provider = provider
|
|
|
|
self.options = options
|
2023-03-01 03:44:52 +00:00
|
|
|
self.optimizations = optimizations or []
|
2023-02-24 02:05:41 +00:00
|
|
|
self.sess_options_cache = None
|
2023-02-04 19:49:34 +00:00
|
|
|
|
2023-02-04 19:56:17 +00:00
|
|
|
def __str__(self) -> str:
|
2023-02-05 13:53:26 +00:00
|
|
|
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
2023-02-04 19:56:17 +00:00
|
|
|
|
2023-04-24 23:10:12 +00:00
|
|
|
def ort_provider(
|
|
|
|
self, model_type: Optional[str] = None
|
|
|
|
) -> Union[str, Tuple[str, Any]]:
|
2023-04-24 22:40:12 +00:00
|
|
|
if model_type is not None:
|
|
|
|
# check if model has been pinned to CPU
|
|
|
|
# TODO: check whether the CPU device is allowed
|
|
|
|
if f"onnx-cpu-{model_type}" in self.optimizations:
|
|
|
|
return "CPUExecutionProvider"
|
|
|
|
|
2023-02-15 00:57:50 +00:00
|
|
|
if self.options is None:
|
|
|
|
return self.provider
|
|
|
|
else:
|
2023-12-03 18:14:05 +00:00
|
|
|
return (self.provider, self.options)
|
2023-02-15 00:57:50 +00:00
|
|
|
|
2023-03-18 18:47:39 +00:00
|
|
|
def sess_options(self, cache=True) -> SessionOptions:
|
2023-03-18 18:39:04 +00:00
|
|
|
if cache and self.sess_options_cache is not None:
|
2023-02-24 02:05:41 +00:00
|
|
|
return self.sess_options_cache
|
|
|
|
|
2023-02-18 21:44:39 +00:00
|
|
|
sess = SessionOptions()
|
|
|
|
|
|
|
|
if "onnx-low-memory" in self.optimizations:
|
|
|
|
logger.debug("enabling ONNX low-memory optimizations")
|
|
|
|
sess.enable_cpu_mem_arena = False
|
|
|
|
sess.enable_mem_pattern = False
|
|
|
|
sess.enable_mem_reuse = False
|
|
|
|
|
2023-02-18 22:06:05 +00:00
|
|
|
if "onnx-graph-disable" in self.optimizations:
|
2023-02-18 21:47:31 +00:00
|
|
|
logger.debug("disabling all ONNX graph optimizations")
|
2023-02-18 21:44:39 +00:00
|
|
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
2023-02-18 22:06:05 +00:00
|
|
|
elif "onnx-graph-basic" in self.optimizations:
|
2023-02-18 21:47:31 +00:00
|
|
|
logger.debug("enabling basic ONNX graph optimizations")
|
2023-02-18 21:44:39 +00:00
|
|
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
2023-02-18 22:06:05 +00:00
|
|
|
elif "onnx-graph-all" in self.optimizations:
|
2023-02-18 21:47:31 +00:00
|
|
|
logger.debug("enabling all ONNX graph optimizations")
|
2023-02-18 21:44:39 +00:00
|
|
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
|
|
|
|
if "onnx-deterministic-compute" in self.optimizations:
|
2023-02-18 21:47:31 +00:00
|
|
|
logger.debug("enabling ONNX deterministic compute")
|
2023-02-18 21:44:39 +00:00
|
|
|
sess.use_deterministic_compute = True
|
2023-02-15 00:57:50 +00:00
|
|
|
|
2023-03-18 18:39:04 +00:00
|
|
|
if cache:
|
|
|
|
self.sess_options_cache = sess
|
|
|
|
|
2023-02-18 21:45:28 +00:00
|
|
|
return sess
|
|
|
|
|
2023-02-17 00:11:35 +00:00
|
|
|
def torch_str(self) -> str:
|
2023-02-05 13:53:26 +00:00
|
|
|
if self.device.startswith("cuda"):
|
2023-04-11 13:34:21 +00:00
|
|
|
if self.options is not None and "device_id" in self.options:
|
|
|
|
return f"{self.device}:{self.options['device_id']}"
|
|
|
|
|
2023-02-04 19:49:34 +00:00
|
|
|
return self.device
|
2023-04-11 13:34:21 +00:00
|
|
|
elif self.device.startswith("rocm"):
|
|
|
|
if self.options is not None and "device_id" in self.options:
|
|
|
|
return f"cuda:{self.options['device_id']}"
|
|
|
|
|
|
|
|
return "cuda"
|
2023-02-04 19:49:34 +00:00
|
|
|
else:
|
2023-02-05 13:53:26 +00:00
|
|
|
return "cpu"
|
2023-02-04 19:49:34 +00:00
|
|
|
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
class ImageParams:
|
2023-04-12 13:43:15 +00:00
|
|
|
model: str
|
2023-04-13 03:58:48 +00:00
|
|
|
pipeline: str
|
2023-04-12 13:43:15 +00:00
|
|
|
scheduler: str
|
|
|
|
prompt: str
|
|
|
|
cfg: float
|
|
|
|
steps: int
|
|
|
|
seed: int
|
|
|
|
negative_prompt: Optional[str]
|
|
|
|
eta: float
|
|
|
|
batch: int
|
2023-04-13 04:26:16 +00:00
|
|
|
control: Optional[NetworkModel]
|
2023-04-22 05:11:33 +00:00
|
|
|
input_prompt: str
|
2023-12-03 18:13:45 +00:00
|
|
|
input_negative_prompt: Optional[str]
|
2023-04-22 15:39:23 +00:00
|
|
|
loopback: int
|
2023-05-02 04:20:40 +00:00
|
|
|
tiled_vae: bool
|
2023-11-05 01:41:58 +00:00
|
|
|
unet_tile: int
|
|
|
|
unet_overlap: float
|
|
|
|
vae_tile: int
|
|
|
|
vae_overlap: float
|
2023-11-19 00:13:13 +00:00
|
|
|
denoise: int
|
2023-04-12 13:43:15 +00:00
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model: str,
|
2023-04-13 03:58:48 +00:00
|
|
|
pipeline: str,
|
2023-02-26 16:15:12 +00:00
|
|
|
scheduler: str,
|
2023-01-28 04:48:06 +00:00
|
|
|
prompt: str,
|
|
|
|
cfg: float,
|
|
|
|
steps: int,
|
2023-02-05 13:53:26 +00:00
|
|
|
seed: int,
|
2023-02-05 23:15:37 +00:00
|
|
|
negative_prompt: Optional[str] = None,
|
2023-02-20 05:29:26 +00:00
|
|
|
eta: float = 0.0,
|
2023-02-20 14:35:18 +00:00
|
|
|
batch: int = 1,
|
2023-04-13 04:26:16 +00:00
|
|
|
control: Optional[NetworkModel] = None,
|
2023-04-22 05:11:33 +00:00
|
|
|
input_prompt: Optional[str] = None,
|
2023-04-22 15:05:58 +00:00
|
|
|
input_negative_prompt: Optional[str] = None,
|
2023-04-22 15:39:23 +00:00
|
|
|
loopback: int = 0,
|
2023-05-02 04:20:40 +00:00
|
|
|
tiled_vae: bool = False,
|
2023-11-05 01:41:58 +00:00
|
|
|
unet_overlap: float = 0.25,
|
|
|
|
unet_tile: int = 512,
|
|
|
|
vae_overlap: float = 0.25,
|
|
|
|
vae_tile: int = 512,
|
2023-11-19 00:13:13 +00:00
|
|
|
denoise: int = 3,
|
2023-01-28 04:48:06 +00:00
|
|
|
) -> None:
|
|
|
|
self.model = model
|
2023-04-13 03:58:48 +00:00
|
|
|
self.pipeline = pipeline
|
2023-01-28 04:48:06 +00:00
|
|
|
self.scheduler = scheduler
|
|
|
|
self.prompt = prompt
|
|
|
|
self.cfg = cfg
|
2023-02-05 16:50:15 +00:00
|
|
|
self.steps = steps
|
2023-04-23 22:57:13 +00:00
|
|
|
self.seed = seed
|
|
|
|
self.negative_prompt = negative_prompt
|
2023-02-20 05:29:26 +00:00
|
|
|
self.eta = eta
|
2023-02-20 14:35:18 +00:00
|
|
|
self.batch = batch
|
2023-04-12 13:43:15 +00:00
|
|
|
self.control = control
|
2023-04-22 05:11:33 +00:00
|
|
|
self.input_prompt = input_prompt or prompt
|
2023-04-22 15:05:58 +00:00
|
|
|
self.input_negative_prompt = input_negative_prompt or negative_prompt
|
2023-04-22 15:39:23 +00:00
|
|
|
self.loopback = loopback
|
2023-05-02 04:20:40 +00:00
|
|
|
self.tiled_vae = tiled_vae
|
2023-11-05 01:41:58 +00:00
|
|
|
self.unet_overlap = unet_overlap
|
|
|
|
self.unet_tile = unet_tile
|
|
|
|
self.vae_overlap = vae_overlap
|
|
|
|
self.vae_tile = vae_tile
|
2023-11-19 00:13:13 +00:00
|
|
|
self.denoise = denoise
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-04-23 20:03:11 +00:00
|
|
|
def do_cfg(self):
|
|
|
|
return self.cfg > 1.0
|
|
|
|
|
2023-12-03 18:13:45 +00:00
|
|
|
def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
|
2023-04-27 12:22:00 +00:00
|
|
|
pipeline = pipeline or self.pipeline
|
|
|
|
|
|
|
|
# if the correct pipeline was already requested, simply use that
|
|
|
|
if group == pipeline:
|
|
|
|
return pipeline
|
|
|
|
|
|
|
|
# otherwise, check for additional allowed pipelines
|
|
|
|
if group == "img2img":
|
2023-09-10 16:26:55 +00:00
|
|
|
if pipeline in [
|
|
|
|
"controlnet",
|
|
|
|
"img2img-sdxl",
|
|
|
|
"lpw",
|
|
|
|
"panorama",
|
|
|
|
"panorama-sdxl",
|
|
|
|
"pix2pix",
|
|
|
|
]:
|
2023-04-27 12:22:00 +00:00
|
|
|
return pipeline
|
2023-08-22 01:46:05 +00:00
|
|
|
elif pipeline == "txt2img-sdxl":
|
|
|
|
return "img2img-sdxl"
|
2023-04-27 12:22:00 +00:00
|
|
|
elif group == "inpaint":
|
2023-06-11 12:38:58 +00:00
|
|
|
if pipeline in ["controlnet", "lpw", "panorama"]:
|
2023-04-27 12:22:00 +00:00
|
|
|
return pipeline
|
|
|
|
elif group == "txt2img":
|
2023-09-10 16:26:18 +00:00
|
|
|
if pipeline in ["lpw", "panorama", "panorama-sdxl", "txt2img-sdxl"]:
|
2023-04-27 12:22:00 +00:00
|
|
|
return pipeline
|
|
|
|
|
|
|
|
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
|
|
|
return group
|
|
|
|
|
2023-08-21 03:28:08 +00:00
|
|
|
def is_control(self):
|
|
|
|
return self.pipeline == "controlnet"
|
|
|
|
|
|
|
|
def is_lpw(self):
|
2023-04-13 03:58:48 +00:00
|
|
|
return self.pipeline == "lpw"
|
|
|
|
|
2023-08-21 03:28:08 +00:00
|
|
|
def is_panorama(self):
|
2023-09-10 16:26:18 +00:00
|
|
|
return self.pipeline in ["panorama", "panorama-sdxl"]
|
2023-08-21 03:28:08 +00:00
|
|
|
|
2023-08-26 04:33:41 +00:00
|
|
|
def is_pix2pix(self):
|
|
|
|
return self.pipeline == "pix2pix"
|
|
|
|
|
2023-08-21 03:28:08 +00:00
|
|
|
def is_xl(self):
|
|
|
|
return self.pipeline.endswith("-sdxl")
|
|
|
|
|
2023-02-05 03:17:39 +00:00
|
|
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
2023-01-28 04:48:06 +00:00
|
|
|
return {
|
2023-02-05 13:53:26 +00:00
|
|
|
"model": self.model,
|
2023-04-13 03:58:48 +00:00
|
|
|
"pipeline": self.pipeline,
|
2023-02-26 16:15:12 +00:00
|
|
|
"scheduler": self.scheduler,
|
2023-02-05 13:53:26 +00:00
|
|
|
"prompt": self.prompt,
|
2023-02-18 22:27:48 +00:00
|
|
|
"negative_prompt": self.negative_prompt,
|
2023-02-05 16:50:15 +00:00
|
|
|
"cfg": self.cfg,
|
|
|
|
"seed": self.seed,
|
2023-02-05 13:53:26 +00:00
|
|
|
"steps": self.steps,
|
2023-02-20 05:29:26 +00:00
|
|
|
"eta": self.eta,
|
2023-02-20 14:35:18 +00:00
|
|
|
"batch": self.batch,
|
2023-04-13 04:26:16 +00:00
|
|
|
"control": self.control.name if self.control is not None else "",
|
2023-04-22 05:11:33 +00:00
|
|
|
"input_prompt": self.input_prompt,
|
2023-04-22 15:05:58 +00:00
|
|
|
"input_negative_prompt": self.input_negative_prompt,
|
2023-04-22 15:39:23 +00:00
|
|
|
"loopback": self.loopback,
|
2023-05-02 04:20:40 +00:00
|
|
|
"tiled_vae": self.tiled_vae,
|
2023-11-05 01:41:58 +00:00
|
|
|
"unet_overlap": self.unet_overlap,
|
|
|
|
"unet_tile": self.unet_tile,
|
|
|
|
"vae_overlap": self.vae_overlap,
|
|
|
|
"vae_tile": self.vae_tile,
|
2023-11-19 00:13:13 +00:00
|
|
|
"denoise": self.denoise,
|
2023-01-28 04:48:06 +00:00
|
|
|
}
|
|
|
|
|
2023-02-18 22:27:48 +00:00
|
|
|
def with_args(self, **kwargs):
|
|
|
|
return ImageParams(
|
|
|
|
kwargs.get("model", self.model),
|
2023-04-13 03:58:48 +00:00
|
|
|
kwargs.get("pipeline", self.pipeline),
|
2023-02-18 22:27:48 +00:00
|
|
|
kwargs.get("scheduler", self.scheduler),
|
|
|
|
kwargs.get("prompt", self.prompt),
|
|
|
|
kwargs.get("cfg", self.cfg),
|
|
|
|
kwargs.get("steps", self.steps),
|
|
|
|
kwargs.get("seed", self.seed),
|
2023-04-23 22:57:13 +00:00
|
|
|
kwargs.get("negative_prompt", self.negative_prompt),
|
2023-02-20 05:29:26 +00:00
|
|
|
kwargs.get("eta", self.eta),
|
2023-02-20 14:35:18 +00:00
|
|
|
kwargs.get("batch", self.batch),
|
2023-04-12 13:43:15 +00:00
|
|
|
kwargs.get("control", self.control),
|
2023-04-22 05:11:33 +00:00
|
|
|
kwargs.get("input_prompt", self.input_prompt),
|
2023-04-22 15:05:58 +00:00
|
|
|
kwargs.get("input_negative_prompt", self.input_negative_prompt),
|
2023-04-22 15:39:23 +00:00
|
|
|
kwargs.get("loopback", self.loopback),
|
2023-05-02 04:20:40 +00:00
|
|
|
kwargs.get("tiled_vae", self.tiled_vae),
|
2023-11-05 01:41:58 +00:00
|
|
|
kwargs.get("unet_overlap", self.unet_overlap),
|
|
|
|
kwargs.get("unet_tile", self.unet_tile),
|
|
|
|
kwargs.get("vae_overlap", self.vae_overlap),
|
|
|
|
kwargs.get("vae_tile", self.vae_tile),
|
2023-11-19 00:13:13 +00:00
|
|
|
kwargs.get("denoise", self.denoise),
|
2023-02-18 22:27:48 +00:00
|
|
|
)
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
|
|
|
|
class StageParams:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-28 04:48:06 +00:00
|
|
|
Parameters for a chained pipeline stage
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-28 04:48:06 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name: Optional[str] = None,
|
|
|
|
outscale: int = 1,
|
2023-07-07 02:44:53 +00:00
|
|
|
tile_order: str = TileOrder.spiral,
|
2023-02-12 00:00:18 +00:00
|
|
|
tile_size: int = SizeChart.auto,
|
2023-01-28 04:48:06 +00:00
|
|
|
# batch_size: int = 1,
|
|
|
|
) -> None:
|
|
|
|
self.name = name
|
|
|
|
self.outscale = outscale
|
2023-02-12 00:00:18 +00:00
|
|
|
self.tile_order = tile_order
|
|
|
|
self.tile_size = tile_size
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-11-25 18:25:16 +00:00
|
|
|
def with_args(
|
|
|
|
self,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
return StageParams(
|
|
|
|
name=kwargs.get("name", self.name),
|
|
|
|
outscale=kwargs.get("outscale", self.outscale),
|
|
|
|
tile_order=kwargs.get("tile_order", self.tile_order),
|
|
|
|
tile_size=kwargs.get("tile_size", self.tile_size),
|
|
|
|
)
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
class UpscaleParams:
|
2023-01-28 04:48:06 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
upscale_model: str,
|
|
|
|
correction_model: Optional[str] = None,
|
|
|
|
denoise: float = 0.5,
|
|
|
|
faces=True,
|
2023-02-06 23:13:37 +00:00
|
|
|
face_outscale: int = 1,
|
2023-01-28 04:48:06 +00:00
|
|
|
face_strength: float = 0.5,
|
2023-04-11 13:26:21 +00:00
|
|
|
format: Literal["onnx", "pth"] = "onnx", # TODO: deprecated, remove
|
2023-01-28 04:48:06 +00:00
|
|
|
outscale: int = 1,
|
|
|
|
scale: int = 4,
|
|
|
|
pre_pad: int = 0,
|
|
|
|
tile_pad: int = 10,
|
2023-02-18 17:05:38 +00:00
|
|
|
upscale_order: Literal[
|
|
|
|
"correction-first", "correction-last", "correction-both"
|
|
|
|
] = "correction-first",
|
2023-01-28 04:48:06 +00:00
|
|
|
) -> None:
|
|
|
|
self.upscale_model = upscale_model
|
|
|
|
self.correction_model = correction_model
|
|
|
|
self.denoise = denoise
|
|
|
|
self.faces = faces
|
2023-02-06 23:13:37 +00:00
|
|
|
self.face_outscale = face_outscale
|
2023-01-28 04:48:06 +00:00
|
|
|
self.face_strength = face_strength
|
|
|
|
self.format = format
|
|
|
|
self.outscale = outscale
|
|
|
|
self.pre_pad = pre_pad
|
|
|
|
self.scale = scale
|
|
|
|
self.tile_pad = tile_pad
|
2023-02-18 16:59:59 +00:00
|
|
|
self.upscale_order = upscale_order
|
2023-01-28 04:48:06 +00:00
|
|
|
|
2023-01-31 23:08:30 +00:00
|
|
|
def rescale(self, scale: int):
|
|
|
|
return UpscaleParams(
|
|
|
|
self.upscale_model,
|
|
|
|
correction_model=self.correction_model,
|
|
|
|
denoise=self.denoise,
|
|
|
|
faces=self.faces,
|
2023-02-06 23:13:37 +00:00
|
|
|
face_outscale=self.face_outscale,
|
2023-01-31 23:08:30 +00:00
|
|
|
face_strength=self.face_strength,
|
|
|
|
format=self.format,
|
|
|
|
outscale=scale,
|
|
|
|
scale=scale,
|
|
|
|
pre_pad=self.pre_pad,
|
|
|
|
tile_pad=self.tile_pad,
|
2023-02-18 16:59:59 +00:00
|
|
|
upscale_order=self.upscale_order,
|
2023-01-31 23:08:30 +00:00
|
|
|
)
|
|
|
|
|
2023-01-28 04:48:06 +00:00
|
|
|
def resize(self, size: Size) -> Size:
|
2023-02-18 17:14:34 +00:00
|
|
|
face_outscale = self.face_outscale
|
|
|
|
if self.upscale_order == "correction-both":
|
|
|
|
face_outscale *= self.face_outscale
|
|
|
|
|
2023-02-06 23:59:34 +00:00
|
|
|
return Size(
|
2023-02-19 16:27:05 +00:00
|
|
|
size.width * self.outscale * face_outscale,
|
|
|
|
size.height * self.outscale * face_outscale,
|
2023-02-06 23:59:34 +00:00
|
|
|
)
|
2023-02-02 04:31:01 +00:00
|
|
|
|
|
|
|
def tojson(self):
|
|
|
|
return {
|
2023-02-05 16:50:15 +00:00
|
|
|
"upscale_model": self.upscale_model,
|
|
|
|
"correction_model": self.correction_model,
|
|
|
|
"denoise": self.denoise,
|
|
|
|
"faces": self.faces,
|
2023-02-06 23:13:37 +00:00
|
|
|
"face_outscale": self.face_outscale,
|
2023-02-05 16:50:15 +00:00
|
|
|
"face_strength": self.face_strength,
|
|
|
|
"format": self.format,
|
2023-02-05 13:53:26 +00:00
|
|
|
"outscale": self.outscale,
|
2023-02-05 16:50:15 +00:00
|
|
|
"pre_pad": self.pre_pad,
|
|
|
|
"scale": self.scale,
|
|
|
|
"tile_pad": self.tile_pad,
|
2023-02-18 16:59:59 +00:00
|
|
|
"upscale_order": self.upscale_order,
|
2023-02-02 04:31:01 +00:00
|
|
|
}
|
2023-02-18 22:27:48 +00:00
|
|
|
|
|
|
|
def with_args(self, **kwargs):
|
2023-02-19 04:36:26 +00:00
|
|
|
return UpscaleParams(
|
2023-02-18 22:27:48 +00:00
|
|
|
kwargs.get("upscale_model", self.upscale_model),
|
|
|
|
kwargs.get("correction_model", self.correction_model),
|
|
|
|
kwargs.get("denoise", self.denoise),
|
|
|
|
kwargs.get("faces", self.faces),
|
|
|
|
kwargs.get("face_outscale", self.face_outscale),
|
|
|
|
kwargs.get("face_strength", self.face_strength),
|
|
|
|
kwargs.get("format", self.format),
|
|
|
|
kwargs.get("outscale", self.outscale),
|
|
|
|
kwargs.get("scale", self.scale),
|
2023-04-01 19:59:59 +00:00
|
|
|
kwargs.get("pre_pad", self.pre_pad),
|
2023-02-18 22:27:48 +00:00
|
|
|
kwargs.get("tile_pad", self.tile_pad),
|
|
|
|
kwargs.get("upscale_order", self.upscale_order),
|
|
|
|
)
|
2023-04-01 16:26:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
class HighresParams:
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-07-03 00:07:59 +00:00
|
|
|
enabled: bool,
|
2023-04-01 16:26:10 +00:00
|
|
|
scale: int,
|
|
|
|
steps: int,
|
|
|
|
strength: float,
|
2023-04-01 19:26:45 +00:00
|
|
|
method: Literal["bilinear", "lanczos", "upscale"] = "lanczos",
|
2023-04-10 00:53:00 +00:00
|
|
|
iterations: int = 1,
|
2023-04-01 16:26:10 +00:00
|
|
|
):
|
2023-07-03 00:07:59 +00:00
|
|
|
self.enabled = enabled
|
2023-04-01 16:26:10 +00:00
|
|
|
self.scale = scale
|
|
|
|
self.steps = steps
|
|
|
|
self.strength = strength
|
2023-04-01 19:26:45 +00:00
|
|
|
self.method = method
|
2023-04-10 00:53:00 +00:00
|
|
|
self.iterations = iterations
|
2023-04-01 16:26:10 +00:00
|
|
|
|
2023-11-25 20:02:42 +00:00
|
|
|
def outscale(self) -> int:
|
|
|
|
return self.scale**self.iterations
|
|
|
|
|
2023-04-01 16:26:10 +00:00
|
|
|
def resize(self, size: Size) -> Size:
|
2023-11-25 20:02:42 +00:00
|
|
|
outscale = self.outscale()
|
2023-04-10 01:34:10 +00:00
|
|
|
return Size(
|
2023-11-25 20:02:42 +00:00
|
|
|
size.width * outscale,
|
|
|
|
size.height * outscale,
|
2023-04-10 01:34:10 +00:00
|
|
|
)
|
2023-04-01 16:26:10 +00:00
|
|
|
|
|
|
|
def tojson(self):
|
|
|
|
return {
|
2023-07-03 00:07:59 +00:00
|
|
|
"enabled": self.enabled,
|
2023-04-10 00:53:00 +00:00
|
|
|
"iterations": self.iterations,
|
2023-04-01 19:26:45 +00:00
|
|
|
"method": self.method,
|
2023-04-01 16:26:10 +00:00
|
|
|
"scale": self.scale,
|
|
|
|
"steps": self.steps,
|
|
|
|
"strength": self.strength,
|
|
|
|
}
|