2023-01-29 19:49:30 +00:00
|
|
|
from logging import getLogger
|
2023-11-26 21:23:28 +00:00
|
|
|
from typing import Any, List
|
2023-01-26 03:04:00 +00:00
|
|
|
|
2023-11-26 21:23:28 +00:00
|
|
|
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
|
2023-11-28 00:53:39 +00:00
|
|
|
from diffusers.pipelines.stable_diffusion import (
|
|
|
|
OnnxStableDiffusionUpscalePipeline as BasePipeline,
|
|
|
|
)
|
2023-04-12 00:29:25 +00:00
|
|
|
from diffusers.schedulers import DDPMScheduler
|
2023-01-26 03:04:00 +00:00
|
|
|
|
2023-01-29 19:49:30 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-02-05 21:54:17 +00:00
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
class FakeConfig:
|
2023-11-19 03:35:57 +00:00
|
|
|
block_out_channels: List[int]
|
2023-03-11 00:57:01 +00:00
|
|
|
scaling_factor: float
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
2023-11-26 16:43:08 +00:00
|
|
|
self.block_out_channels = [128, 256, 512]
|
2023-03-11 00:57:01 +00:00
|
|
|
self.scaling_factor = 0.08333
|
|
|
|
|
2023-01-26 03:04:00 +00:00
|
|
|
|
2023-11-26 21:23:28 +00:00
|
|
|
class OnnxStableDiffusionUpscalePipeline(BasePipeline):
|
2023-01-26 03:04:00 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
vae: OnnxRuntimeModel,
|
|
|
|
text_encoder: OnnxRuntimeModel,
|
2023-01-26 03:29:18 +00:00
|
|
|
tokenizer: Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
unet: OnnxRuntimeModel,
|
|
|
|
low_res_scheduler: DDPMScheduler,
|
2023-01-26 03:29:18 +00:00
|
|
|
scheduler: Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
max_noise_level: int = 350,
|
|
|
|
):
|
2023-04-12 00:29:25 +00:00
|
|
|
if not hasattr(vae, "config"):
|
2023-03-11 00:57:01 +00:00
|
|
|
setattr(vae, "config", FakeConfig())
|
2023-03-11 00:42:11 +00:00
|
|
|
|
2023-04-12 00:29:25 +00:00
|
|
|
super().__init__(
|
|
|
|
vae,
|
|
|
|
text_encoder,
|
|
|
|
tokenizer,
|
|
|
|
unet,
|
|
|
|
low_res_scheduler,
|
|
|
|
scheduler,
|
2023-05-10 01:25:00 +00:00
|
|
|
max_noise_level=max_noise_level,
|
2023-04-12 00:29:25 +00:00
|
|
|
)
|