1
0
Fork 0
onnx-web/api/onnx_web/chain/upscale_highres.py

109 lines
3.2 KiB
Python
Raw Normal View History

from logging import getLogger
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.upscale import append_upscale_correction
from ..diffusers.utils import parse_prompt
from ..params import HighresParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
logger = getLogger(__name__)
def upscale_highres(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
size: Size,
stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
image = stage_source or source
if highres.scale <= 1:
return image
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for highres", pipe_type)
_prompt_pairs, loras, inversions = parse_prompt(params)
highres_pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
scaled_size = (source.width * highres.scale, source.height * highres.scale)
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
upscale = append_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
)
source = upscale(
job,
server,
source,
callback=callback,
)
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=callback,
)
return result.images[0]
else:
rng = np.random.RandomState(params.seed)
result = highres_pipe(
params.prompt,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=callback,
)
return result.images[0]