fix(api): enable Unet patch for SDXL
This commit is contained in:
parent
90d1812dec
commit
d48dbf7d6e
|
@ -547,10 +547,9 @@ def patch_pipeline(
|
|||
if not params.is_lpw() and not params.is_xl():
|
||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||
|
||||
if not params.is_xl():
|
||||
original_unet = pipe.unet
|
||||
pipe.unet = UNetWrapper(server, original_unet)
|
||||
logger.debug("patched UNet with wrapper")
|
||||
original_unet = pipe.unet
|
||||
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
|
||||
logger.debug("patched UNet with wrapper")
|
||||
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
original_decoder = pipe.vae_decoder
|
||||
|
|
|
@ -14,14 +14,17 @@ class UNetWrapper(object):
|
|||
prompt_index: int = 0
|
||||
server: ServerContext
|
||||
wrapped: OnnxRuntimeModel
|
||||
xl: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
xl: bool,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
self.xl = xl
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -43,13 +46,20 @@ class UNetWrapper(object):
|
|||
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||
self.prompt_index += 1
|
||||
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
if self.xl:
|
||||
logger.trace(
|
||||
"converting UNet sample to hidden state dtype for XL: %s",
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
sample = sample.astype(encoder_hidden_states.dtype)
|
||||
else:
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
if encoder_hidden_states.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet hidden states to timestep dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||
if encoder_hidden_states.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet hidden states to timestep dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
|
|
Loading…
Reference in New Issue