apply lint, fix up log levels
This commit is contained in:
parent
6f283c5c02
commit
edd32f6044
|
@ -2,8 +2,8 @@ from logging import getLogger
|
|||
from os import path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
|
@ -121,7 +121,9 @@ def load_pipeline(
|
|||
inversions = inversions or []
|
||||
loras = loras or []
|
||||
|
||||
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||
torch_dtype = (
|
||||
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||
)
|
||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||
pipe_key = (
|
||||
pipeline.__name__,
|
||||
|
@ -349,6 +351,7 @@ def optimize_pipeline(
|
|||
|
||||
timestep_dtype = None
|
||||
|
||||
|
||||
class UNetWrapper(object):
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
|
@ -357,12 +360,16 @@ class UNetWrapper(object):
|
|||
global timestep_dtype
|
||||
timestep_dtype = timestep.dtype
|
||||
|
||||
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.warning("converting UNet sample dtype")
|
||||
logger.info("converting UNet sample dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
@ -375,10 +382,10 @@ class VAEWrapper(object):
|
|||
def __call__(self, latent_sample=None):
|
||||
global timestep_dtype
|
||||
|
||||
logger.info("VAE parameter types: %s", latent_sample.dtype)
|
||||
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
||||
if latent_sample.dtype != timestep_dtype:
|
||||
logger.warning("converting VAE sample dtype")
|
||||
sample = sample.astype(timestep_dtype)
|
||||
logger.info("converting VAE sample dtype")
|
||||
latent_sample = latent_sample.astype(timestep_dtype)
|
||||
|
||||
return self.wrapped(latent_sample=latent_sample)
|
||||
|
||||
|
@ -386,7 +393,6 @@ class VAEWrapper(object):
|
|||
return getattr(self.wrapped, attr)
|
||||
|
||||
|
||||
|
||||
def patch_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
|
|
Loading…
Reference in New Issue