load additional components for SD upscaling
This commit is contained in:
parent
819af824b0
commit
fb5c46d90c
|
@ -19,7 +19,7 @@ from .image import (
|
|||
)
|
||||
from .upscale import (
|
||||
make_resrgan,
|
||||
run_upscale_pipeline,
|
||||
run_upscale_correction,
|
||||
upscale_gfpgan,
|
||||
upscale_resrgan,
|
||||
UpscaleParams,
|
||||
|
|
|
@ -249,7 +249,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
|||
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
# TODO: needs to be Int or Long for upscaling, Bool for regular
|
||||
4,
|
||||
# False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
ordered_input_names=["sample", "timestep",
|
||||
|
|
|
@ -16,7 +16,7 @@ from .image import (
|
|||
expand_image,
|
||||
)
|
||||
from .upscale import (
|
||||
upscale_resrgan,
|
||||
run_upscale_correction,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
|
@ -117,7 +117,7 @@ def run_txt2img_pipeline(
|
|||
num_inference_steps=params.steps,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_pipeline(ctx, upscale, image)
|
||||
image = run_upscale_correction(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -151,7 +151,7 @@ def run_img2img_pipeline(
|
|||
strength=strength,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_pipeline(ctx, upscale, image)
|
||||
image = run_upscale_correction(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -215,7 +215,7 @@ def run_inpaint_pipeline(
|
|||
else:
|
||||
print('output image size does not match source, skipping post-blend')
|
||||
|
||||
image = run_upscale_pipeline(ctx, upscale, image)
|
||||
image = run_upscale_correction(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -234,7 +234,7 @@ def run_upscale_pipeline(
|
|||
upscale: UpscaleParams,
|
||||
source_image: Image
|
||||
):
|
||||
image = upscale_resrgan(ctx, upscale, source_image)
|
||||
image = run_upscale_correction(ctx, upscale, source_image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
|
|
@ -3,13 +3,8 @@ from diffusers import (
|
|||
OnnxRuntimeModel,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.models import (
|
||||
CLIPTokenizer,
|
||||
)
|
||||
from diffusers.schedulers import (
|
||||
KarrasDiffusionSchedulers,
|
||||
)
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Union,
|
||||
List,
|
||||
|
@ -25,34 +20,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer: Any,
|
||||
unet: OnnxRuntimeModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
scheduler: Any,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
||||
super().__init__(vae, text_encoder, tokenizer, unet,
|
||||
low_res_scheduler, scheduler, max_noise_level)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image,
|
||||
List[PIL.Image.Image]] = None,
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[
|
||||
int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
super().__call__(*args, **kwargs)
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
)
|
||||
from gfpgan import GFPGANer
|
||||
from os import path
|
||||
from PIL import Image
|
||||
|
@ -127,12 +131,20 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
|
|||
|
||||
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
print('upscaling with Stable Diffusion')
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(params.upscale_model)
|
||||
model_path = '../models/%s' % params.upscale_model
|
||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path,
|
||||
vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'),
|
||||
low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'),
|
||||
)
|
||||
result = pipeline('', image=image)
|
||||
return result.images[0]
|
||||
|
||||
|
||||
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
print('running upscale pipeline')
|
||||
|
||||
if params.scale > 1:
|
||||
|
|
Loading…
Reference in New Issue