fix(api): generate latents before tiling
This commit is contained in:
parent
c15f750821
commit
60aa8ab4c0
|
@ -1,6 +1,5 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageParams
|
from .base import ChainPipeline, PipelineStage, StageParams
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
from .blend_inpaint import BlendInpaintStage
|
|
||||||
from .blend_linear import BlendLinearStage
|
from .blend_linear import BlendLinearStage
|
||||||
from .blend_mask import BlendMaskStage
|
from .blend_mask import BlendMaskStage
|
||||||
from .correct_codeformer import CorrectCodeformerStage
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
|
@ -23,7 +22,7 @@ from .upscale_swinir import UpscaleSwinIRStage
|
||||||
|
|
||||||
CHAIN_STAGES = {
|
CHAIN_STAGES = {
|
||||||
"blend-img2img": BlendImg2ImgStage,
|
"blend-img2img": BlendImg2ImgStage,
|
||||||
"blend-inpaint": BlendInpaintStage,
|
"blend-inpaint": UpscaleOutpaintStage,
|
||||||
"blend-linear": BlendLinearStage,
|
"blend-linear": BlendLinearStage,
|
||||||
"blend-mask": BlendMaskStage,
|
"blend-mask": BlendMaskStage,
|
||||||
"correct-codeformer": CorrectCodeformerStage,
|
"correct-codeformer": CorrectCodeformerStage,
|
||||||
|
|
|
@ -146,7 +146,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
def stage_tile(
|
def stage_tile(
|
||||||
source_tile: Image.Image, tile_mask: Image.Image, _dims
|
source_tile: Image.Image, tile_mask: Image.Image, dims: Tuple[int, int, int]
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
output_tile = stage_pipe.run(
|
output_tile = stage_pipe.run(
|
||||||
job,
|
job,
|
||||||
|
@ -156,6 +156,7 @@ class ChainPipeline:
|
||||||
[source_tile],
|
[source_tile],
|
||||||
tile_mask=tile_mask,
|
tile_mask=tile_mask,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
|
dims=dims,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
|
|
@ -1,140 +0,0 @@
|
||||||
from logging import getLogger
|
|
||||||
from typing import Callable, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..diffusers.load import load_pipeline
|
|
||||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
|
||||||
from ..output import save_image
|
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
|
||||||
from ..server import ServerContext
|
|
||||||
from ..utils import is_debug
|
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
|
||||||
from .stage import BaseStage
|
|
||||||
from .tile import process_tile_order
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BlendInpaintStage(BaseStage):
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
job: WorkerContext,
|
|
||||||
server: ServerContext,
|
|
||||||
stage: StageParams,
|
|
||||||
params: ImageParams,
|
|
||||||
sources: List[Image.Image],
|
|
||||||
*,
|
|
||||||
expand: Border,
|
|
||||||
stage_source: Optional[Image.Image] = None,
|
|
||||||
stage_mask: Optional[Image.Image] = None,
|
|
||||||
fill_color: str = "white",
|
|
||||||
mask_filter: Callable = mask_filter_none,
|
|
||||||
noise_source: Callable = noise_source_histogram,
|
|
||||||
callback: Optional[ProgressCallback] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> List[Image.Image]:
|
|
||||||
params = params.with_args(**kwargs)
|
|
||||||
expand = expand.with_args(**kwargs)
|
|
||||||
logger.info(
|
|
||||||
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
|
||||||
params
|
|
||||||
)
|
|
||||||
pipe_type = params.get_valid_pipeline("inpaint")
|
|
||||||
pipe = load_pipeline(
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
pipe_type,
|
|
||||||
job.get_device(),
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for source in sources:
|
|
||||||
if stage_mask is None:
|
|
||||||
# if no mask was provided, keep the full source image
|
|
||||||
stage_mask = Image.new("RGB", source.size, "black")
|
|
||||||
|
|
||||||
source, stage_mask, noise, _full_dims = expand_image(
|
|
||||||
source,
|
|
||||||
stage_mask,
|
|
||||||
expand,
|
|
||||||
fill=fill_color,
|
|
||||||
noise_source=noise_source,
|
|
||||||
mask_filter=mask_filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-source.png", source)
|
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
|
||||||
save_image(server, "last-noise.png", noise)
|
|
||||||
|
|
||||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
|
||||||
left, top, tile = dims
|
|
||||||
size = Size(*tile_source.size)
|
|
||||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "tile-source.png", tile_source)
|
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
|
||||||
if params.lpw():
|
|
||||||
logger.debug("using LPW pipeline for inpaint")
|
|
||||||
rng = torch.manual_seed(params.seed)
|
|
||||||
result = pipe.inpaint(
|
|
||||||
prompt,
|
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
height=size.height,
|
|
||||||
image=tile_source,
|
|
||||||
latents=latents,
|
|
||||||
mask_image=tile_mask,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
width=size.width,
|
|
||||||
eta=params.eta,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# encode and record alternative prompts outside of LPW
|
|
||||||
prompt_embeds = encode_prompt(
|
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
|
||||||
)
|
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
|
||||||
result = pipe(
|
|
||||||
prompt,
|
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
height=size.height,
|
|
||||||
image=tile_source,
|
|
||||||
latents=latents,
|
|
||||||
mask_image=stage_mask,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
width=size.width,
|
|
||||||
eta=params.eta,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result.images[0]
|
|
||||||
|
|
||||||
outputs.append(
|
|
||||||
process_tile_order(
|
|
||||||
stage.tile_order,
|
|
||||||
source,
|
|
||||||
SizeChart.auto,
|
|
||||||
1,
|
|
||||||
[outpaint],
|
|
||||||
overlap=params.overlap,
|
|
||||||
)
|
|
||||||
)
|
|
|
@ -1,12 +1,12 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
@ -26,8 +26,10 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
_source: Image.Image,
|
_source: Image.Image,
|
||||||
*,
|
*,
|
||||||
|
dims: Tuple[int, int, int],
|
||||||
size: Size,
|
size: Size,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
latents: Optional[np.ndarray] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
|
@ -47,16 +49,24 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
)
|
)
|
||||||
|
|
||||||
tile_size = params.tiles
|
tile_size = params.tiles
|
||||||
|
|
||||||
|
# generate new latents or slice existing
|
||||||
|
if latents is None:
|
||||||
if max(size) > tile_size:
|
if max(size) > tile_size:
|
||||||
latent_size = Size(tile_size, tile_size)
|
latent_size = Size(tile_size, tile_size)
|
||||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
|
||||||
pipe_width = pipe_height = tile_size
|
pipe_width = pipe_height = tile_size
|
||||||
else:
|
else:
|
||||||
latent_size = Size(size.width, size.height)
|
latent_size = Size(size.width, size.height)
|
||||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
|
||||||
pipe_width = size.width
|
pipe_width = size.width
|
||||||
pipe_height = size.height
|
pipe_height = size.height
|
||||||
|
|
||||||
|
# generate new latents
|
||||||
|
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||||
|
else:
|
||||||
|
# slice existing latents
|
||||||
|
latents = get_tile_latents(latents, dims, size)
|
||||||
|
pipe_width, pipe_height, _tile_size = dims
|
||||||
|
|
||||||
pipe_type = params.get_valid_pipeline("txt2img")
|
pipe_type = params.get_valid_pipeline("txt2img")
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt
|
||||||
from ..image import mask_filter_none, noise_source_histogram
|
from ..image import mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
|
@ -28,15 +28,17 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
tile_mask: Image.Image,
|
|
||||||
*,
|
*,
|
||||||
border: Border,
|
border: Border,
|
||||||
stage_source: Optional[Image.Image] = None,
|
dims: Tuple[int, int, int],
|
||||||
stage_mask: Optional[Image.Image] = None,
|
tile_mask: Image.Image,
|
||||||
fill_color: str = "white",
|
fill_color: str = "white",
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
|
latents: Optional[np.ndarray] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
stage_mask: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Image.Image]:
|
) -> List[Image.Image]:
|
||||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||||
|
@ -64,18 +66,25 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
outputs.append(source)
|
outputs.append(source)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
source_width, source_height = source.size
|
size = Size(*source.size)
|
||||||
source_size = Size(source_width, source_height)
|
|
||||||
tile_size = params.tiles
|
tile_size = params.tiles
|
||||||
if max(source_size) > tile_size:
|
|
||||||
|
# generate new latents or slice existing
|
||||||
|
if latents is None:
|
||||||
|
if max(size) > tile_size:
|
||||||
latent_size = Size(tile_size, tile_size)
|
latent_size = Size(tile_size, tile_size)
|
||||||
latents = get_latents_from_seed(params.seed, latent_size)
|
|
||||||
pipe_width = pipe_height = tile_size
|
pipe_width = pipe_height = tile_size
|
||||||
else:
|
else:
|
||||||
latent_size = Size(source_size.width, source_size.height)
|
latent_size = Size(size.width, size.height)
|
||||||
latents = get_latents_from_seed(params.seed, latent_size)
|
pipe_width = size.width
|
||||||
pipe_width = source_size.width
|
pipe_height = size.height
|
||||||
pipe_height = source_size.height
|
|
||||||
|
# generate new latents
|
||||||
|
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||||
|
else:
|
||||||
|
# slice existing latents
|
||||||
|
latents = get_tile_latents(latents, dims, size)
|
||||||
|
pipe_width, pipe_height, _tile_size = dims
|
||||||
|
|
||||||
if params.lpw():
|
if params.lpw():
|
||||||
logger.debug("using LPW pipeline for inpaint")
|
logger.debug("using LPW pipeline for inpaint")
|
||||||
|
|
|
@ -27,7 +27,7 @@ from ..server import ServerContext
|
||||||
from ..server.load import get_source_filters
|
from ..server.load import get_source_filters
|
||||||
from ..utils import is_debug, run_gc, show_system_toast
|
from ..utils import is_debug, run_gc, show_system_toast
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .utils import parse_prompt
|
from .utils import get_latents_from_seed, parse_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -81,8 +81,9 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
images = chain(job, server, params, [], callback=progress)
|
images = chain.run(job, server, params, [], callback=progress, latents=latents)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
||||||
|
@ -287,8 +288,9 @@ def run_inpaint_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
images = chain(job, server, params, [source], callback=progress)
|
images = chain(job, server, params, [source], callback=progress, latents=latents)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
|
|
@ -17,7 +17,7 @@ from ..diffusers.run import (
|
||||||
)
|
)
|
||||||
from ..diffusers.utils import replace_wildcards
|
from ..diffusers.utils import replace_wildcards
|
||||||
from ..output import json_params, make_output_name
|
from ..output import json_params, make_output_name
|
||||||
from ..params import Border, StageParams, TileOrder, UpscaleParams
|
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
|
||||||
from ..transformers.run import run_txt2txt_pipeline
|
from ..transformers.run import run_txt2txt_pipeline
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -163,8 +163,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("source image is required")
|
return error_reply("source image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
size = Size(source.width, source.height)
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server, "img2img")
|
device, params, _size = pipeline_from_request(server, "img2img")
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
highres = highres_from_request()
|
highres = highres_from_request()
|
||||||
source_filter = get_from_list(
|
source_filter = get_from_list(
|
||||||
|
@ -249,12 +250,14 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("mask image is required")
|
return error_reply("mask image is required")
|
||||||
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
size = Size(source.width, source.height)
|
||||||
|
|
||||||
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||||
mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255))
|
mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255))
|
||||||
mask.alpha_composite(mask_top_layer)
|
mask.alpha_composite(mask_top_layer)
|
||||||
mask.convert(mode="L")
|
mask.convert(mode="L")
|
||||||
|
|
||||||
device, params, size = pipeline_from_request(server, "inpaint")
|
device, params, _size = pipeline_from_request(server, "inpaint")
|
||||||
expand = border_from_request()
|
expand = border_from_request()
|
||||||
upscale = upscale_from_request()
|
upscale = upscale_from_request()
|
||||||
highres = highres_from_request()
|
highres = highres_from_request()
|
||||||
|
|
Loading…
Reference in New Issue