1
0
Fork 0

handle XL UNets

This commit is contained in:
Sean Sube 2023-12-24 22:36:39 -06:00
parent 80a255397e
commit 39ee4cbfcd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 27 additions and 19 deletions

View File

@ -1,9 +1,9 @@
from logging import getLogger from logging import getLogger
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from diffusers import OnnxRuntimeModel from diffusers import OnnxRuntimeModel
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
from ...server import ServerContext from ...server import ServerContext
@ -15,13 +15,13 @@ class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0 prompt_index: int = 0
server: ServerContext server: ServerContext
wrapped: OnnxRuntimeModel wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
xl: bool xl: bool
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
wrapped: OnnxRuntimeModel, wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
xl: bool, xl: bool,
): ):
self.server = server self.server = server
@ -79,23 +79,31 @@ class UNetWrapper(object):
def cache_input_types(self): def cache_input_types(self):
# TODO: use server dtype as default # TODO: use server dtype as default
self.input_types = dict( if isinstance(self.wrapped, ORTModelUnet):
[ session = self.wrapped.session
( elif isinstance(self.wrapped, OnnxRuntimeModel):
input.name, session = self.wrapped.model
next( else:
[ raise ValueError()
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
for field in input.type.ListFields() inputs = session.get_inputs()
], self.input_types = dict([(input.name, input.type) for input in inputs])
np.float32,
),
)
for input in self.wrapped.model.graph.input
]
)
logger.debug("cached UNet input types: %s", self.input_types) logger.debug("cached UNet input types: %s", self.input_types)
# [
# (
# input.name,
# next(
# [
# TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
# for field in input.type.ListFields()
# ],
# np.float32,
# ),
# )
# for input in self.wrapped.model.graph.input
# ]
def set_prompts(self, prompt_embeds: List[np.ndarray]): def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug( logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds] "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]