handle XL UNets
This commit is contained in:
parent
80a255397e
commit
39ee4cbfcd
|
@ -1,9 +1,9 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import OnnxRuntimeModel
|
from diffusers import OnnxRuntimeModel
|
||||||
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
|
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet
|
||||||
|
|
||||||
from ...server import ServerContext
|
from ...server import ServerContext
|
||||||
|
|
||||||
|
@ -15,13 +15,13 @@ class UNetWrapper(object):
|
||||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||||
prompt_index: int = 0
|
prompt_index: int = 0
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
wrapped: OnnxRuntimeModel
|
wrapped: Union[OnnxRuntimeModel, ORTModelUnet]
|
||||||
xl: bool
|
xl: bool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
wrapped: OnnxRuntimeModel,
|
wrapped: Union[OnnxRuntimeModel, ORTModelUnet],
|
||||||
xl: bool,
|
xl: bool,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
|
@ -79,23 +79,31 @@ class UNetWrapper(object):
|
||||||
|
|
||||||
def cache_input_types(self):
|
def cache_input_types(self):
|
||||||
# TODO: use server dtype as default
|
# TODO: use server dtype as default
|
||||||
self.input_types = dict(
|
if isinstance(self.wrapped, ORTModelUnet):
|
||||||
[
|
session = self.wrapped.session
|
||||||
(
|
elif isinstance(self.wrapped, OnnxRuntimeModel):
|
||||||
input.name,
|
session = self.wrapped.model
|
||||||
next(
|
else:
|
||||||
[
|
raise ValueError()
|
||||||
TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
|
||||||
for field in input.type.ListFields()
|
inputs = session.get_inputs()
|
||||||
],
|
self.input_types = dict([(input.name, input.type) for input in inputs])
|
||||||
np.float32,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for input in self.wrapped.model.graph.input
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger.debug("cached UNet input types: %s", self.input_types)
|
logger.debug("cached UNet input types: %s", self.input_types)
|
||||||
|
|
||||||
|
# [
|
||||||
|
# (
|
||||||
|
# input.name,
|
||||||
|
# next(
|
||||||
|
# [
|
||||||
|
# TENSOR_TYPE_TO_NP_TYPE[field[1].elem_type]
|
||||||
|
# for field in input.type.ListFields()
|
||||||
|
# ],
|
||||||
|
# np.float32,
|
||||||
|
# ),
|
||||||
|
# )
|
||||||
|
# for input in self.wrapped.model.graph.input
|
||||||
|
# ]
|
||||||
|
|
||||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||||
|
|
Loading…
Reference in New Issue