add fallback dtypes to unet patch
This commit is contained in:
parent
ef256280b4
commit
e2c9389d6e
|
@ -15,7 +15,9 @@ class UNetWrapper(object):
|
||||||
input_types: Optional[Dict[str, np.dtype]] = None
|
input_types: Optional[Dict[str, np.dtype]] = None
|
||||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||||
prompt_index: int = 0
|
prompt_index: int = 0
|
||||||
|
sample_dtype: np.dtype
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
|
timestep_dtype: np.dtype
|
||||||
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
|
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
|
||||||
xl: bool
|
xl: bool
|
||||||
|
|
||||||
|
@ -24,10 +26,14 @@ class UNetWrapper(object):
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
|
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
|
||||||
xl: bool,
|
xl: bool,
|
||||||
|
sample_dtype: Optional[np.dtype] = None,
|
||||||
|
timestep_dtype: np.dtype = np.int64,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.xl = xl
|
self.xl = xl
|
||||||
|
self.sample_dtype = sample_dtype or server.torch_dtype
|
||||||
|
self.timestep_dtype = timestep_dtype
|
||||||
|
|
||||||
self.cache_input_types()
|
self.cache_input_types()
|
||||||
|
|
||||||
|
@ -54,19 +60,36 @@ class UNetWrapper(object):
|
||||||
if self.input_types is None:
|
if self.input_types is None:
|
||||||
self.cache_input_types()
|
self.cache_input_types()
|
||||||
|
|
||||||
if encoder_hidden_states.dtype != self.input_types["encoder_hidden_states"]:
|
encoder_hidden_states_input_dtype = self.input_types.get(
|
||||||
logger.trace("converting UNet hidden states to input dtype")
|
"encoder_hidden_states", self.sample_dtype
|
||||||
|
)
|
||||||
|
if encoder_hidden_states.dtype != encoder_hidden_states_input_dtype:
|
||||||
|
logger.debug(
|
||||||
|
"converting UNet hidden states to input dtype from %s to %s",
|
||||||
|
encoder_hidden_states.dtype,
|
||||||
|
encoder_hidden_states_input_dtype,
|
||||||
|
)
|
||||||
encoder_hidden_states = encoder_hidden_states.astype(
|
encoder_hidden_states = encoder_hidden_states.astype(
|
||||||
self.input_types["encoder_hidden_states"]
|
encoder_hidden_states_input_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if sample.dtype != self.input_types["sample"]:
|
sample_input_dtype = self.input_types.get("sample", self.sample_dtype)
|
||||||
logger.trace("converting UNet sample to input dtype")
|
if sample.dtype != sample_input_dtype:
|
||||||
sample = sample.astype(self.input_types["sample"])
|
logger.debug(
|
||||||
|
"converting UNet sample to input dtype from %s to %s",
|
||||||
|
sample.dtype,
|
||||||
|
sample_input_dtype,
|
||||||
|
)
|
||||||
|
sample = sample.astype(sample_input_dtype)
|
||||||
|
|
||||||
if timestep.dtype != self.input_types["timestep"]:
|
timestep_input_dtype = self.input_types.get("timestep", self.timestep_dtype)
|
||||||
logger.trace("converting UNet timestep to input dtype")
|
if timestep.dtype != timestep_input_dtype:
|
||||||
timestep = timestep.astype(self.input_types["timestep"])
|
logger.debug(
|
||||||
|
"converting UNet timestep to input dtype from %s to %s",
|
||||||
|
timestep.dtype,
|
||||||
|
timestep_input_dtype,
|
||||||
|
)
|
||||||
|
timestep = timestep.astype(timestep_input_dtype)
|
||||||
|
|
||||||
return self.wrapped(
|
return self.wrapped(
|
||||||
sample=sample,
|
sample=sample,
|
||||||
|
@ -79,7 +102,6 @@ class UNetWrapper(object):
|
||||||
return getattr(self.wrapped, attr)
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
def cache_input_types(self):
|
def cache_input_types(self):
|
||||||
# TODO: use server dtype as default
|
|
||||||
if isinstance(self.wrapped, ORTModelUnet):
|
if isinstance(self.wrapped, ORTModelUnet):
|
||||||
session = self.wrapped.session
|
session = self.wrapped.session
|
||||||
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
||||||
|
|
Loading…
Reference in New Issue