load additional components for SD upscaling
This commit is contained in:
parent
819af824b0
commit
fb5c46d90c
|
@ -19,7 +19,7 @@ from .image import (
|
||||||
)
|
)
|
||||||
from .upscale import (
|
from .upscale import (
|
||||||
make_resrgan,
|
make_resrgan,
|
||||||
run_upscale_pipeline,
|
run_upscale_correction,
|
||||||
upscale_gfpgan,
|
upscale_gfpgan,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
|
|
|
@ -249,7 +249,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str):
|
||||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||||
device=training_device, dtype=dtype),
|
device=training_device, dtype=dtype),
|
||||||
False,
|
# TODO: needs to be Int or Long for upscaling, Bool for regular
|
||||||
|
4,
|
||||||
|
# False,
|
||||||
),
|
),
|
||||||
output_path=unet_path,
|
output_path=unet_path,
|
||||||
ordered_input_names=["sample", "timestep",
|
ordered_input_names=["sample", "timestep",
|
||||||
|
|
|
@ -16,7 +16,7 @@ from .image import (
|
||||||
expand_image,
|
expand_image,
|
||||||
)
|
)
|
||||||
from .upscale import (
|
from .upscale import (
|
||||||
upscale_resrgan,
|
run_upscale_correction,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
@ -117,7 +117,7 @@ def run_txt2img_pipeline(
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_pipeline(ctx, upscale, image)
|
image = run_upscale_correction(ctx, upscale, image)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
@ -151,7 +151,7 @@ def run_img2img_pipeline(
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_pipeline(ctx, upscale, image)
|
image = run_upscale_correction(ctx, upscale, image)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
@ -215,7 +215,7 @@ def run_inpaint_pipeline(
|
||||||
else:
|
else:
|
||||||
print('output image size does not match source, skipping post-blend')
|
print('output image size does not match source, skipping post-blend')
|
||||||
|
|
||||||
image = run_upscale_pipeline(ctx, upscale, image)
|
image = run_upscale_correction(ctx, upscale, image)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
@ -234,7 +234,7 @@ def run_upscale_pipeline(
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
source_image: Image
|
source_image: Image
|
||||||
):
|
):
|
||||||
image = upscale_resrgan(ctx, upscale, source_image)
|
image = run_upscale_correction(ctx, upscale, source_image)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = safer_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
|
|
@ -3,13 +3,8 @@ from diffusers import (
|
||||||
OnnxRuntimeModel,
|
OnnxRuntimeModel,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from diffusers.models import (
|
|
||||||
CLIPTokenizer,
|
|
||||||
)
|
|
||||||
from diffusers.schedulers import (
|
|
||||||
KarrasDiffusionSchedulers,
|
|
||||||
)
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Union,
|
Union,
|
||||||
List,
|
List,
|
||||||
|
@ -25,34 +20,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||||
self,
|
self,
|
||||||
vae: OnnxRuntimeModel,
|
vae: OnnxRuntimeModel,
|
||||||
text_encoder: OnnxRuntimeModel,
|
text_encoder: OnnxRuntimeModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: Any,
|
||||||
unet: OnnxRuntimeModel,
|
unet: OnnxRuntimeModel,
|
||||||
low_res_scheduler: DDPMScheduler,
|
low_res_scheduler: DDPMScheduler,
|
||||||
scheduler: KarrasDiffusionSchedulers,
|
scheduler: Any,
|
||||||
max_noise_level: int = 350,
|
max_noise_level: int = 350,
|
||||||
):
|
):
|
||||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
super().__init__(vae, text_encoder, tokenizer, unet,
|
||||||
|
low_res_scheduler, scheduler, max_noise_level)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]] = None,
|
*args,
|
||||||
image: Union[torch.FloatTensor, PIL.Image.Image,
|
**kwargs,
|
||||||
List[PIL.Image.Image]] = None,
|
|
||||||
num_inference_steps: int = 75,
|
|
||||||
guidance_scale: float = 9.0,
|
|
||||||
noise_level: int = 20,
|
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
num_images_per_prompt: Optional[int] = 1,
|
|
||||||
eta: float = 0.0,
|
|
||||||
generator: Optional[Union[torch.Generator,
|
|
||||||
List[torch.Generator]]] = None,
|
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
output_type: Optional[str] = "pil",
|
|
||||||
return_dict: bool = True,
|
|
||||||
callback: Optional[Callable[[
|
|
||||||
int, int, torch.FloatTensor], None]] = None,
|
|
||||||
callback_steps: Optional[int] = 1,
|
|
||||||
):
|
):
|
||||||
pass
|
super().__call__(*args, **kwargs)
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from diffusers import (
|
||||||
|
AutoencoderKL,
|
||||||
|
DDPMScheduler,
|
||||||
|
)
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from os import path
|
from os import path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -127,12 +131,20 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
|
||||||
|
|
||||||
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||||
print('upscaling with Stable Diffusion')
|
print('upscaling with Stable Diffusion')
|
||||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(params.upscale_model)
|
model_path = '../models/%s' % params.upscale_model
|
||||||
|
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||||
|
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||||
|
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||||
|
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'),
|
||||||
|
low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'),
|
||||||
|
)
|
||||||
result = pipeline('', image=image)
|
result = pipeline('', image=image)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||||
print('running upscale pipeline')
|
print('running upscale pipeline')
|
||||||
|
|
||||||
if params.scale > 1:
|
if params.scale > 1:
|
||||||
|
|
Loading…
Reference in New Issue