1
0
Fork 0

fix(api): handle models with integer timestamps

This commit is contained in:
Sean Sube 2023-12-23 22:19:15 -06:00
parent ebe813d035
commit abeeddeeb2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 3 additions and 1 deletions

View File

@ -47,13 +47,15 @@ class UNetWrapper(object):
self.prompt_index += 1 self.prompt_index += 1
if self.xl: if self.xl:
# for XL, the sample and hidden states should match
if sample.dtype != encoder_hidden_states.dtype: if sample.dtype != encoder_hidden_states.dtype:
logger.trace( logger.trace(
"converting UNet sample to hidden state dtype for XL: %s", "converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype, encoder_hidden_states.dtype,
) )
sample = sample.astype(encoder_hidden_states.dtype) sample = sample.astype(encoder_hidden_states.dtype)
else: elif timestep.dtype != np.int64:
# the optimum converter uses an int timestep
if sample.dtype != timestep.dtype: if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype") logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype) sample = sample.astype(timestep.dtype)