fix(api): handle models with integer timestamps
This commit is contained in:
parent
ebe813d035
commit
abeeddeeb2
|
@ -47,13 +47,15 @@ class UNetWrapper(object):
|
||||||
self.prompt_index += 1
|
self.prompt_index += 1
|
||||||
|
|
||||||
if self.xl:
|
if self.xl:
|
||||||
|
# for XL, the sample and hidden states should match
|
||||||
if sample.dtype != encoder_hidden_states.dtype:
|
if sample.dtype != encoder_hidden_states.dtype:
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"converting UNet sample to hidden state dtype for XL: %s",
|
"converting UNet sample to hidden state dtype for XL: %s",
|
||||||
encoder_hidden_states.dtype,
|
encoder_hidden_states.dtype,
|
||||||
)
|
)
|
||||||
sample = sample.astype(encoder_hidden_states.dtype)
|
sample = sample.astype(encoder_hidden_states.dtype)
|
||||||
else:
|
elif timestep.dtype != np.int64:
|
||||||
|
# the optimum converter uses an int timestep
|
||||||
if sample.dtype != timestep.dtype:
|
if sample.dtype != timestep.dtype:
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
sample = sample.astype(timestep.dtype)
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
Loading…
Reference in New Issue