1
0
Fork 0

handle XL UNets

This commit is contained in:
Sean Sube 2023-12-24 22:36:39 -06:00
parent 80a255397e
commit 39ee4cbfcd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 27 additions and 19 deletions

View File

@ -1,9 +1,9 @@
from logging import getLogger
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import numpy as np
from diffusers import OnnxRuntimeModel
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
from ...server import ServerContext
@ -15,13 +15,13 @@ class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
xl: bool
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
xl: bool,
):
self.server = server
@ -79,23 +79,31 @@ class UNetWrapper(object):
def cache_input_types(self):
# TODO: use server dtype as default
self.input_types = dict(
[
(
input.name,
next(
[
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
for field in input.type.ListFields()
],
np.float32,
),
)
for input in self.wrapped.model.graph.input
]
)
if isinstance(self.wrapped, ORTModelUnet):
session = self.wrapped.session
elif isinstance(self.wrapped, OnnxRuntimeModel):
session = self.wrapped.model
else:
raise ValueError()
inputs = session.get_inputs()
self.input_types = dict([(input.name, input.type) for input in inputs])
logger.debug("cached UNet input types: %s", self.input_types)
# [
# (
# input.name,
# next(
# [
# TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
# for field in input.type.ListFields()
# ],
# np.float32,
# ),
# )
# for input in self.wrapped.model.graph.input
# ]
def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]