From 39ee4cbfcd815560f6a4afbffb1adab281f298f5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Dec 2023 22:36:39 -0600 Subject: [PATCH] handle XL UNets --- api/onnx_web/diffusers/patches/unet.py | 46 +++++++++++++++----------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index dc7c1c94..cef594c1 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -1,9 +1,9 @@ from logging import getLogger -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np from diffusers import OnnxRuntimeModel -from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE +from optimum.onnxruntime.modeling_diffusion import ORTModelUnet from ...server import ServerContext @@ -15,13 +15,13 @@ class UNetWrapper(object): prompt_embeds: Optional[List[np.ndarray]] = None prompt_index: int = 0 server: ServerContext - wrapped: OnnxRuntimeModel + wrapped: Union[OnnxRuntimeModel, ORTModelUnet] xl: bool def __init__( self, server: ServerContext, - wrapped: OnnxRuntimeModel, + wrapped: Union[OnnxRuntimeModel, ORTModelUnet], xl: bool, ): self.server = server @@ -79,23 +79,31 @@ class UNetWrapper(object): def cache_input_types(self): # TODO: use server dtype as default - self.input_types = dict( - [ - ( - input.name, - next( - [ - TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type] - for field in input.type.ListFields() - ], - np.float32, - ), - ) - for input in self.wrapped.model.graph.input - ] - ) + if isinstance(self.wrapped, ORTModelUnet): + session = self.wrapped.session + elif isinstance(self.wrapped, OnnxRuntimeModel): + session = self.wrapped.model + else: + raise ValueError() + + inputs = session.get_inputs() + self.input_types = dict([(input.name, input.type) for input in inputs]) logger.debug("cached UNet input types: %s", self.input_types) + # [ + # ( + # input.name, + # next( + # [ + # TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type] + # for field in input.type.ListFields() + # ], + # np.float32, + # ), + # ) + # for input in self.wrapped.model.graph.input + # ] + def set_prompts(self, prompt_embeds: List[np.ndarray]): logger.debug( "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]