apply lint, fix up log levels
This commit is contained in:
parent
6f283c5c02
commit
edd32f6044
|
@ -2,8 +2,8 @@ from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
@ -121,7 +121,9 @@ def load_pipeline(
|
||||||
inversions = inversions or []
|
inversions = inversions or []
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
|
|
||||||
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
torch_dtype = (
|
||||||
|
torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||||
|
)
|
||||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||||
pipe_key = (
|
pipe_key = (
|
||||||
pipeline.__name__,
|
pipeline.__name__,
|
||||||
|
@ -349,6 +351,7 @@ def optimize_pipeline(
|
||||||
|
|
||||||
timestep_dtype = None
|
timestep_dtype = None
|
||||||
|
|
||||||
|
|
||||||
class UNetWrapper(object):
|
class UNetWrapper(object):
|
||||||
def __init__(self, wrapped):
|
def __init__(self, wrapped):
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
@ -357,12 +360,16 @@ class UNetWrapper(object):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
timestep_dtype = timestep.dtype
|
timestep_dtype = timestep.dtype
|
||||||
|
|
||||||
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||||
if sample.dtype != timestep.dtype:
|
if sample.dtype != timestep.dtype:
|
||||||
logger.warning("converting UNet sample dtype")
|
logger.info("converting UNet sample dtype")
|
||||||
sample = sample.astype(timestep.dtype)
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
|
return self.wrapped(
|
||||||
|
sample=sample,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self.wrapped, attr)
|
return getattr(self.wrapped, attr)
|
||||||
|
@ -375,10 +382,10 @@ class VAEWrapper(object):
|
||||||
def __call__(self, latent_sample=None):
|
def __call__(self, latent_sample=None):
|
||||||
global timestep_dtype
|
global timestep_dtype
|
||||||
|
|
||||||
logger.info("VAE parameter types: %s", latent_sample.dtype)
|
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
||||||
if latent_sample.dtype != timestep_dtype:
|
if latent_sample.dtype != timestep_dtype:
|
||||||
logger.warning("converting VAE sample dtype")
|
logger.info("converting VAE sample dtype")
|
||||||
sample = sample.astype(timestep_dtype)
|
latent_sample = latent_sample.astype(timestep_dtype)
|
||||||
|
|
||||||
return self.wrapped(latent_sample=latent_sample)
|
return self.wrapped(latent_sample=latent_sample)
|
||||||
|
|
||||||
|
@ -386,7 +393,6 @@ class VAEWrapper(object):
|
||||||
return getattr(self.wrapped, attr)
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def patch_pipeline(
|
def patch_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
|
Loading…
Reference in New Issue