1
0
Fork 0

fix(api): enable Unet patch for SDXL

This commit is contained in:
Sean Sube 2023-09-20 19:28:34 -05:00
parent 90d1812dec
commit d48dbf7d6e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 20 additions and 11 deletions

View File

@ -547,10 +547,9 @@ def patch_pipeline(
if not params.is_lpw() and not params.is_xl():
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
if not params.is_xl():
original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet)
logger.debug("patched UNet with wrapper")
original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
logger.debug("patched UNet with wrapper")
if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder

View File

@ -14,14 +14,17 @@ class UNetWrapper(object):
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel
xl: bool
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
xl: bool,
):
self.server = server
self.wrapped = wrapped
self.xl = xl
def __call__(
self,
@ -43,13 +46,20 @@ class UNetWrapper(object):
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if self.xl:
logger.trace(
"converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
else:
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
return self.wrapped(
sample=sample,

View File

@ -124,4 +124,4 @@ class TestSlicePrompt(unittest.TestCase):
def test_slice_outside_range(self):
slice = slice_prompt("foo || bar", 9)
self.assertEqual(slice, " bar")
self.assertEqual(slice, " bar")