fix(api): handle models with integer timestamps
This commit is contained in:
parent
ebe813d035
commit
abeeddeeb2
|
@ -47,13 +47,15 @@ class UNetWrapper(object):
|
|||
self.prompt_index += 1
|
||||
|
||||
if self.xl:
|
||||
# for XL, the sample and hidden states should match
|
||||
if sample.dtype != encoder_hidden_states.dtype:
|
||||
logger.trace(
|
||||
"converting UNet sample to hidden state dtype for XL: %s",
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
sample = sample.astype(encoder_hidden_states.dtype)
|
||||
else:
|
||||
elif timestep.dtype != np.int64:
|
||||
# the optimum converter uses an int timestep
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
|
Loading…
Reference in New Issue