1
0
Fork 0

fix(api): generate latents before tiling

This commit is contained in:
Sean Sube 2023-07-09 22:19:02 -05:00
parent c15f750821
commit 60aa8ab4c0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 56 additions and 172 deletions

View File

@ -1,6 +1,5 @@
from .base import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
@ -23,7 +22,7 @@ from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = {
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": BlendInpaintStage,
"blend-inpaint": UpscaleOutpaintStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,

View File

@ -146,7 +146,7 @@ class ChainPipeline:
)
def stage_tile(
source_tile: Image.Image, tile_mask: Image.Image, _dims
source_tile: Image.Image, tile_mask: Image.Image, dims: Tuple[int, int, int]
) -> Image.Image:
output_tile = stage_pipe.run(
job,
@ -156,6 +156,7 @@ class ChainPipeline:
[source_tile],
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
)[0]

View File

@ -1,140 +0,0 @@
from logging import getLogger
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import process_tile_order
logger = getLogger(__name__)
class BlendInpaintStage(BaseStage):
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
*,
expand: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)
pipe_type = params.get_valid_pipeline("inpaint")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
outputs = []
for source in sources:
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
outputs.append(
process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
)

View File

@ -1,12 +1,12 @@
from logging import getLogger
from typing import Optional
from typing import Optional, Tuple
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
@ -26,8 +26,10 @@ class SourceTxt2ImgStage(BaseStage):
params: ImageParams,
_source: Image.Image,
*,
dims: Tuple[int, int, int],
size: Size,
callback: Optional[ProgressCallback] = None,
latents: Optional[np.ndarray] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
@ -47,15 +49,23 @@ class SourceTxt2ImgStage(BaseStage):
)
tile_size = params.tiles
if max(size) > tile_size:
latent_size = Size(tile_size, tile_size)
# generate new latents or slice existing
if latents is None:
if max(size) > tile_size:
latent_size = Size(tile_size, tile_size)
pipe_width = pipe_height = tile_size
else:
latent_size = Size(size.width, size.height)
pipe_width = size.width
pipe_height = size.height
# generate new latents
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
pipe_width = pipe_height = tile_size
else:
latent_size = Size(size.width, size.height)
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
pipe_width = size.width
pipe_height = size.height
# slice existing latents
latents = get_tile_latents(latents, dims, size)
pipe_width, pipe_height, _tile_size = dims
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(

View File

@ -6,7 +6,7 @@ import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt
from ..image import mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
@ -28,15 +28,17 @@ class UpscaleOutpaintStage(BaseStage):
stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
tile_mask: Image.Image,
*,
border: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
dims: Tuple[int, int, int],
tile_mask: Image.Image,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
latents: Optional[np.ndarray] = None,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
**kwargs,
) -> List[Image.Image]:
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
@ -64,18 +66,25 @@ class UpscaleOutpaintStage(BaseStage):
outputs.append(source)
continue
source_width, source_height = source.size
source_size = Size(source_width, source_height)
size = Size(*source.size)
tile_size = params.tiles
if max(source_size) > tile_size:
latent_size = Size(tile_size, tile_size)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width = pipe_height = tile_size
# generate new latents or slice existing
if latents is None:
if max(size) > tile_size:
latent_size = Size(tile_size, tile_size)
pipe_width = pipe_height = tile_size
else:
latent_size = Size(size.width, size.height)
pipe_width = size.width
pipe_height = size.height
# generate new latents
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
latent_size = Size(source_size.width, source_size.height)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width = source_size.width
pipe_height = source_size.height
# slice existing latents
latents = get_tile_latents(latents, dims, size)
pipe_width, pipe_height, _tile_size = dims
if params.lpw():
logger.debug("using LPW pipeline for inpaint")

View File

@ -27,7 +27,7 @@ from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import is_debug, run_gc, show_system_toast
from ..worker import WorkerContext
from .utils import parse_prompt
from .utils import get_latents_from_seed, parse_prompt
logger = getLogger(__name__)
@ -81,8 +81,9 @@ def run_txt2img_pipeline(
)
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = job.get_progress_callback()
images = chain(job, server, params, [], callback=progress)
images = chain.run(job, server, params, [], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params)
@ -287,8 +288,9 @@ def run_inpaint_pipeline(
)
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress)
images = chain(job, server, params, [source], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):

View File

@ -17,7 +17,7 @@ from ..diffusers.run import (
)
from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name
from ..params import Border, StageParams, TileOrder, UpscaleParams
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
base_join,
@ -163,8 +163,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
size = Size(source.width, source.height)
device, params, size = pipeline_from_request(server, "img2img")
device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request()
highres = highres_from_request()
source_filter = get_from_list(
@ -249,12 +250,14 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
size = Size(source.width, source.height)
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255))
mask.alpha_composite(mask_top_layer)
mask.convert(mode="L")
device, params, size = pipeline_from_request(server, "inpaint")
device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request()
upscale = upscale_from_request()
highres = highres_from_request()