diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 799b7b8e..bd283b1a 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -39,20 +39,26 @@ class StageResult: def __len__(self) -> int: if self.arrays is not None: return len(self.arrays) - else: + elif self.images is not None: return len(self.images) + else: + raise ValueError("empty stage result") def as_numpy(self) -> List[np.ndarray]: if self.arrays is not None: return self.arrays - - return [np.array(i) for i in self.images] + elif self.images is not None: + return [np.array(i) for i in self.images] + else: + raise ValueError("empty stage result") def as_image(self) -> List[Image.Image]: if self.images is not None: return self.images - - return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays] + elif self.arrays is not None: + return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays] + else: + raise ValueError("empty stage result") def shape_mode(arr: np.ndarray) -> str: diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 6ade9580..08c07759 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -6,7 +6,14 @@ import numpy as np from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + ImageParams, + Size, + SizeChart, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext @@ -17,7 +24,7 @@ logger = getLogger(__name__) class UpscaleBSRGANStage(BaseStage): - max_tile = 64 + max_tile = SizeChart.micro def load( self, diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index 62bd6102..ef7d421f 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import WorkerContext @@ -17,7 +17,7 @@ logger = getLogger(__name__) class UpscaleSwinIRStage(BaseStage): - max_tile = 64 + max_tile = SizeChart.micro def load( self, diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index feaaeab3..810cf0e7 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -43,7 +43,6 @@ logger = logging.get_logger(__name__) # inpaint constants NUM_UNET_INPUT_CHANNELS = 9 -NUM_LATENT_CHANNELS = 4 DEFAULT_WINDOW = 32 DEFAULT_STRIDE = 8 @@ -1215,7 +1214,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, ) - num_channels_latents = NUM_LATENT_CHANNELS + num_channels_latents = LATENT_CHANNELS latents_shape = ( batch_size * num_images_per_prompt, num_channels_latents, diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index ff39d371..b9dc4394 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -88,8 +88,8 @@ def expand_prompt( negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, - skip_clip_states: Optional[int] = 0, -) -> "np.NDArray": + skip_clip_states: int = 0, +) -> np.ndarray: # self provides: # tokenizer: CLIPTokenizer # encoder: OnnxRuntimeModel @@ -144,6 +144,7 @@ def expand_prompt( last_state, _pooled_output, *hidden_states = text_result if skip_clip_states > 0: + # TODO: why is this normalized? layer_norm = torch.nn.LayerNorm(last_state.shape[2]) norm_state = layer_norm( torch.from_numpy( diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 34f1d070..d5e5ebff 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -14,6 +14,7 @@ Point = Tuple[int, int] class SizeChart(IntEnum): + micro = 64 mini = 128 # small tile for very expensive models half = 256 # half tile for outpainting auto = 512 # auto tile size @@ -201,7 +202,7 @@ class ImageParams: batch: int control: Optional[NetworkModel] input_prompt: str - input_negative_prompt: str + input_negative_prompt: Optional[str] loopback: int tiled_vae: bool unet_tile: int @@ -257,7 +258,7 @@ class ImageParams: def do_cfg(self): return self.cfg > 1.0 - def get_valid_pipeline(self, group: str, pipeline: str = None) -> str: + def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str: pipeline = pipeline or self.pipeline # if the correct pipeline was already requested, simply use that diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 67e7847d..6ef953c3 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -20,7 +20,7 @@ SAFE_CHARS = "._-" def split_list(val: str) -> List[str]: parts = [part.strip() for part in val.split(",")] - return [part for part in parts if len(part.strip()) > 0] + return [part for part in parts if len(part) > 0] def base_join(base: str, tail: str) -> str: @@ -75,13 +75,13 @@ def get_from_list( def get_from_map( - args: Any, key: str, values: Dict[str, TElem], default: TElem + args: Any, key: str, values: Dict[str, TElem], defaultKey: str ) -> TElem: - selected = args.get(key, default) + selected = args.get(key, defaultKey) if selected in values: return values[selected] else: - return values[default] + return values[defaultKey] def get_not_empty(args: Any, key: str, default: TElem) -> TElem: @@ -209,6 +209,8 @@ def load_config(file: str) -> Dict: return load_yaml(file) elif ext in [".json"]: return load_json(file) + else: + raise ValueError("unknown config file extension") def load_config_str(raw: str) -> Dict: diff --git a/api/pyproject.toml b/api/pyproject.toml index f0c3b689..bb47f16d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -9,12 +9,14 @@ skip_glob = ["*/lpw.py"] [tool.mypy] # ignore_missing_imports = true exclude = [ - "onnx_web.diffusers.lpw_stable_diffusion_onnx" + "onnx_web.diffusers.pipelines.controlnet", + "onnx_web.diffusers.pipelines.lpw", + "onnx_web.diffusers.pipelines.pix2pix" ] [[tool.mypy.overrides]] module = [ -"arpeggio", + "arpeggio", "basicsr.archs.rrdbnet_arch", "basicsr.utils.download_util", "basicsr.utils", @@ -27,8 +29,10 @@ module = [ "compel", "controlnet_aux", "cv2", + "debugpy", "diffusers", "diffusers.configuration_utils", + "diffusers.image_processor", "diffusers.loaders", "diffusers.models.attention_processor", "diffusers.models.autoencoder_kl", @@ -44,6 +48,7 @@ module = [ "diffusers.pipelines.pipeline_utils", "diffusers.pipelines.stable_diffusion", "diffusers.pipelines.stable_diffusion.convert_from_ckpt", + "diffusers.pipelines.stable_diffusion_xl", "diffusers.schedulers", "diffusers.utils.logging", "facexlib.utils", @@ -56,6 +61,11 @@ module = [ "mediapipe", "onnxruntime", "onnxruntime.transformers.float16", + "optimum.exporters.onnx", + "optimum.onnxruntime", + "optimum.onnxruntime.modeling_diffusion", + "optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img", + "optimum.pipelines.diffusers.pipeline_utils", "piexif", "piexif.helper", "realesrgan",