1
0
Fork 0
onnx-web/api/onnx_web/diffusers/patches/unet.py

81 lines
2.4 KiB
Python
Raw Normal View History

2023-04-28 03:50:11 +00:00
from logging import getLogger
from typing import List, Optional
2023-04-28 18:56:36 +00:00
import numpy as np
from diffusers import OnnxRuntimeModel
2023-04-28 03:50:11 +00:00
from ...server import ServerContext
logger = getLogger(__name__)
class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel
2023-09-21 00:28:34 +00:00
xl: bool
2023-04-28 03:50:11 +00:00
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
2023-09-21 00:28:34 +00:00
xl: bool,
2023-04-28 03:50:11 +00:00
):
self.server = server
self.wrapped = wrapped
2023-09-21 00:28:34 +00:00
self.xl = xl
2023-04-28 03:50:11 +00:00
def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
**kwargs,
):
logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)
if self.prompt_embeds is not None:
step_index = self.prompt_index % len(self.prompt_embeds)
logger.trace("multiple prompt embeds found, using step: %s", step_index)
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1
2023-09-21 00:28:34 +00:00
if self.xl:
if sample.dtype != encoder_hidden_states.dtype:
logger.trace(
"converting UNet sample to hidden state dtype for XL: %s",
encoder_hidden_states.dtype,
)
sample = sample.astype(encoder_hidden_states.dtype)
2023-09-21 00:28:34 +00:00
else:
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
2023-04-28 03:50:11 +00:00
2023-09-21 00:28:34 +00:00
if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
2023-04-28 03:50:11 +00:00
return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
**kwargs,
)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
)
self.prompt_embeds = prompt_embeds
self.prompt_index = 0