1
0
Fork 0

apply lint, fix up log levels

This commit is contained in:
Sean Sube 2023-03-19 10:25:09 -05:00
parent 6f283c5c02
commit edd32f6044
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 15 additions and 9 deletions

View File

@ -2,8 +2,8 @@ from logging import getLogger
from os import path from os import path
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch
import numpy as np import numpy as np
import torch
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
@ -121,7 +121,9 @@ def load_pipeline(
inversions = inversions or [] inversions = inversions or []
loras = loras or [] loras = loras or []
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32 torch_dtype = (
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
)
logger.debug("using Torch dtype %s for pipeline", torch_dtype) logger.debug("using Torch dtype %s for pipeline", torch_dtype)
pipe_key = ( pipe_key = (
pipeline.__name__, pipeline.__name__,
@ -349,6 +351,7 @@ def optimize_pipeline(
timestep_dtype = None timestep_dtype = None
class UNetWrapper(object): class UNetWrapper(object):
def __init__(self, wrapped): def __init__(self, wrapped):
self.wrapped = wrapped self.wrapped = wrapped
@ -357,12 +360,16 @@ class UNetWrapper(object):
global timestep_dtype global timestep_dtype
timestep_dtype = timestep.dtype timestep_dtype = timestep.dtype
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype: if sample.dtype != timestep.dtype:
logger.warning("converting UNet sample dtype") logger.info("converting UNet sample dtype")
sample = sample.astype(timestep.dtype) sample = sample.astype(timestep.dtype)
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states) return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
)
def __getattr__(self, attr): def __getattr__(self, attr):
return getattr(self.wrapped, attr) return getattr(self.wrapped, attr)
@ -375,10 +382,10 @@ class VAEWrapper(object):
def __call__(self, latent_sample=None): def __call__(self, latent_sample=None):
global timestep_dtype global timestep_dtype
logger.info("VAE parameter types: %s", latent_sample.dtype) logger.trace("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype: if latent_sample.dtype != timestep_dtype:
logger.warning("converting VAE sample dtype") logger.info("converting VAE sample dtype")
sample = sample.astype(timestep_dtype) latent_sample = latent_sample.astype(timestep_dtype)
return self.wrapped(latent_sample=latent_sample) return self.wrapped(latent_sample=latent_sample)
@ -386,7 +393,6 @@ class VAEWrapper(object):
return getattr(self.wrapped, attr) return getattr(self.wrapped, attr)
def patch_pipeline( def patch_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,