1
0
Fork 0

apply lint, fix up log levels

This commit is contained in:
Sean Sube 2023-03-19 10:25:09 -05:00
parent 6f283c5c02
commit edd32f6044
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 15 additions and 9 deletions

View File

@ -2,8 +2,8 @@ from logging import getLogger
from os import path
from typing import Any, List, Optional, Tuple
import torch
import numpy as np
import torch
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@ -121,7 +121,9 @@ def load_pipeline(
inversions = inversions or []
loras = loras or []
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
torch_dtype = (
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
)
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
pipe_key = (
pipeline.__name__,
@ -349,6 +351,7 @@ def optimize_pipeline(
timestep_dtype = None
class UNetWrapper(object):
def __init__(self, wrapped):
self.wrapped = wrapped
@ -357,12 +360,16 @@ class UNetWrapper(object):
global timestep_dtype
timestep_dtype = timestep.dtype
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype:
logger.warning("converting UNet sample dtype")
logger.info("converting UNet sample dtype")
sample = sample.astype(timestep.dtype)
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
@ -375,10 +382,10 @@ class VAEWrapper(object):
def __call__(self, latent_sample=None):
global timestep_dtype
logger.info("VAE parameter types: %s", latent_sample.dtype)
logger.trace("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.warning("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)
return self.wrapped(latent_sample=latent_sample)
@ -386,7 +393,6 @@ class VAEWrapper(object):
return getattr(self.wrapped, attr)
def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,