Merge branch 'ssube:main' into main
This commit is contained in:
commit
9b630f97ee
|
@ -1,6 +1,5 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageParams
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_inpaint import BlendInpaintStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
from .blend_mask import BlendMaskStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
|
@ -23,7 +22,7 @@ from .upscale_swinir import UpscaleSwinIRStage
|
|||
|
||||
CHAIN_STAGES = {
|
||||
"blend-img2img": BlendImg2ImgStage,
|
||||
"blend-inpaint": BlendInpaintStage,
|
||||
"blend-inpaint": UpscaleOutpaintStage,
|
||||
"blend-linear": BlendLinearStage,
|
||||
"blend-mask": BlendMaskStage,
|
||||
"correct-codeformer": CorrectCodeformerStage,
|
||||
|
|
|
@ -146,7 +146,9 @@ class ChainPipeline:
|
|||
)
|
||||
|
||||
def stage_tile(
|
||||
source_tile: Image.Image, tile_mask: Image.Image, _dims
|
||||
source_tile: Image.Image,
|
||||
tile_mask: Image.Image,
|
||||
dims: Tuple[int, int, int],
|
||||
) -> Image.Image:
|
||||
output_tile = stage_pipe.run(
|
||||
job,
|
||||
|
@ -156,6 +158,7 @@ class ChainPipeline:
|
|||
[source_tile],
|
||||
tile_mask=tile_mask,
|
||||
callback=callback,
|
||||
dims=dims,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import is_debug
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .stage import BaseStage
|
||||
from .tile import process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class BlendInpaintStage(BaseStage):
|
||||
def run(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
expand: Border,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
params = params.with_args(**kwargs)
|
||||
expand = expand.with_args(**kwargs)
|
||||
logger.info(
|
||||
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||
params
|
||||
)
|
||||
pipe_type = params.get_valid_pipeline("inpaint")
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
if stage_mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
stage_mask = Image.new("RGB", source.size, "black")
|
||||
|
||||
source, stage_mask, noise, _full_dims = expand_image(
|
||||
source,
|
||||
stage_mask,
|
||||
expand,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-source.png", source)
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", tile_source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=tile_source,
|
||||
latents=latents,
|
||||
mask_image=tile_mask,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=tile_source,
|
||||
latents=latents,
|
||||
mask_image=stage_mask,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return result.images[0]
|
||||
|
||||
outputs.append(
|
||||
process_tile_order(
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=params.overlap,
|
||||
)
|
||||
)
|
|
@ -1,12 +1,17 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
||||
from ..diffusers.utils import (
|
||||
encode_prompt,
|
||||
get_latents_from_seed,
|
||||
get_tile_latents,
|
||||
parse_prompt,
|
||||
)
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
@ -26,8 +31,10 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
params: ImageParams,
|
||||
_source: Image.Image,
|
||||
*,
|
||||
dims: Tuple[int, int, int],
|
||||
size: Size,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
params = params.with_args(**kwargs)
|
||||
|
@ -47,15 +54,13 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
)
|
||||
|
||||
tile_size = params.tiles
|
||||
if max(size) > tile_size:
|
||||
latent_size = Size(tile_size, tile_size)
|
||||
latent_size = size.min(tile_size, tile_size)
|
||||
|
||||
# generate new latents or slice existing
|
||||
if latents is None:
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
pipe_width = pipe_height = tile_size
|
||||
else:
|
||||
latent_size = Size(size.width, size.height)
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
pipe_width = size.width
|
||||
pipe_height = size.height
|
||||
latents = get_tile_latents(latents, dims, latent_size)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
pipe = load_pipeline(
|
||||
|
@ -72,8 +77,8 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.text2img(
|
||||
prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
|
@ -93,8 +98,8 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
latents=latents,
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
|
||||
from ..diffusers.utils import (
|
||||
encode_prompt,
|
||||
get_latents_from_seed,
|
||||
get_tile_latents,
|
||||
parse_prompt,
|
||||
)
|
||||
from ..image import mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
|
@ -28,15 +33,17 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
*,
|
||||
border: Border,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
dims: Tuple[int, int, int],
|
||||
tile_mask: Image.Image,
|
||||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
latents: Optional[np.ndarray] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> List[Image.Image]:
|
||||
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
|
||||
|
@ -64,18 +71,15 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
outputs.append(source)
|
||||
continue
|
||||
|
||||
source_width, source_height = source.size
|
||||
source_size = Size(source_width, source_height)
|
||||
tile_size = params.tiles
|
||||
if max(source_size) > tile_size:
|
||||
latent_size = Size(tile_size, tile_size)
|
||||
latents = get_latents_from_seed(params.seed, latent_size)
|
||||
pipe_width = pipe_height = tile_size
|
||||
size = Size(*source.size)
|
||||
latent_size = size.min(tile_size, tile_size)
|
||||
|
||||
# generate new latents or slice existing
|
||||
if latents is None:
|
||||
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
|
||||
else:
|
||||
latent_size = Size(source_size.width, source_size.height)
|
||||
latents = get_latents_from_seed(params.seed, latent_size)
|
||||
pipe_width = source_size.width
|
||||
pipe_height = source_size.height
|
||||
latents = get_tile_latents(latents, dims, latent_size)
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
|
@ -85,8 +89,8 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
tile_mask,
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
generator=rng,
|
||||
|
@ -106,8 +110,8 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
source,
|
||||
tile_mask,
|
||||
negative_prompt=negative_prompt,
|
||||
height=pipe_height,
|
||||
width=pipe_width,
|
||||
height=latent_size.height,
|
||||
width=latent_size.width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
generator=rng,
|
||||
|
|
|
@ -28,7 +28,7 @@ from ..server import ServerContext
|
|||
from ..server.load import get_source_filters
|
||||
from ..utils import is_debug, run_gc, show_system_toast
|
||||
from ..worker import WorkerContext
|
||||
from .utils import parse_prompt
|
||||
from .utils import get_latents_from_seed, parse_prompt
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -82,8 +82,9 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = job.get_progress_callback()
|
||||
images = chain(job, server, params, [], callback=progress)
|
||||
images = chain.run(job, server, params, [], callback=progress, latents=latents)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
|
||||
|
@ -361,8 +362,9 @@ def run_inpaint_pipeline(
|
|||
)
|
||||
|
||||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = job.get_progress_callback()
|
||||
images = chain(job, server, params, [source], callback=progress)
|
||||
images = chain(job, server, params, [source], callback=progress, latents=latents)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
|
|
@ -273,8 +273,8 @@ def get_tile_latents(
|
|||
) -> np.ndarray:
|
||||
x, y, tile = dims
|
||||
t = tile // LATENT_FACTOR
|
||||
x = x // LATENT_FACTOR
|
||||
y = y // LATENT_FACTOR
|
||||
x = max(0, x // LATENT_FACTOR)
|
||||
y = max(0, y // LATENT_FACTOR)
|
||||
xt = x + t
|
||||
yt = y + t
|
||||
|
||||
|
|
|
@ -86,6 +86,9 @@ class Size:
|
|||
border.top + self.height + border.bottom,
|
||||
)
|
||||
|
||||
def min(self, width: int, height: int):
|
||||
return Size(min(self.width, width), min(self.height, height))
|
||||
|
||||
def round_to_tile(self, tile=512):
|
||||
return Size(
|
||||
ceil(self.width / tile) * tile,
|
||||
|
|
|
@ -17,7 +17,7 @@ from ..diffusers.run import (
|
|||
)
|
||||
from ..diffusers.utils import replace_wildcards
|
||||
from ..output import json_params, make_output_name
|
||||
from ..params import Border, StageParams, TileOrder, UpscaleParams
|
||||
from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
|
||||
from ..transformers.run import run_txt2txt_pipeline
|
||||
from ..utils import (
|
||||
base_join,
|
||||
|
@ -163,8 +163,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return error_reply("source image is required")
|
||||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
size = Size(source.width, source.height)
|
||||
|
||||
device, params, size = pipeline_from_request(server, "img2img")
|
||||
device, params, _size = pipeline_from_request(server, "img2img")
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
source_filter = get_from_list(
|
||||
|
@ -249,12 +250,14 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return error_reply("mask image is required")
|
||||
|
||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||
size = Size(source.width, source.height)
|
||||
|
||||
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||
mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255))
|
||||
mask.alpha_composite(mask_top_layer)
|
||||
mask.convert(mode="L")
|
||||
|
||||
device, params, size = pipeline_from_request(server, "inpaint")
|
||||
device, params, _size = pipeline_from_request(server, "inpaint")
|
||||
expand = border_from_request()
|
||||
upscale = upscale_from_request()
|
||||
highres = highres_from_request()
|
||||
|
|
Loading…
Reference in New Issue