1
0
Fork 0

feat(api): start using chain pipelines for all images

This commit is contained in:
Sean Sube 2023-06-29 23:06:36 -05:00
parent 4c3fcace5e
commit fd3e65eafc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
10 changed files with 362 additions and 606 deletions

View File

@ -17,7 +17,7 @@ from .diffusers.run import (
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import run_upscale_correction
from .diffusers.upscale import append_upscale_correction
from .image.utils import (
expand_image,
valid_image,

View File

@ -1,10 +1,8 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_controlnet import blend_controlnet
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_linear import blend_linear
from .blend_mask import blend_mask
from .blend_pix2pix import blend_pix2pix
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
@ -16,18 +14,17 @@ from .source_s3 import source_s3
from .source_txt2img import source_txt2img
from .source_url import source_url
from .upscale_bsrgan import upscale_bsrgan
from .upscale_highres import upscale_highres
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion
from .upscale_swinir import upscale_swinir
CHAIN_STAGES = {
"blend-controlnet": blend_controlnet,
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-linear": blend_linear,
"blend-mask": blend_mask,
"blend-pix2pix": blend_pix2pix,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
@ -39,6 +36,7 @@ CHAIN_STAGES = {
"source-txt2img": source_txt2img,
"source-url": source_url,
"upscale-bsrgan": upscale_bsrgan,
"upscale-highres": upscale_highres,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,

View File

@ -88,7 +88,7 @@ class ChainPipeline:
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Image.Image,
source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> Image.Image:

View File

@ -1,54 +0,0 @@
from logging import getLogger
from typing import Optional
import numpy as np
from PIL import Image
from ..diffusers.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_controlnet(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
callback: Optional[ProgressCallback] = None,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using ControlNet, %s steps: %s", params.steps, params.prompt
)
pipe = load_pipeline(
server,
params,
"controlnet",
job.get_device(),
)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength, # TODO: ControlNet strength
callback=callback,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -6,6 +6,7 @@ import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, parse_prompt
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
@ -20,8 +21,9 @@ def blend_img2img(
params: ImageParams,
source: Image.Image,
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Image.Image,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
@ -30,14 +32,28 @@ def blend_img2img(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)
pipe_type = "lpw" if params.lpw() else "img2img"
prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
inversions=inversions,
loras=loras,
)
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
@ -50,8 +66,13 @@ def blend_img2img(
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
@ -62,6 +83,7 @@ def blend_img2img(
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
**pipe_params,
)
output = result.images[0]

View File

@ -1,71 +0,0 @@
from logging import getLogger
from typing import Optional
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_pix2pix(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
callback: Optional[ProgressCallback] = None,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using instruct pix2pix, %s steps: %s",
params.steps,
params.prompt,
)
pipe = load_pipeline(
server,
params,
"pix2pix",
job.get_device(),
# TODO: add LoRAs and TIs
)
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=params.strength,
callback=callback,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -6,7 +6,7 @@ import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..params import ImageParams, Size, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
@ -36,14 +36,17 @@ def source_txt2img(
"a source image was passed to a txt2img stage, and will be discarded"
)
prompt_pairs, loras, inversions = parse_prompt(params)
latents = get_latents_from_seed(params.seed, size)
pipe_type = "lpw" if params.lpw() else "txt2img"
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
inversions=inversions,
loras=loras,
)
if params.lpw():
@ -61,6 +64,10 @@ def source_txt2img(
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,

View File

@ -0,0 +1,108 @@
from logging import getLogger
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.upscale import append_upscale_correction
from ..diffusers.utils import parse_prompt
from ..params import HighresParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
logger = getLogger(__name__)
def upscale_highres(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
size: Size,
stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
image = stage_source or source
if highres.scale <= 1:
return image
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for highres", pipe_type)
_prompt_pairs, loras, inversions = parse_prompt(params)
highres_pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
scaled_size = (source.width * highres.scale, source.height * highres.scale)
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
upscale = append_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
)
source = upscale(
job,
server,
source,
callback=callback,
)
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=callback,
)
return result.images[0]
else:
rng = np.random.RandomState(params.seed)
result = highres_pipe(
params.prompt,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=callback,
)
return result.images[0]

View File

@ -1,12 +1,10 @@
from logging import getLogger
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional
import numpy as np
import torch
from PIL import Image
from ..chain import blend_mask, upscale_outpaint
from ..chain.utils import process_tile_order
from ..chain import blend_img2img, blend_mask, upscale_highres, upscale_outpaint
from ..chain.base import ChainPipeline
from ..output import save_image
from ..params import (
Border,
@ -14,213 +12,18 @@ from ..params import (
ImageParams,
Size,
StageParams,
TileOrder,
UpscaleParams,
)
from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .load import load_pipeline
from .upscale import run_upscale_correction
from .utils import encode_prompt, get_latents_from_seed, parse_prompt
from .upscale import append_upscale_correction, split_upscale
from .utils import parse_prompt
logger = getLogger(__name__)
def run_loopback(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
strength: float,
image: Image.Image,
progress: ProgressCallback,
inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]],
pipeline: Optional[Any] = None,
) -> Image.Image:
if params.loopback == 0:
return image
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
if pipe_type == "controlnet":
logger.debug(
"controlnet pipeline cannot be used for loopback, switching to img2img"
)
pipe_type = "img2img"
logger.debug("using %s pipeline for loopback", pipe_type)
pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
def loopback_iteration(source: Image.Image):
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
for _i in range(params.loopback):
image = loopback_iteration(image)
return image
def run_highres(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
size: Size,
upscale: UpscaleParams,
highres: HighresParams,
image: Image.Image,
progress: ProgressCallback,
inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]],
pipeline: Optional[Any] = None,
) -> Image.Image:
if highres.scale <= 1:
return image
if upscale.faces and (
upscale.upscale_order == "correction-both"
or upscale.upscale_order == "correction-first"
):
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale.with_args(
scale=1,
outscale=1,
),
callback=progress,
)
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for highres", pipe_type)
highres_pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
def highres_tile(tile: Image.Image, dims):
scaled_size = (tile.width * highres.scale, tile.height * highres.scale)
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
tile = tile.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
tile = tile.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
tile = run_upscale_correction(
job,
server,
StageParams(),
params,
tile,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
callback=progress,
)
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
tile,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
else:
rng = np.random.RandomState(params.seed)
result = highres_pipe(
params.prompt,
tile,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=highres.steps,
strength=highres.strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
logger.info(
"running highres fix for %s iterations at %s scale",
highres.iterations,
highres.scale,
)
for _i in range(highres.iterations):
image = process_tile_order(
TileOrder.grid,
image,
size.height // highres.scale,
highres.scale,
[highres_tile],
overlap=params.overlap,
)
return image
def run_txt2img_pipeline(
job: WorkerContext,
server: ServerContext,
@ -230,103 +33,52 @@ def run_txt2img_pipeline(
upscale: UpscaleParams,
highres: HighresParams,
) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
prompt_pairs, loras, inversions = parse_prompt(params)
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
chain.append((blend_img2img, stage, None))
pipe_type = params.get_valid_pipeline("txt2img")
logger.debug("using %s pipeline for txt2img", pipe_type)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
pipe = load_pipeline(
server,
# apply highres
chain.append((upscale_highres, stage, None))
# apply upscaling and correction, after highres
append_upscale_correction(
StageParams(),
params,
pipe_type,
job.get_device(),
upscale=upscale,
chain=chain,
)
# run and save
image = chain(job, server, params, None)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
progress = job.get_progress_callback()
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)
image_outputs = list(zip(result.images, outputs))
del result
del pipe
for image, output in image_outputs:
image = run_highres(
job,
server,
params,
size,
upscale,
highres,
image,
progress,
inversions,
loras,
)
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished txt2img job: {dest}")
logger.info("finished txt2img job: %s", dest)
@ -342,110 +94,75 @@ def run_img2img_pipeline(
strength: float,
source_filter: Optional[str] = None,
) -> None:
prompt_pairs, loras, inversions = parse_prompt(params)
# filter the source image
# run filter on the source image
if source_filter is not None:
f = get_source_filters().get(source_filter, None)
if f is not None:
logger.debug("running source filter: %s", f.__name__)
source = f(server, source)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
# prepare the chain pipeline and first stage
chain = ChainPipeline()
stage = StageParams()
chain.append(
(
blend_img2img,
stage,
{
"strength": strength,
},
)
)
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
progress = job.get_progress_callback()
if pipe_type == "lpw":
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
**pipe_params,
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
images = result.images
# loopback through multiple img2img iterations
if params.loopback > 0:
for _i in range(params.loopback):
chain.append(
(
blend_img2img,
stage,
{
"strength": strength,
},
)
)
# highres, if selected
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.append((upscale_highres, stage, None))
# apply upscaling and correction, after highres
append_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
)
# run and append the filtered source
images = [
chain(job, server, params, source),
]
if source_filter is not None and source_filter != "none":
images.append(source)
# save with metadata
_prompt_pairs, loras, inversions = parse_prompt(params)
size = Size(*source.size)
for image, output in zip(images, outputs):
image = run_loopback(
job,
server,
params,
strength,
image,
progress,
inversions,
loras,
)
size = Size(*source.size)
image = run_highres(
job,
server,
params,
size,
upscale,
highres,
image,
progress,
inversions,
loras,
)
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(
server,
output,
@ -458,7 +175,10 @@ def run_img2img_pipeline(
loras=loras,
)
# clean up
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished img2img job: {dest}")
logger.info("finished img2img job: %s", dest)
@ -479,49 +199,48 @@ def run_inpaint_pipeline(
fill_color: str,
tile_order: str,
) -> None:
progress = job.get_progress_callback()
logger.debug("building inpaint pipeline")
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
chain.append(
(
upscale_outpaint,
stage,
{
"border": border,
"stage_mask": mask,
"fill_color": fill_color,
"mask_filter": mask_filter,
"noise_source": noise_source,
},
)
)
# apply highres
chain.append(
(
upscale_highres,
stage,
{
"highres": highres,
},
)
)
# apply upscaling and correction
append_upscale_correction(
stage,
params,
upscale=upscale,
chain=chain,
)
# run and save
image = chain(job, server, params, source)
_prompt_pairs, loras, inversions = parse_prompt(params)
logger.debug("applying mask filter and generating noise source")
image = upscale_outpaint(
job,
server,
stage,
params,
source,
border=border,
stage_mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
callback=progress,
)
image = run_highres(
job,
server,
params,
size,
upscale,
highres,
image,
progress,
inversions,
loras,
)
image = run_upscale_correction(
job,
server,
stage,
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(
server,
outputs[0],
@ -534,9 +253,11 @@ def run_inpaint_pipeline(
loras=loras,
)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished inpaint job: {dest}")
logger.info("finished inpaint job: %s", dest)
@ -551,34 +272,50 @@ def run_upscale_pipeline(
highres: HighresParams,
source: Image.Image,
) -> None:
progress = job.get_progress_callback()
# set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline()
stage = StageParams()
_prompt_pairs, loras, inversions = parse_prompt(params)
# apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale)
if first_upscale:
append_upscale_correction(
stage,
params,
upscale=first_upscale,
chain=chain,
)
image = run_upscale_correction(
job, server, stage, params, source, upscale=upscale, callback=progress
# apply highres
chain.append((upscale_highres, stage, None))
# apply upscaling and correction, after highres
append_upscale_correction(
stage,
params,
upscale=after_upscale,
chain=chain,
)
# TODO: should this come first?
image = run_highres(
job,
# run and save
image = chain(job, server, params, source)
_prompt_pairs, loras, inversions = parse_prompt(params)
dest = save_image(
server,
outputs[0],
image,
params,
size,
upscale,
highres,
image,
progress,
inversions,
loras,
upscale=upscale,
inversions=inversions,
loras=loras,
)
dest = save_image(server, outputs[0], image, params, size, upscale=upscale)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished upscale job: {dest}")
logger.info("finished upscale job: %s", dest)
@ -594,28 +331,27 @@ def run_blend_pipeline(
sources: List[Image.Image],
mask: Image.Image,
) -> None:
progress = job.get_progress_callback()
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams()
stage.append((blend_mask, stage, None))
image = blend_mask(
job,
server,
# apply upscaling and correction
append_upscale_correction(
stage,
params,
sources=sources,
stage_mask=mask,
callback=progress,
)
image = image.convert("RGB")
image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale, callback=progress
upscale=upscale,
chain=chain,
)
# run and save
image = chain(job, server, params, sources[0])
dest = save_image(server, outputs[0], image, params, size, upscale=upscale)
# clean up
del image
run_gc([job.get_device()])
# notify the user
show_system_toast(f"finished blend job: {dest}")
logger.info("finished blend job: %s", dest)

View File

@ -1,7 +1,5 @@
from logging import getLogger
from typing import List, Optional
from PIL import Image
from typing import List, Optional, Tuple
from ..chain import (
ChainPipeline,
@ -14,24 +12,42 @@ from ..chain import (
upscale_swinir,
)
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def run_upscale_correction(
job: WorkerContext,
server: ServerContext,
def split_upscale(
upscale: UpscaleParams,
) -> Tuple[Optional[UpscaleParams], UpscaleParams]:
if upscale.faces and (
upscale.upscale_order == "correction-both"
or upscale.upscale_order == "correction-first"
):
return (
upscale.with_args(
scale=1,
outscale=1,
),
upscale.with_args(
upscale_order="correction-last",
),
)
else:
return (
None,
upscale,
)
def append_upscale_correction(
stage: StageParams,
params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
callback: Optional[ProgressCallback] = None,
chain: Optional[ChainPipeline] = None,
pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None,
) -> Image.Image:
) -> ChainPipeline:
"""
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
@ -42,7 +58,9 @@ def run_upscale_correction(
upscale.outscale,
)
chain = ChainPipeline()
if chain is None:
chain = ChainPipeline()
if pre_stages is not None:
for stage, params in pre_stages:
chain.append((stage, params))
@ -103,12 +121,4 @@ def run_upscale_correction(
for stage, params in post_stages:
chain.append((stage, params))
return chain(
job,
server,
params,
image,
prompt=params.prompt,
upscale=upscale,
callback=callback,
)
return chain