From edd32f6044deb0ef545483d01e567d90f47d3ca7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 10:25:09 -0500 Subject: [PATCH] apply lint, fix up log levels --- api/onnx_web/diffusers/load.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 189b66dd..a90e58b9 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -2,8 +2,8 @@ from logging import getLogger from os import path from typing import Any, List, Optional, Tuple -import torch import numpy as np +import torch from diffusers import ( DDIMScheduler, DDPMScheduler, @@ -121,7 +121,9 @@ def load_pipeline( inversions = inversions or [] loras = loras or [] - torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32 + torch_dtype = ( + torch.float16 if "torch-fp16" in server.optimizations else torch.float32 + ) logger.debug("using Torch dtype %s for pipeline", torch_dtype) pipe_key = ( pipeline.__name__, @@ -349,6 +351,7 @@ def optimize_pipeline( timestep_dtype = None + class UNetWrapper(object): def __init__(self, wrapped): self.wrapped = wrapped @@ -357,12 +360,16 @@ class UNetWrapper(object): global timestep_dtype timestep_dtype = timestep.dtype - logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) + logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) if sample.dtype != timestep.dtype: - logger.warning("converting UNet sample dtype") + logger.info("converting UNet sample dtype") sample = sample.astype(timestep.dtype) - return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states) + return self.wrapped( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + ) def __getattr__(self, attr): return getattr(self.wrapped, attr) @@ -375,10 +382,10 @@ class VAEWrapper(object): def __call__(self, latent_sample=None): global timestep_dtype - logger.info("VAE parameter types: %s", latent_sample.dtype) + logger.trace("VAE parameter types: %s", latent_sample.dtype) if latent_sample.dtype != timestep_dtype: - logger.warning("converting VAE sample dtype") - sample = sample.astype(timestep_dtype) + logger.info("converting VAE sample dtype") + latent_sample = latent_sample.astype(timestep_dtype) return self.wrapped(latent_sample=latent_sample) @@ -386,7 +393,6 @@ class VAEWrapper(object): return getattr(self.wrapped, attr) - def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline,