1
0
Fork 0
onnx-web/api/onnx_web/diffusers/pipelines/upscale.py

45 lines
1.1 KiB
Python

from logging import getLogger
from typing import Any, List
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import (
OnnxStableDiffusionUpscalePipeline as BasePipeline,
)
from diffusers.schedulers import DDPMScheduler
logger = getLogger(__name__)
class FakeConfig:
block_out_channels: List[int]
scaling_factor: float
def __init__(self) -> None:
self.block_out_channels = [128, 256, 512]
self.scaling_factor = 0.08333
class OnnxStableDiffusionUpscalePipeline(BasePipeline):
def __init__(
self,
vae: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: Any,
unet: OnnxRuntimeModel,
low_res_scheduler: DDPMScheduler,
scheduler: Any,
max_noise_level: int = 350,
):
if not hasattr(vae, "config"):
setattr(vae, "config", FakeConfig())
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
low_res_scheduler,
scheduler,
max_noise_level=max_noise_level,
)