handle XL UNets
This commit is contained in:
parent
80a255397e
commit
39ee4cbfcd
|
@ -1,9 +1,9 @@
|
|||
from logging import getLogger
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxRuntimeModel
|
||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
||||
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
|
||||
|
||||
from ...server import ServerContext
|
||||
|
||||
|
@ -15,13 +15,13 @@ class UNetWrapper(object):
|
|||
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||
prompt_index: int = 0
|
||||
server: ServerContext
|
||||
wrapped: OnnxRuntimeModel
|
||||
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
|
||||
xl: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
|
||||
xl: bool,
|
||||
):
|
||||
self.server = server
|
||||
|
@ -79,23 +79,31 @@ class UNetWrapper(object):
|
|||
|
||||
def cache_input_types(self):
|
||||
# TODO: use server dtype as default
|
||||
self.input_types = dict(
|
||||
[
|
||||
(
|
||||
input.name,
|
||||
next(
|
||||
[
|
||||
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
||||
for field in input.type.ListFields()
|
||||
],
|
||||
np.float32,
|
||||
),
|
||||
)
|
||||
for input in self.wrapped.model.graph.input
|
||||
]
|
||||
)
|
||||
if isinstance(self.wrapped, ORTModelUnet):
|
||||
session = self.wrapped.session
|
||||
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
||||
session = self.wrapped.model
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
inputs = session.get_inputs()
|
||||
self.input_types = dict([(input.name, input.type) for input in inputs])
|
||||
logger.debug("cached UNet input types: %s", self.input_types)
|
||||
|
||||
# [
|
||||
# (
|
||||
# input.name,
|
||||
# next(
|
||||
# [
|
||||
# TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
||||
# for field in input.type.ListFields()
|
||||
# ],
|
||||
# np.float32,
|
||||
# ),
|
||||
# )
|
||||
# for input in self.wrapped.model.graph.input
|
||||
# ]
|
||||
|
||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||
logger.debug(
|
||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||
|
|
Loading…
Reference in New Issue