1
0
Fork 0

fix(api): enable Unet patch for SDXL

This commit is contained in:
Sean Sube 2023-09-20 19:28:34 -05:00
parent 90d1812dec
commit d48dbf7d6e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 20 additions and 11 deletions

View File

@ -547,10 +547,9 @@ def patch_pipeline(
if not params.is_lpw() and not params.is_xl(): if not params.is_lpw() and not params.is_xl():
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
if not params.is_xl(): original_unet = pipe.unet
original_unet = pipe.unet pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
pipe.unet = UNetWrapper(server, original_unet) logger.debug("patched UNet with wrapper")
logger.debug("patched UNet with wrapper")
if hasattr(pipe, "vae_decoder"): if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder original_decoder = pipe.vae_decoder

View File

@ -14,14 +14,17 @@ class UNetWrapper(object):
prompt_index: int = 0 prompt_index: int = 0
server: ServerContext server: ServerContext
wrapped: OnnxRuntimeModel wrapped: OnnxRuntimeModel
xl: bool
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
wrapped: OnnxRuntimeModel, wrapped: OnnxRuntimeModel,
xl: bool,
): ):
self.server = server self.server = server
self.wrapped = wrapped self.wrapped = wrapped
self.xl = xl
def __call__( def __call__(
self, self,
@ -43,13 +46,20 @@ class UNetWrapper(object):
encoder_hidden_states = self.prompt_embeds[step_index] encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1 self.prompt_index += 1
if sample.dtype != timestep.dtype: if self.xl:
logger.trace("converting UNet sample to timestep dtype") logger.trace(
sample = sample.astype(timestep.dtype) "converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
else:
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype: if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype") logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
return self.wrapped( return self.wrapped(
sample=sample, sample=sample,

View File

@ -124,4 +124,4 @@ class TestSlicePrompt(unittest.TestCase):
def test_slice_outside_range(self): def test_slice_outside_range(self):
slice = slice_prompt("foo || bar", 9) slice = slice_prompt("foo || bar", 9)
self.assertEqual(slice, " bar") self.assertEqual(slice, " bar")