2023-04-28 03:50:11 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from typing import List, Optional
|
|
|
|
|
2023-04-28 18:56:36 +00:00
|
|
|
import numpy as np
|
|
|
|
from diffusers import OnnxRuntimeModel
|
|
|
|
|
2023-04-28 03:50:11 +00:00
|
|
|
from ...server import ServerContext
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class UNetWrapper(object):
|
|
|
|
prompt_embeds: Optional[List[np.ndarray]] = None
|
|
|
|
prompt_index: int = 0
|
|
|
|
server: ServerContext
|
|
|
|
wrapped: OnnxRuntimeModel
|
2023-09-21 00:28:34 +00:00
|
|
|
xl: bool
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
server: ServerContext,
|
|
|
|
wrapped: OnnxRuntimeModel,
|
2023-09-21 00:28:34 +00:00
|
|
|
xl: bool,
|
2023-04-28 03:50:11 +00:00
|
|
|
):
|
|
|
|
self.server = server
|
|
|
|
self.wrapped = wrapped
|
2023-09-21 00:28:34 +00:00
|
|
|
self.xl = xl
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
2023-12-03 18:53:50 +00:00
|
|
|
sample: Optional[np.ndarray] = None,
|
|
|
|
timestep: Optional[np.ndarray] = None,
|
|
|
|
encoder_hidden_states: Optional[np.ndarray] = None,
|
2023-04-28 03:50:11 +00:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
logger.trace(
|
|
|
|
"UNet parameter types: %s, %s, %s",
|
|
|
|
sample.dtype,
|
|
|
|
timestep.dtype,
|
|
|
|
encoder_hidden_states.dtype,
|
|
|
|
)
|
|
|
|
|
|
|
|
if self.prompt_embeds is not None:
|
|
|
|
step_index = self.prompt_index % len(self.prompt_embeds)
|
|
|
|
logger.trace("multiple prompt embeds found, using step: %s", step_index)
|
|
|
|
encoder_hidden_states = self.prompt_embeds[step_index]
|
|
|
|
self.prompt_index += 1
|
|
|
|
|
2023-09-21 00:28:34 +00:00
|
|
|
if self.xl:
|
2023-09-24 14:49:16 +00:00
|
|
|
if sample.dtype != encoder_hidden_states.dtype:
|
|
|
|
logger.trace(
|
|
|
|
"converting UNet sample to hidden state dtype for XL: %s",
|
|
|
|
encoder_hidden_states.dtype,
|
|
|
|
)
|
|
|
|
sample = sample.astype(encoder_hidden_states.dtype)
|
2023-09-21 00:28:34 +00:00
|
|
|
else:
|
|
|
|
if sample.dtype != timestep.dtype:
|
|
|
|
logger.trace("converting UNet sample to timestep dtype")
|
|
|
|
sample = sample.astype(timestep.dtype)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-09-21 00:28:34 +00:00
|
|
|
if encoder_hidden_states.dtype != timestep.dtype:
|
|
|
|
logger.trace("converting UNet hidden states to timestep dtype")
|
|
|
|
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
return self.wrapped(
|
|
|
|
sample=sample,
|
|
|
|
timestep=timestep,
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
return getattr(self.wrapped, attr)
|
|
|
|
|
|
|
|
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
|
|
|
logger.debug(
|
|
|
|
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
|
|
|
)
|
|
|
|
self.prompt_embeds = prompt_embeds
|
|
|
|
self.prompt_index = 0
|