fix(api): enable Unet patch for SDXL
This commit is contained in:
parent
90d1812dec
commit
d48dbf7d6e
|
@ -547,10 +547,9 @@ def patch_pipeline(
|
||||||
if not params.is_lpw() and not params.is_xl():
|
if not params.is_lpw() and not params.is_xl():
|
||||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||||
|
|
||||||
if not params.is_xl():
|
original_unet = pipe.unet
|
||||||
original_unet = pipe.unet
|
pipe.unet = UNetWrapper(server, original_unet, params.is_xl())
|
||||||
pipe.unet = UNetWrapper(server, original_unet)
|
logger.debug("patched UNet with wrapper")
|
||||||
logger.debug("patched UNet with wrapper")
|
|
||||||
|
|
||||||
if hasattr(pipe, "vae_decoder"):
|
if hasattr(pipe, "vae_decoder"):
|
||||||
original_decoder = pipe.vae_decoder
|
original_decoder = pipe.vae_decoder
|
||||||
|
|
|
@ -14,14 +14,17 @@ class UNetWrapper(object):
|
||||||
prompt_index: int = 0
|
prompt_index: int = 0
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
wrapped: OnnxRuntimeModel
|
wrapped: OnnxRuntimeModel
|
||||||
|
xl: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
wrapped: OnnxRuntimeModel,
|
wrapped: OnnxRuntimeModel,
|
||||||
|
xl: bool,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
self.xl = xl
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -43,13 +46,20 @@ class UNetWrapper(object):
|
||||||
encoder_hidden_states = self.prompt_embeds[step_index]
|
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||||
self.prompt_index += 1
|
self.prompt_index += 1
|
||||||
|
|
||||||
if sample.dtype != timestep.dtype:
|
if self.xl:
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
logger.trace(
|
||||||
sample = sample.astype(timestep.dtype)
|
"converting UNet sample to hidden state dtype for XL: %s",
|
||||||
|
encoder_hidden_states.dtype,
|
||||||
|
)
|
||||||
|
sample = sample.astype(encoder_hidden_states.dtype)
|
||||||
|
else:
|
||||||
|
if sample.dtype != timestep.dtype:
|
||||||
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
if encoder_hidden_states.dtype != timestep.dtype:
|
if encoder_hidden_states.dtype != timestep.dtype:
|
||||||
logger.trace("converting UNet hidden states to timestep dtype")
|
logger.trace("converting UNet hidden states to timestep dtype")
|
||||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||||
|
|
||||||
return self.wrapped(
|
return self.wrapped(
|
||||||
sample=sample,
|
sample=sample,
|
||||||
|
|
|
@ -124,4 +124,4 @@ class TestSlicePrompt(unittest.TestCase):
|
||||||
|
|
||||||
def test_slice_outside_range(self):
|
def test_slice_outside_range(self):
|
||||||
slice = slice_prompt("foo || bar", 9)
|
slice = slice_prompt("foo || bar", 9)
|
||||||
self.assertEqual(slice, " bar")
|
self.assertEqual(slice, " bar")
|
||||||
|
|
Loading…
Reference in New Issue