add message to unet type error
This commit is contained in:
parent
e2c9389d6e
commit
ae0e8447f9
|
@ -107,7 +107,7 @@ class UNetWrapper(object):
|
|||
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
||||
session = self.wrapped.model
|
||||
else:
|
||||
raise ValueError()
|
||||
raise ValueError("unknown UNet class")
|
||||
|
||||
inputs = session.get_inputs()
|
||||
self.input_types = dict(
|
||||
|
@ -115,20 +115,6 @@ class UNetWrapper(object):
|
|||
)
|
||||
logger.debug("cached UNet input types: %s", self.input_types)
|
||||
|
||||
# [
|
||||
# (
|
||||
# input.name,
|
||||
# next(
|
||||
# [
|
||||
# TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
||||
# for field in input.type.ListFields()
|
||||
# ],
|
||||
# np.float32,
|
||||
# ),
|
||||
# )
|
||||
# for input in self.wrapped.model.graph.input
|
||||
# ]
|
||||
|
||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||
logger.debug(
|
||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||
|
|
Loading…
Reference in New Issue