1
0
Fork 0
onnx-web/api/onnx_web/diffusers/pipelines/upscale.py

43 lines
1.1 KiB
Python
Raw Normal View History

2023-01-29 19:49:30 +00:00
from logging import getLogger
from typing import Any, List
from diffusers.pipelines.onnx_utils import OnnxRuntimeModel
from diffusers.pipelines.stable_diffusion import OnnxStableDiffusionUpscalePipeline as BasePipeline
from diffusers.schedulers import DDPMScheduler
2023-01-29 19:49:30 +00:00
logger = getLogger(__name__)
2023-02-05 21:54:17 +00:00
class FakeConfig:
block_out_channels: List[int]
2023-03-11 00:57:01 +00:00
scaling_factor: float
def __init__(self) -> None:
self.block_out_channels = [128, 256, 512]
2023-03-11 00:57:01 +00:00
self.scaling_factor = 0.08333
class OnnxStableDiffusionUpscalePipeline(BasePipeline):
def __init__(
self,
vae: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: Any,
unet: OnnxRuntimeModel,
low_res_scheduler: DDPMScheduler,
scheduler: Any,
max_noise_level: int = 350,
):
if not hasattr(vae, "config"):
2023-03-11 00:57:01 +00:00
setattr(vae, "config", FakeConfig())
super().__init__(
vae,
text_encoder,
tokenizer,
unet,
low_res_scheduler,
scheduler,
max_noise_level=max_noise_level,
)