1
0
Fork 0

add message to unet type error

This commit is contained in:
Sean Sube 2023-12-24 23:10:08 -06:00
parent e2c9389d6e
commit ae0e8447f9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 1 additions and 15 deletions

View File

@ -107,7 +107,7 @@ class UNetWrapper(object):
elif isinstance(self.wrapped, OnnxRuntimeModel):
session = self.wrapped.model
else:
raise ValueError()
raise ValueError("unknown UNet class")
inputs = session.get_inputs()
self.input_types = dict(
@ -115,20 +115,6 @@ class UNetWrapper(object):
)
logger.debug("cached UNet input types: %s", self.input_types)
# [
# (
# input.name,
# next(
# [
# TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
# for field in input.type.ListFields()
# ],
# np.float32,
# ),
# )
# for input in self.wrapped.model.graph.input
# ]
def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]