1
0
Fork 0

Merge branch 'ssube:main' into main

This commit is contained in:
HoopyFreud 2023-07-10 20:02:32 -04:00 committed by GitHub
commit 9b630f97ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 62 additions and 183 deletions

View File

@ -1,6 +1,5 @@
from .base import ChainPipeline, PipelineStage, StageParams from .base import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage from .correct_codeformer import CorrectCodeformerStage
@ -23,7 +22,7 @@ from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = { CHAIN_STAGES = {
"blend-img2img": BlendImg2ImgStage, "blend-img2img": BlendImg2ImgStage,
"blend-inpaint": BlendInpaintStage, "blend-inpaint": UpscaleOutpaintStage,
"blend-linear": BlendLinearStage, "blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage, "blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage, "correct-codeformer": CorrectCodeformerStage,

View File

@ -146,7 +146,9 @@ class ChainPipeline:
) )
def stage_tile( def stage_tile(
source_tile: Image.Image, tile_mask: Image.Image, _dims source_tile: Image.Image,
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> Image.Image: ) -> Image.Image:
output_tile = stage_pipe.run( output_tile = stage_pipe.run(
job, job,
@ -156,6 +158,7 @@ class ChainPipeline:
[source_tile], [source_tile],
tile_mask=tile_mask, tile_mask=tile_mask,
callback=callback, callback=callback,
dims=dims,
**kwargs, **kwargs,
)[0] )[0]

View File

@ -1,140 +0,0 @@
from logging import getLogger
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import process_tile_order
logger = getLogger(__name__)
class BlendInpaintStage(BaseStage):
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
sources: List[Image.Image],
*,
expand: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> List[Image.Image]:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
params
)
pipe_type = params.get_valid_pipeline("inpaint")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
outputs = []
for source in sources:
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
outputs.append(
process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
)

View File

@ -1,12 +1,17 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt from ..diffusers.utils import (
encode_prompt,
get_latents_from_seed,
get_tile_latents,
parse_prompt,
)
from ..params import ImageParams, Size, SizeChart, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
@ -26,8 +31,10 @@ class SourceTxt2ImgStage(BaseStage):
params: ImageParams, params: ImageParams,
_source: Image.Image, _source: Image.Image,
*, *,
dims: Tuple[int, int, int],
size: Size, size: Size,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
latents: Optional[np.ndarray] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
@ -47,15 +54,13 @@ class SourceTxt2ImgStage(BaseStage):
) )
tile_size = params.tiles tile_size = params.tiles
if max(size) > tile_size: latent_size = size.min(tile_size, tile_size)
latent_size = Size(tile_size, tile_size)
# generate new latents or slice existing
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch) latents = get_latents_from_seed(params.seed, latent_size, params.batch)
pipe_width = pipe_height = tile_size
else: else:
latent_size = Size(size.width, size.height) latents = get_tile_latents(latents, dims, latent_size)
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
pipe_width = size.width
pipe_height = size.height
pipe_type = params.get_valid_pipeline("txt2img") pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline( pipe = load_pipeline(
@ -72,8 +77,8 @@ class SourceTxt2ImgStage(BaseStage):
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.text2img( result = pipe.text2img(
prompt, prompt,
height=pipe_height, height=latent_size.height,
width=pipe_width, width=latent_size.width,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
latents=latents, latents=latents,
@ -93,8 +98,8 @@ class SourceTxt2ImgStage(BaseStage):
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
prompt, prompt,
height=pipe_height, height=latent_size.height,
width=pipe_width, width=latent_size.width,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
latents=latents, latents=latents,

View File

@ -1,12 +1,17 @@
from logging import getLogger from logging import getLogger
from typing import Callable, List, Optional from typing import Callable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt from ..diffusers.utils import (
encode_prompt,
get_latents_from_seed,
get_tile_latents,
parse_prompt,
)
from ..image import mask_filter_none, noise_source_histogram from ..image import mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
@ -28,15 +33,17 @@ class UpscaleOutpaintStage(BaseStage):
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: List[Image.Image], sources: List[Image.Image],
tile_mask: Image.Image,
*, *,
border: Border, border: Border,
stage_source: Optional[Image.Image] = None, dims: Tuple[int, int, int],
stage_mask: Optional[Image.Image] = None, tile_mask: Image.Image,
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
latents: Optional[np.ndarray] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> List[Image.Image]: ) -> List[Image.Image]:
prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt(
@ -64,18 +71,15 @@ class UpscaleOutpaintStage(BaseStage):
outputs.append(source) outputs.append(source)
continue continue
source_width, source_height = source.size
source_size = Size(source_width, source_height)
tile_size = params.tiles tile_size = params.tiles
if max(source_size) > tile_size: size = Size(*source.size)
latent_size = Size(tile_size, tile_size) latent_size = size.min(tile_size, tile_size)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width = pipe_height = tile_size # generate new latents or slice existing
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else: else:
latent_size = Size(source_size.width, source_size.height) latents = get_tile_latents(latents, dims, latent_size)
latents = get_latents_from_seed(params.seed, latent_size)
pipe_width = source_size.width
pipe_height = source_size.height
if params.lpw(): if params.lpw():
logger.debug("using LPW pipeline for inpaint") logger.debug("using LPW pipeline for inpaint")
@ -85,8 +89,8 @@ class UpscaleOutpaintStage(BaseStage):
tile_mask, tile_mask,
prompt, prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
height=pipe_height, height=latent_size.height,
width=pipe_width, width=latent_size.width,
num_inference_steps=params.steps, num_inference_steps=params.steps,
guidance_scale=params.cfg, guidance_scale=params.cfg,
generator=rng, generator=rng,
@ -106,8 +110,8 @@ class UpscaleOutpaintStage(BaseStage):
source, source,
tile_mask, tile_mask,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
height=pipe_height, height=latent_size.height,
width=pipe_width, width=latent_size.width,
num_inference_steps=params.steps, num_inference_steps=params.steps,
guidance_scale=params.cfg, guidance_scale=params.cfg,
generator=rng, generator=rng,

View File

@ -28,7 +28,7 @@ from ..server import ServerContext
from ..server.load import get_source_filters from ..server.load import get_source_filters
from ..utils import is_debug, run_gc, show_system_toast from ..utils import is_debug, run_gc, show_system_toast
from ..worker import WorkerContext from ..worker import WorkerContext
from .utils import parse_prompt from .utils import get_latents_from_seed, parse_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
@ -82,8 +82,9 @@ def run_txt2img_pipeline(
) )
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = job.get_progress_callback() progress = job.get_progress_callback()
images = chain(job, server, params, [], callback=progress) images = chain.run(job, server, params, [], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
@ -361,8 +362,9 @@ def run_inpaint_pipeline(
) )
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = job.get_progress_callback() progress = job.get_progress_callback()
images = chain(job, server, params, [source], callback=progress) images = chain(job, server, params, [source], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):

View File

@ -273,8 +273,8 @@ def get_tile_latents(
) -> np.ndarray: ) -> np.ndarray:
x, y, tile = dims x, y, tile = dims
t = tile // LATENT_FACTOR t = tile // LATENT_FACTOR
x = x // LATENT_FACTOR x = max(0, x // LATENT_FACTOR)
y = y // LATENT_FACTOR y = max(0, y // LATENT_FACTOR)
xt = x + t xt = x + t
yt = y + t yt = y + t

View File

@ -86,6 +86,9 @@ class Size:
border.top + self.height + border.bottom, border.top + self.height + border.bottom,
) )
def min(self, width: int, height: int):
return Size(min(self.width, width), min(self.height, height))
def round_to_tile(self, tile=512): def round_to_tile(self, tile=512):
return Size( return Size(
ceil(self.width / tile) * tile, ceil(self.width / tile) * tile,

View File

@ -17,7 +17,7 @@ from ..diffusers.run import (
) )
from ..diffusers.utils import replace_wildcards from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name from ..output import json_params, make_output_name
from ..params import Border, StageParams, TileOrder, UpscaleParams from ..params import Border, Size, StageParams, TileOrder, UpscaleParams
from ..transformers.run import run_txt2txt_pipeline from ..transformers.run import run_txt2txt_pipeline
from ..utils import ( from ..utils import (
base_join, base_join,
@ -163,8 +163,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("source image is required") return error_reply("source image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
size = Size(source.width, source.height)
device, params, size = pipeline_from_request(server, "img2img") device, params, _size = pipeline_from_request(server, "img2img")
upscale = upscale_from_request() upscale = upscale_from_request()
highres = highres_from_request() highres = highres_from_request()
source_filter = get_from_list( source_filter = get_from_list(
@ -249,12 +250,14 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
return error_reply("mask image is required") return error_reply("mask image is required")
source = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
size = Size(source.width, source.height)
mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255)) mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255))
mask.alpha_composite(mask_top_layer) mask.alpha_composite(mask_top_layer)
mask.convert(mode="L") mask.convert(mode="L")
device, params, size = pipeline_from_request(server, "inpaint") device, params, _size = pipeline_from_request(server, "inpaint")
expand = border_from_request() expand = border_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
highres = highres_from_request() highres = highres_from_request()