2023-01-26 03:04:00 +00:00
|
|
|
from diffusers import (
|
|
|
|
DDPMScheduler,
|
|
|
|
OnnxRuntimeModel,
|
|
|
|
StableDiffusionUpscalePipeline,
|
|
|
|
)
|
|
|
|
from typing import (
|
2023-01-26 03:29:18 +00:00
|
|
|
Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
Callable,
|
|
|
|
Union,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
)
|
|
|
|
|
|
|
|
import PIL
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
vae: OnnxRuntimeModel,
|
|
|
|
text_encoder: OnnxRuntimeModel,
|
2023-01-26 03:29:18 +00:00
|
|
|
tokenizer: Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
unet: OnnxRuntimeModel,
|
|
|
|
low_res_scheduler: DDPMScheduler,
|
2023-01-26 03:29:18 +00:00
|
|
|
scheduler: Any,
|
2023-01-26 03:04:00 +00:00
|
|
|
max_noise_level: int = 350,
|
|
|
|
):
|
2023-01-26 03:29:18 +00:00
|
|
|
super().__init__(vae, text_encoder, tokenizer, unet,
|
|
|
|
low_res_scheduler, scheduler, max_noise_level)
|
2023-01-26 03:04:00 +00:00
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
2023-01-26 03:29:18 +00:00
|
|
|
*args,
|
|
|
|
**kwargs,
|
2023-01-26 03:04:00 +00:00
|
|
|
):
|
2023-01-26 03:29:18 +00:00
|
|
|
super().__call__(*args, **kwargs)
|