diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 07310e74..642ccf77 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -267,12 +267,20 @@ def load_pipeline( # update panorama params if params.is_panorama(): unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8 - logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap) + logger.debug( + "setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", + params.unet_tile, + unet_stride, + params.vae_tile, + params.vae_overlap, + ) pipe.set_window_size(params.unet_tile // 8, unet_stride) for vae in VAE_COMPONENTS: if hasattr(pipe, vae): - getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap) + getattr(pipe, vae).set_window_size( + params.vae_tile // 8, params.vae_overlap + ) run_gc([device]) diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index a0c92d07..7c46326c 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -303,7 +303,14 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix # 3.b. Encode region prompts regions = parse_regions(prompt) - region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = [] + region_embeds: List[ + Tuple[ + List[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + ] + ] = [] add_region_embeds: List[np.ndarray] = [] for _top, _left, _bottom, _right, _mode, region_prompt in regions: @@ -322,9 +329,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix current_region_embeds[0] = np.concatenate( (current_region_embeds[1], current_region_embeds[0]), axis=0 ) - add_region_embeds.append(np.concatenate( - (current_region_embeds[3], current_region_embeds[2]), axis=0 - )) + add_region_embeds.append( + np.concatenate( + (current_region_embeds[3], current_region_embeds[2]), axis=0 + ) + ) region_embeds.append(current_region_embeds) @@ -492,7 +501,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) latents_view_denoised = scheduler_output.prev_sample.numpy() - if mode: + if mode == "replace": value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] = 1 else: diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index b86be2b1..6c6ebc8e 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -43,8 +43,9 @@ def run_txt2img_pipeline( highres: HighresParams, ) -> None: # if using panorama, the pipeline will tile itself (views) - if params.is_panorama() or params.is_xl(): + if params.is_panorama(): tile_size = max(params.unet_tile, size.width, size.height) + logger.debug("adjusting tile size for panorama to %s", tile_size) else: tile_size = params.unet_tile diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 152b5bca..d06a28cd 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -3,7 +3,7 @@ from copy import deepcopy from logging import getLogger from math import ceil from re import Pattern, compile -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Literal, Optional, Tuple import numpy as np import torch @@ -21,6 +21,7 @@ CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") +REGION_TOKEN = compile(r"\])\>") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") @@ -446,7 +447,8 @@ def slice_prompt(prompt: str, slice: int) -> str: return prompt -Region = Tuple[int, int, int, int, bool, str] +Region = Tuple[int, int, int, int, Literal["add", "replace"], str] + def parse_regions(prompt: str) -> List[Region]: - return [] + return get_tokens_from_prompt(prompt, REGION_TOKEN) diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index f51b51f4..b59bb73a 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -2,16 +2,8 @@ import sys from functools import partial from logging import getLogger from os import path -from pathlib import Path -from typing import Dict, Optional, Union from urllib.parse import urlparse -from optimum.onnxruntime.modeling_diffusion import ( - ORTModel, - ORTStableDiffusionPipelineBase, -) - -from ..torch_before_ort import SessionOptions from ..utils import run_gc from .context import ServerContext