add fallback dtypes to unet patch
This commit is contained in:
parent
ef256280b4
commit
e2c9389d6e
|
@ -15,7 +15,9 @@ class UNetWrapper(object):
|
|||
input_types: Optional[Dict[str, np.dtype]] = None
|
||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||
prompt_index: int = 0
|
||||
sample_dtype: np.dtype
|
||||
server: ServerContext
|
||||
timestep_dtype: np.dtype
|
||||
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
|
||||
xl: bool
|
||||
|
||||
|
@ -24,10 +26,14 @@ class UNetWrapper(object):
|
|||
server: ServerContext,
|
||||
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
|
||||
xl: bool,
|
||||
sample_dtype: Optional[np.dtype] = None,
|
||||
timestep_dtype: np.dtype = np.int64,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
self.xl = xl
|
||||
self.sample_dtype = sample_dtype or server.torch_dtype
|
||||
self.timestep_dtype = timestep_dtype
|
||||
|
||||
self.cache_input_types()
|
||||
|
||||
|
@ -54,19 +60,36 @@ class UNetWrapper(object):
|
|||
if self.input_types is None:
|
||||
self.cache_input_types()
|
||||
|
||||
if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
|
||||
logger.trace("converting UNet hidden states to input dtype")
|
||||
encoder_hidden_states_input_dtype = self.input_types.get(
|
||||
"encoder_hidden_states", self.sample_dtype
|
||||
)
|
||||
if encoder_hidden_states.dtype != encoder_hidden_states_input_dtype:
|
||||
logger.debug(
|
||||
"converting UNet hidden states to input dtype from %s to %s",
|
||||
encoder_hidden_states.dtype,
|
||||
encoder_hidden_states_input_dtype,
|
||||
)
|
||||
encoder_hidden_states = encoder_hidden_states.astype(
|
||||
self.input_types["encoder_hidden_states"]
|
||||
encoder_hidden_states_input_dtype
|
||||
)
|
||||
|
||||
if sample.dtype != self.input_types["sample"]:
|
||||
logger.trace("converting UNet sample to input dtype")
|
||||
sample = sample.astype(self.input_types["sample"])
|
||||
sample_input_dtype = self.input_types.get("sample", self.sample_dtype)
|
||||
if sample.dtype != sample_input_dtype:
|
||||
logger.debug(
|
||||
"converting UNet sample to input dtype from %s to %s",
|
||||
sample.dtype,
|
||||
sample_input_dtype,
|
||||
)
|
||||
sample = sample.astype(sample_input_dtype)
|
||||
|
||||
if timestep.dtype != self.input_types["timestep"]:
|
||||
logger.trace("converting UNet timestep to input dtype")
|
||||
timestep = timestep.astype(self.input_types["timestep"])
|
||||
timestep_input_dtype = self.input_types.get("timestep", self.timestep_dtype)
|
||||
if timestep.dtype != timestep_input_dtype:
|
||||
logger.debug(
|
||||
"converting UNet timestep to input dtype from %s to %s",
|
||||
timestep.dtype,
|
||||
timestep_input_dtype,
|
||||
)
|
||||
timestep = timestep.astype(timestep_input_dtype)
|
||||
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
|
@ -79,7 +102,6 @@ class UNetWrapper(object):
|
|||
return getattr(self.wrapped, attr)
|
||||
|
||||
def cache_input_types(self):
|
||||
# TODO: use server dtype as default
|
||||
if isinstance(self.wrapped, ORTModelUnet):
|
||||
session = self.wrapped.session
|
||||
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
||||
|
|
Loading…
Reference in New Issue