1
0
Fork 0

add fallback dtypes to unet patch

This commit is contained in:
Sean Sube 2023-12-24 22:57:02 -06:00
parent ef256280b4
commit e2c9389d6e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 32 additions and 10 deletions

View File

@ -15,7 +15,9 @@ class UNetWrapper(object):
input_types: Optional[Dict[str, np.dtype]] = None
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
sample_dtype: np.dtype
server: ServerContext
timestep_dtype: np.dtype
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
xl: bool
@ -24,10 +26,14 @@ class UNetWrapper(object):
server: ServerContext,
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
xl: bool,
sample_dtype: Optional[np.dtype] = None,
timestep_dtype: np.dtype = np.int64,
):
self.server = server
self.wrapped = wrapped
self.xl = xl
self.sample_dtype = sample_dtype or server.torch_dtype
self.timestep_dtype = timestep_dtype
self.cache_input_types()
@ -54,19 +60,36 @@ class UNetWrapper(object):
if self.input_types is None:
self.cache_input_types()
if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
logger.trace("converting UNet hidden states to input dtype")
encoder_hidden_states_input_dtype = self.input_types.get(
"encoder_hidden_states", self.sample_dtype
)
if encoder_hidden_states.dtype != encoder_hidden_states_input_dtype:
logger.debug(
"converting UNet hidden states to input dtype from %s to %s",
encoder_hidden_states.dtype,
encoder_hidden_states_input_dtype,
)
encoder_hidden_states = encoder_hidden_states.astype(
self.input_types["encoder_hidden_states"]
encoder_hidden_states_input_dtype
)
if sample.dtype != self.input_types["sample"]:
logger.trace("converting UNet sample to input dtype")
sample = sample.astype(self.input_types["sample"])
sample_input_dtype = self.input_types.get("sample", self.sample_dtype)
if sample.dtype != sample_input_dtype:
logger.debug(
"converting UNet sample to input dtype from %s to %s",
sample.dtype,
sample_input_dtype,
)
sample = sample.astype(sample_input_dtype)
if timestep.dtype != self.input_types["timestep"]:
logger.trace("converting UNet timestep to input dtype")
timestep = timestep.astype(self.input_types["timestep"])
timestep_input_dtype = self.input_types.get("timestep", self.timestep_dtype)
if timestep.dtype != timestep_input_dtype:
logger.debug(
"converting UNet timestep to input dtype from %s to %s",
timestep.dtype,
timestep_input_dtype,
)
timestep = timestep.astype(timestep_input_dtype)
return self.wrapped(
sample=sample,
@ -79,7 +102,6 @@ class UNetWrapper(object):
return getattr(self.wrapped, attr)
def cache_input_types(self):
# TODO: use server dtype as default
if isinstance(self.wrapped, ORTModelUnet):
session = self.wrapped.session
elif isinstance(self.wrapped, OnnxRuntimeModel):