From ad35c41c9de3195c0fd80f57226c046223548ed2 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 14 Apr 2023 08:54:21 -0500 Subject: [PATCH] feat(api): add highres to img2img mode for all pipelines --- api/onnx_web/diffusers/run.py | 252 ++++++++++++++++++++-------------- api/onnx_web/server/api.py | 4 +- 2 files changed, 153 insertions(+), 103 deletions(-) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index f69421e5..72fee1f7 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import numpy as np import torch @@ -22,6 +22,7 @@ from ..server import ServerContext from ..server.load import get_source_filters from ..utils import run_gc from ..worker import WorkerContext +from ..worker.context import ProgressCallback from .load import get_latents_from_seed, load_pipeline from .upscale import run_upscale_correction from .utils import get_inversions_from_prompt, get_loras_from_prompt @@ -29,6 +30,126 @@ from .utils import get_inversions_from_prompt, get_loras_from_prompt logger = getLogger(__name__) +def run_highres( + job: WorkerContext, + server: ServerContext, + params: ImageParams, + size: Size, + upscale: UpscaleParams, + highres: HighresParams, + image: Image.Image, + progress: ProgressCallback, + inversions: List[Tuple[str, float]], + loras: List[Tuple[str, float]], +) -> None: + highres_progress = ChainProgress.from_progress(progress) + + if upscale.faces and ( + upscale.upscale_order == "correction-both" + or upscale.upscale_order == "correction-first" + ): + image = run_upscale_correction( + job, + server, + StageParams(), + params, + image, + upscale=upscale.with_args( + scale=1, + outscale=1, + ), + callback=highres_progress, + ) + + # load img2img pipeline once + highres_pipe = load_pipeline( + server, + "img2img", + params.model, + params.scheduler, + job.get_device(), + inversions=inversions, + loras=loras, + ) + + def highres_tile(tile: Image.Image, dims): + if highres.method == "bilinear": + logger.debug("using bilinear interpolation for highres") + tile = tile.resize( + (size.height, size.width), resample=Image.Resampling.BILINEAR + ) + elif highres.method == "lanczos": + logger.debug("using Lanczos interpolation for highres") + tile = tile.resize( + (size.height, size.width), resample=Image.Resampling.LANCZOS + ) + else: + logger.debug("using upscaling pipeline for highres") + tile = run_upscale_correction( + job, + server, + StageParams(), + params, + tile, + upscale=upscale.with_args( + faces=False, + scale=highres.scale, + outscale=highres.scale, + ), + callback=highres_progress, + ) + + if params.lpw(): + logger.debug("using LPW pipeline for highres") + rng = torch.manual_seed(params.seed) + result = highres_pipe.img2img( + tile, + params.prompt, + generator=rng, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_images_per_prompt=1, + num_inference_steps=highres.steps, + strength=highres.strength, + eta=params.eta, + callback=highres_progress, + ) + return result.images[0] + else: + rng = np.random.RandomState(params.seed) + result = highres_pipe( + params.prompt, + tile, + generator=rng, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_images_per_prompt=1, + num_inference_steps=highres.steps, + strength=highres.strength, + eta=params.eta, + callback=highres_progress, + ) + return result.images[0] + + logger.info( + "running highres fix for %s iterations at %s scale", + highres.iterations, + highres.scale, + ) + + for _i in range(highres.iterations): + image = process_tile_order( + TileOrder.grid, + image, + size.height // highres.scale, + highres.scale, + [highres_tile], + overlap=0, + ) + + return image + + def run_txt2img_pipeline( job: WorkerContext, server: ServerContext, @@ -93,110 +214,19 @@ def run_txt2img_pipeline( for image, output in image_outputs: if highres.scale > 1: - highres_progress = ChainProgress.from_progress(progress) - - if upscale.faces and ( - upscale.upscale_order == "correction-both" - or upscale.upscale_order == "correction-first" - ): - image = run_upscale_correction( - job, - server, - StageParams(), - params, - image, - upscale=upscale.with_args( - scale=1, - outscale=1, - ), - callback=highres_progress, - ) - - # load img2img pipeline once - highres_pipe = load_pipeline( + image = run_highres( + job, server, - "img2img", - params.model, - params.scheduler, - job.get_device(), - inversions=inversions, - loras=loras, + params, + size, + upscale, + highres, + image, + progress, + inversions, + loras, ) - def highres_tile(tile: Image.Image, dims): - if highres.method == "bilinear": - logger.debug("using bilinear interpolation for highres") - tile = tile.resize( - (size.height, size.width), resample=Image.Resampling.BILINEAR - ) - elif highres.method == "lanczos": - logger.debug("using Lanczos interpolation for highres") - tile = tile.resize( - (size.height, size.width), resample=Image.Resampling.LANCZOS - ) - else: - logger.debug("using upscaling pipeline for highres") - tile = run_upscale_correction( - job, - server, - StageParams(), - params, - tile, - upscale=upscale.with_args( - faces=False, - scale=highres.scale, - outscale=highres.scale, - ), - callback=highres_progress, - ) - - if params.lpw(): - logger.debug("using LPW pipeline for highres") - rng = torch.manual_seed(params.seed) - result = highres_pipe.img2img( - tile, - params.prompt, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=highres.steps, - strength=highres.strength, - eta=params.eta, - callback=highres_progress, - ) - return result.images[0] - else: - rng = np.random.RandomState(params.seed) - result = highres_pipe( - params.prompt, - tile, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=highres.steps, - strength=highres.strength, - eta=params.eta, - callback=highres_progress, - ) - return result.images[0] - - logger.info( - "running highres fix for %s iterations at %s scale", - highres.iterations, - highres.scale, - ) - for _i in range(highres.iterations): - image = process_tile_order( - TileOrder.grid, - image, - size.height // highres.scale, - highres.scale, - [highres_tile], - overlap=0, - ) - image = run_upscale_correction( job, server, @@ -221,6 +251,7 @@ def run_img2img_pipeline( params: ImageParams, outputs: List[str], upscale: UpscaleParams, + highres: HighresParams, source: Image.Image, strength: float, source_filter: Optional[str] = None, @@ -290,6 +321,20 @@ def run_img2img_pipeline( images.append(source) for image, output in zip(images, outputs): + if highres.scale > 1: + image = run_highres( + job, + server, + params, + Size(source.width, source.height), + upscale, + highres, + image, + progress, + inversions, + loras, + ) + image = run_upscale_correction( job, server, @@ -316,6 +361,7 @@ def run_inpaint_pipeline( size: Size, outputs: List[str], upscale: UpscaleParams, + highres: HighresParams, source: Image.Image, mask: Image.Image, border: Border, @@ -372,6 +418,7 @@ def run_upscale_pipeline( size: Size, outputs: List[str], upscale: UpscaleParams, + highres: HighresParams, source: Image.Image, ) -> None: progress = job.get_progress_callback() @@ -398,6 +445,7 @@ def run_blend_pipeline( size: Size, outputs: List[str], upscale: UpscaleParams, + highres: HighresParams, sources: List[Image.Image], mask: Image.Image, ) -> None: diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 3d5247af..ebbfc984 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -180,7 +180,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): if source_filter is not None: output_count += 1 - output = make_output_name(server, "img2img", params, size, extras=[strength], count=output_count) + output = make_output_name( + server, "img2img", params, size, extras=[strength], count=output_count + ) job_name = output[0] logger.info("img2img job queued for: %s", job_name)