1
0
Fork 0

fix(api): pass additional params to new stages

This commit is contained in:
Sean Sube 2023-06-30 07:20:49 -05:00
parent 7a951065e4
commit 7a73c9ff61
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 50 additions and 83 deletions

View File

@ -64,7 +64,6 @@ def blend_img2img(
image=source, image=source,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength,
callback=callback, callback=callback,
**pipe_params, **pipe_params,
) )
@ -81,7 +80,6 @@ def blend_img2img(
image=source, image=source,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
strength=strength,
callback=callback, callback=callback,
**pipe_params, **pipe_params,
) )

View File

@ -1,13 +1,11 @@
from logging import getLogger from logging import getLogger
from typing import Any, Optional from typing import Any, Optional
import numpy as np
import torch
from PIL import Image from PIL import Image
from ..diffusers.load import load_pipeline from ..chain.base import ChainPipeline
from ..chain.img2img import blend_img2img
from ..diffusers.upscale import append_upscale_correction from ..diffusers.upscale import append_upscale_correction
from ..diffusers.utils import parse_prompt
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import WorkerContext from ..worker import WorkerContext
@ -30,25 +28,12 @@ def upscale_highres(
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
image = stage_source or source source = stage_source or source
if highres.scale <= 1: if highres.scale <= 1:
return image return source
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for highres", pipe_type)
_prompt_pairs, loras, inversions = parse_prompt(params)
highres_pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
chain = ChainPipeline()
scaled_size = (source.width * highres.scale, source.height * highres.scale) scaled_size = (source.width * highres.scale, source.height * highres.scale)
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM # TODO: upscaling within the same stage prevents tiling from happening and causes OOM
@ -60,7 +45,7 @@ def upscale_highres(
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else: else:
logger.debug("using upscaling pipeline for highres") logger.debug("using upscaling pipeline for highres")
upscale = append_upscale_correction( append_upscale_correction(
StageParams(), StageParams(),
params, params,
upscale=upscale.with_args( upscale=upscale.with_args(
@ -68,41 +53,24 @@ def upscale_highres(
scale=highres.scale, scale=highres.scale,
outscale=highres.scale, outscale=highres.scale,
), ),
) chain=chain,
source = upscale(
job,
server,
source,
callback=callback,
) )
if pipe_type == "lpw": chain.append(
rng = torch.manual_seed(params.seed) (
result = highres_pipe.img2img( blend_img2img,
source, StageParams(),
params.prompt, {
generator=rng, "overlap": params.overlap,
guidance_scale=params.cfg, "strength": highres.strength,
negative_prompt=params.negative_prompt, },
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=callback,
) )
return result.images[0] )
else:
rng = np.random.RandomState(params.seed) return chain(
result = highres_pipe( job,
params.prompt, server,
source, params,
generator=rng, source,
guidance_scale=params.cfg, callback=callback,
negative_prompt=params.negative_prompt, )
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=callback,
)
return result.images[0]

View File

@ -63,16 +63,17 @@ def run_txt2img_pipeline(
) )
# apply highres # apply highres
chain.append( for _i in range(highres.iterations):
( chain.append(
upscale_highres, (
stage, upscale_highres,
{ stage,
"highres": highres, {
"upscale": upscale, "highres": highres,
}, "upscale": upscale,
},
)
) )
)
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
append_upscale_correction( append_upscale_correction(

View File

@ -1,16 +1,13 @@
from logging import getLogger from logging import getLogger
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..chain import ( from ..chain import ChainPipeline, PipelineStage
ChainPipeline, from ..chain.correct_codeformer import correct_codeformer
PipelineStage, from ..chain.correct_gfpgan import correct_gfpgan
correct_codeformer, from ..chain.upscale_bsrgan import upscale_bsrgan
correct_gfpgan, from ..chain.upscale_resrgan import upscale_resrgan
upscale_bsrgan, from ..chain.upscale_stable_diffusion import upscale_stable_diffusion
upscale_resrgan, from ..chain.upscale_swinir import upscale_swinir
upscale_stable_diffusion,
upscale_swinir,
)
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
logger = getLogger(__name__) logger = getLogger(__name__)
@ -65,6 +62,9 @@ def append_upscale_correction(
for stage, params in pre_stages: for stage, params in pre_stages:
chain.append((stage, params)) chain.append((stage, params))
upscale_opts = {
"upscale": upscale,
}
upscale_stage = None upscale_stage = None
if upscale.scale > 1: if upscale.scale > 1:
if "bsrgan" in upscale.upscale_model: if "bsrgan" in upscale.upscale_model:
@ -72,23 +72,23 @@ def append_upscale_correction(
tile_size=stage.tile_size, tile_size=stage.tile_size,
outscale=upscale.outscale, outscale=upscale.outscale,
) )
upscale_stage = (upscale_bsrgan, bsrgan_params, None) upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model: elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams( esrgan_params = StageParams(
tile_size=stage.tile_size, tile_size=stage.tile_size,
outscale=upscale.outscale, outscale=upscale.outscale,
) )
upscale_stage = (upscale_resrgan, esrgan_params, None) upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model: elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, None) upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model: elif "swinir" in upscale.upscale_model:
swinir_params = StageParams( swinir_params = StageParams(
tile_size=stage.tile_size, tile_size=stage.tile_size,
outscale=upscale.outscale, outscale=upscale.outscale,
) )
upscale_stage = (upscale_swinir, swinir_params, None) upscale_stage = (upscale_swinir, swinir_params, upscale_opts)
else: else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model) logger.warn("unknown upscaling model: %s", upscale.upscale_model)

View File

@ -87,14 +87,14 @@ class Size:
def tojson(self) -> Dict[str, int]: def tojson(self) -> Dict[str, int]:
return { return {
"height": self.height,
"width": self.width, "width": self.width,
"height": self.height,
} }
def with_args(self, **kwargs): def with_args(self, **kwargs):
return Size( return Size(
kwargs.get("height", self.height),
kwargs.get("width", self.width), kwargs.get("width", self.width),
kwargs.get("height", self.height),
) )