diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 0775f3e7..2408885c 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -107,48 +107,6 @@ def get_tile_latents( return full_latents[:, :, y:yt, x:xt] -def optimize_pipeline( - server: ServerContext, - pipe: StableDiffusionPipeline, -) -> None: - if "diffusers-attention-slicing" in server.optimizations: - logger.debug("enabling attention slicing on SD pipeline") - try: - pipe.enable_attention_slicing() - except Exception as e: - logger.warning("error while enabling attention slicing: %s", e) - - if "diffusers-vae-slicing" in server.optimizations: - logger.debug("enabling VAE slicing on SD pipeline") - try: - pipe.enable_vae_slicing() - except Exception as e: - logger.warning("error while enabling VAE slicing: %s", e) - - if "diffusers-cpu-offload-sequential" in server.optimizations: - logger.debug("enabling sequential CPU offload on SD pipeline") - try: - pipe.enable_sequential_cpu_offload() - except Exception as e: - logger.warning("error while enabling sequential CPU offload: %s", e) - - elif "diffusers-cpu-offload-model" in server.optimizations: - # TODO: check for accelerate - logger.debug("enabling model CPU offload on SD pipeline") - try: - pipe.enable_model_cpu_offload() - except Exception as e: - logger.warning("error while enabling model CPU offload: %s", e) - - if "diffusers-memory-efficient-attention" in server.optimizations: - # TODO: check for xformers - logger.debug("enabling memory efficient attention for SD pipeline") - try: - pipe.enable_xformers_memory_efficient_attention() - except Exception as e: - logger.warning("error while enabling memory efficient attention: %s", e) - - def load_pipeline( server: ServerContext, pipeline: DiffusionPipeline, @@ -330,9 +288,74 @@ def load_pipeline( # monkey-patch pipeline if not lpw: - pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) + patch_pipeline(server, pipe, pipeline) server.cache.set("diffusion", pipe_key, pipe) server.cache.set("scheduler", scheduler_key, components["scheduler"]) return pipe + + +def optimize_pipeline( + server: ServerContext, + pipe: StableDiffusionPipeline, +) -> None: + if "diffusers-attention-slicing" in server.optimizations: + logger.debug("enabling attention slicing on SD pipeline") + try: + pipe.enable_attention_slicing() + except Exception as e: + logger.warning("error while enabling attention slicing: %s", e) + + if "diffusers-vae-slicing" in server.optimizations: + logger.debug("enabling VAE slicing on SD pipeline") + try: + pipe.enable_vae_slicing() + except Exception as e: + logger.warning("error while enabling VAE slicing: %s", e) + + if "diffusers-cpu-offload-sequential" in server.optimizations: + logger.debug("enabling sequential CPU offload on SD pipeline") + try: + pipe.enable_sequential_cpu_offload() + except Exception as e: + logger.warning("error while enabling sequential CPU offload: %s", e) + + elif "diffusers-cpu-offload-model" in server.optimizations: + # TODO: check for accelerate + logger.debug("enabling model CPU offload on SD pipeline") + try: + pipe.enable_model_cpu_offload() + except Exception as e: + logger.warning("error while enabling model CPU offload: %s", e) + + if "diffusers-memory-efficient-attention" in server.optimizations: + # TODO: check for xformers + logger.debug("enabling memory efficient attention for SD pipeline") + try: + pipe.enable_xformers_memory_efficient_attention() + except Exception as e: + logger.warning("error while enabling memory efficient attention: %s", e) + + +def patch_pipeline( + server: ServerContext, + pipe: StableDiffusionPipeline, + pipeline: Any, +) -> None: + logger.debug("patching SD pipeline") + pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) + + original_unet = pipe.unet.__call__ + original_vae = pipe.vae_decoder.__call__ + + def unet_call(sample=None, timestep=None, encoder_hidden_states=None): + logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) + return original_unet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states) + + def vae_call(latent_sample=None): + logger.info("VAE parameter types: %s", latent_sample.dtype) + return original_vae(latent_sample=latent_sample) + + pipe.unet.__call__ = unet_call + pipe.vae_decoder.__call__ = vae_call