diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index b6c372f6..0bdc79d3 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -19,7 +19,7 @@ from .image import ( ) from .upscale import ( make_resrgan, - run_upscale_pipeline, + run_upscale_correction, upscale_gfpgan, upscale_resrgan, UpscaleParams, diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 094eac0e..9e8e8f62 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -249,7 +249,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): torch.randn(2).to(device=training_device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to( device=training_device, dtype=dtype), - False, + # TODO: needs to be Int or Long for upscaling, Bool for regular + 4, + # False, ), output_path=unet_path, ordered_input_names=["sample", "timestep", diff --git a/api/onnx_web/diffusion.py b/api/onnx_web/diffusion.py index f58dd5d7..7e65c00f 100644 --- a/api/onnx_web/diffusion.py +++ b/api/onnx_web/diffusion.py @@ -16,7 +16,7 @@ from .image import ( expand_image, ) from .upscale import ( - upscale_resrgan, + run_upscale_correction, UpscaleParams, ) from .utils import ( @@ -117,7 +117,7 @@ def run_txt2img_pipeline( num_inference_steps=params.steps, ) image = result.images[0] - image = run_upscale_pipeline(ctx, upscale, image) + image = run_upscale_correction(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -151,7 +151,7 @@ def run_img2img_pipeline( strength=strength, ) image = result.images[0] - image = run_upscale_pipeline(ctx, upscale, image) + image = run_upscale_correction(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -215,7 +215,7 @@ def run_inpaint_pipeline( else: print('output image size does not match source, skipping post-blend') - image = run_upscale_pipeline(ctx, upscale, image) + image = run_upscale_correction(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -234,7 +234,7 @@ def run_upscale_pipeline( upscale: UpscaleParams, source_image: Image ): - image = upscale_resrgan(ctx, upscale, source_image) + image = run_upscale_correction(ctx, upscale, source_image) dest = safer_join(ctx.output_path, output) image.save(dest) diff --git a/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py index 48077b74..970c6600 100644 --- a/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/onnx/pipeline_onnx_stable_diffusion_upscale.py @@ -3,13 +3,8 @@ from diffusers import ( OnnxRuntimeModel, StableDiffusionUpscalePipeline, ) -from diffusers.models import ( - CLIPTokenizer, -) -from diffusers.schedulers import ( - KarrasDiffusionSchedulers, -) from typing import ( + Any, Callable, Union, List, @@ -25,34 +20,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): self, vae: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, - tokenizer: CLIPTokenizer, + tokenizer: Any, unet: OnnxRuntimeModel, low_res_scheduler: DDPMScheduler, - scheduler: KarrasDiffusionSchedulers, + scheduler: Any, max_noise_level: int = 350, ): - super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level) + super().__init__(vae, text_encoder, tokenizer, unet, + low_res_scheduler, scheduler, max_noise_level) def __call__( self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, - List[PIL.Image.Image]] = None, - num_inference_steps: int = 75, - guidance_scale: float = 9.0, - noise_level: int = 20, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, - List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[ - int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, + *args, + **kwargs, ): - pass + super().__call__(*args, **kwargs) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index b5eb85f9..eea3d0be 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -1,4 +1,8 @@ from basicsr.archs.rrdbnet_arch import RRDBNet +from diffusers import ( + AutoencoderKL, + DDPMScheduler, +) from gfpgan import GFPGANer from os import path from PIL import Image @@ -127,12 +131,20 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: print('upscaling with Stable Diffusion') - pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(params.upscale_model) + model_path = '../models/%s' % params.upscale_model + # ValueError: Pipeline + # expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'}, + # but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed. + pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained( + model_path, + vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'), + low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'), + ) result = pipeline('', image=image) return result.images[0] -def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: +def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: print('running upscale pipeline') if params.scale > 1: