diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 7cf5d069..bbd3ca9e 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -19,6 +19,15 @@ from .persist_disk import ( from .persist_s3 import ( persist_s3, ) +from .reduce_crop import ( + reduce_crop, +) +from .reduce_thumbnail import ( + reduce_thumbnail, +) +from .source_noise import ( + source_noise, +) from .source_txt2img import ( source_txt2img, ) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 954fc794..56e08233 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -70,11 +70,11 @@ class ChainPipeline: kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} - logger.info('running stage %s on result image with dimensions %sx%s, %s', + logger.info('running stage %s on image with dimensions %sx%s, %s', name, image.width, image.height, kwargs.keys()) if image.width > stage_params.tile_size or image.height > stage_params.tile_size: - logger.info('source image larger than tile size of %s, tiling stage', + logger.info('image larger than tile size of %s, tiling stage', stage_params.tile_size) def stage_tile(tile: Image.Image, _dims) -> Image.Image: @@ -89,7 +89,7 @@ class ChainPipeline: image = process_tile_grid( image, stage_params.tile_size, stage_params.outscale, [stage_tile]) else: - logger.info('source image within tile size, running stage') + logger.info('image within tile size, running stage') image = stage_pipe(ctx, stage_params, params, image, **kwargs) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 6e74429b..6c25e1db 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,7 +1,6 @@ from logging import getLogger from PIL import Image - from ..params import ( ImageParams, StageParams, diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py new file mode 100644 index 00000000..3e3a5f96 --- /dev/null +++ b/api/onnx_web/chain/reduce_crop.py @@ -0,0 +1,30 @@ +from logging import getLogger +from PIL import Image + +from ..params import ( + ImageParams, + Size, + StageParams, +) +from ..utils import ( + ServerContext, +) + +logger = getLogger(__name__) + + +def reduce_crop( + ctx: ServerContext, + _stage: StageParams, + _params: ImageParams, + source_image: Image.Image, + *, + origin: Size, + size: Size, + **kwargs, +) -> Image.Image: + image = source_image.crop( + (origin.width, origin.height, size.width, size.height)) + logger.info('created thumbnail with dimensions: %sx%s', + image.width, image.height) + return image diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py new file mode 100644 index 00000000..8c3f0e19 --- /dev/null +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -0,0 +1,28 @@ +from logging import getLogger +from PIL import Image + +from ..params import ( + ImageParams, + Size, + StageParams, +) +from ..utils import ( + ServerContext, +) + +logger = getLogger(__name__) + + +def reduce_thumbnail( + ctx: ServerContext, + _stage: StageParams, + _params: ImageParams, + source_image: Image.Image, + *, + size: Size, + **kwargs, +) -> Image.Image: + image = source_image.thumbnail((size.width, size.height)) + logger.info('created thumbnail with dimensions: %sx%s', + image.width, image.height) + return image diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py new file mode 100644 index 00000000..33c6c226 --- /dev/null +++ b/api/onnx_web/chain/source_noise.py @@ -0,0 +1,38 @@ +from logging import getLogger +from PIL import Image +from typing import Callable + +from ..params import ( + ImageParams, + Size, + StageParams, +) +from ..utils import ( + ServerContext, +) + + +logger = getLogger(__name__) + + +def source_noise( + ctx: ServerContext, + stage: StageParams, + params: ImageParams, + source_image: Image.Image, + *, + size: Size, + noise_source: Callable, + **kwargs, +) -> Image.Image: + prompt = prompt or params.prompt + logger.info('generating image from noise source') + + if source_image is not None: + logger.warn( + 'a source image was passed to a noise stage, but will be discarded') + + output = noise_source(source_image, (size.width, size.height), (0, 0)) + + logger.info('final output image size: %sx%s', output.width, output.height) + return output diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 6405d768..08df9462 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -1,6 +1,4 @@ from diffusers import ( - AutoencoderKL, - DDPMScheduler, StableDiffusionUpscalePipeline, ) from logging import getLogger @@ -40,19 +38,9 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams): return last_pipeline_instance if upscale.format == 'onnx': - # ValueError: Pipeline - # expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'}, - # but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed. - pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained( - model_path, - vae=AutoencoderKL.from_pretrained( - model_path, subfolder='vae_encoder'), - low_res_scheduler=DDPMScheduler.from_pretrained( - model_path, subfolder='scheduler'), - ) + pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path) else: - pipeline = StableDiffusionUpscalePipeline.from_pretrained( - 'stabilityai/stable-diffusion-x4-upscaler') + pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path) last_pipeline_instance = pipeline last_pipeline_params = cache_params diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 443b0d11..2e390425 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -6,6 +6,7 @@ from diffusers import ( OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline, + StableDiffusionUpscalePipeline, ) from logging import getLogger from onnx import load, save_model @@ -202,7 +203,7 @@ def onnx_export( @torch.no_grad() -def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): +def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False): ''' From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py ''' @@ -212,6 +213,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): # diffusers go into a directory rather than .onnx file logger.info('converting Diffusers model: %s -> %s/', name, dest_path) + if single_vae: + logger.info('converting model with single VAE') + if path.isdir(dest_path): logger.info('ONNX model already exists, skipping.') return @@ -295,50 +299,75 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): ) del pipeline.unet - # VAE ENCODER - vae_encoder = pipeline.vae - vae_in_channels = vae_encoder.config.in_channels - vae_sample_size = vae_encoder.config.sample_size - # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( - sample, return_dict)[0].sample() - onnx_export( - vae_encoder, - model_args=( - torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( - device=training_device, dtype=dtype), - False, - ), - output_path=output_path / "vae_encoder" / "model.onnx", - ordered_input_names=["sample", "return_dict"], - output_names=["latent_sample"], - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, - opset=opset, - ) + if single_vae: + # SINGLE VAE + vae_only = pipeline.vae + vae_in_channels = vae_only.config.in_channels + vae_sample_size = vae_only.config.sample_size + # need to get the raw tensor output (sample) from the encoder + vae_only.forward = lambda sample, return_dict: vae_only.encode( + sample, return_dict)[0].sample() + onnx_export( + vae_only, + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( + device=training_device, dtype=dtype), + False, + ), + output_path=output_path / "vae" / "model.onnx", + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + else: + # VAE ENCODER + vae_encoder = pipeline.vae + vae_in_channels = vae_encoder.config.in_channels + vae_sample_size = vae_encoder.config.sample_size + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( + sample, return_dict)[0].sample() + onnx_export( + vae_encoder, + model_args=( + torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( + device=training_device, dtype=dtype), + False, + ), + output_path=output_path / "vae_encoder" / "model.onnx", + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=( + torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( + device=training_device, dtype=dtype), + False, + ), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) - # VAE DECODER - vae_decoder = pipeline.vae - vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels - # forward only through the decoder part - vae_decoder.forward = vae_encoder.decode - onnx_export( - vae_decoder, - model_args=( - torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( - device=training_device, dtype=dtype), - False, - ), - output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample", "return_dict"], - output_names=["sample"], - dynamic_axes={ - "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, - opset=opset, - ) del pipeline.vae # SAFETY CHECKER @@ -376,20 +405,32 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): safety_checker = None feature_extractor = None - onnx_pipeline = OnnxStableDiffusionPipeline( - vae_encoder=OnnxRuntimeModel.from_pretrained( - output_path / "vae_encoder"), - vae_decoder=OnnxRuntimeModel.from_pretrained( - output_path / "vae_decoder"), - text_encoder=OnnxRuntimeModel.from_pretrained( - output_path / "text_encoder"), - tokenizer=pipeline.tokenizer, - unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), - scheduler=pipeline.scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=safety_checker is not None, - ) + if single_vae: + onnx_pipeline = StableDiffusionUpscalePipeline( + vae=OnnxRuntimeModel.from_pretrained( + output_path / "vae"), + text_encoder=OnnxRuntimeModel.from_pretrained( + output_path / "text_encoder"), + tokenizer=pipeline.tokenizer, + low_res_scheduler=pipeline.scheduler, + unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), + scheduler=pipeline.scheduler, + ) + else: + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained( + output_path / "vae_encoder"), + vae_decoder=OnnxRuntimeModel.from_pretrained( + output_path / "vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained( + output_path / "text_encoder"), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), + scheduler=pipeline.scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=safety_checker is not None, + ) logger.info('exporting ONNX model') @@ -398,8 +439,15 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str): del pipeline del onnx_pipeline - _ = OnnxStableDiffusionPipeline.from_pretrained( - output_path, provider="CPUExecutionProvider") + + if single_vae: + _ = StableDiffusionUpscalePipeline.from_pretrained( + output_path, provider="CPUExecutionProvider" + ) + else: + _ = OnnxStableDiffusionPipeline.from_pretrained( + output_path, provider="CPUExecutionProvider") + logger.info("ONNX pipeline is loadable") @@ -409,7 +457,8 @@ def load_models(args, models: Models): if source[0] in args.skip: logger.info('Skipping model: %s', source[0]) else: - convert_diffuser(*source, args.opset, args.half, args.token) + single_vae = 'upscaling' in source[0] + convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae) if args.upscaling: for source in models.get('upscaling'): diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 18c79cc3..bc1198f9 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -32,6 +32,9 @@ from .chain import ( correct_gfpgan, persist_disk, persist_s3, + reduce_thumbnail, + reduce_crop, + source_noise, source_txt2img, upscale_outpaint, upscale_resrgan, @@ -61,7 +64,6 @@ from .params import ( Border, ImageParams, Size, - SizeChart, StageParams, UpscaleParams, ) @@ -129,6 +131,9 @@ chain_stages = { 'correct-gfpgan': correct_gfpgan, 'persist-disk': persist_disk, 'persist-s3': persist_s3, + 'reduce-crop': reduce_crop, + 'reduce-thumbnail': reduce_thumbnail, + 'source-noise': source_noise, 'source-txt2img': source_txt2img, 'upscale-outpaint': upscale_outpaint, 'upscale-resrgan': upscale_resrgan, diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index f80a3c2b..e73f770c 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -56,6 +56,9 @@ "spinalcase", "stabilityai", "stringcase", + "uncond", + "unet", + "untruncated", "upsampler", "upscaling", "venv",