diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index f4609e83..5a5f2d3b 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -107,7 +107,7 @@ class UNetWrapper(object): elif isinstance(self.wrapped, OnnxRuntimeModel): session = self.wrapped.model else: - raise ValueError() + raise ValueError("unknown UNet class") inputs = session.get_inputs() self.input_types = dict( @@ -115,20 +115,6 @@ class UNetWrapper(object): ) logger.debug("cached UNet input types: %s", self.input_types) - # [ - # ( - # input.name, - # next( - # [ - # TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type] - # for field in input.type.ListFields() - # ], - # np.float32, - # ), - # ) - # for input in self.wrapped.model.graph.input - # ] - def set_prompts(self, prompt_embeds: List[np.ndarray]): logger.debug( "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]