1
0
Fork 0

fix(api): use stage source when available

This commit is contained in:
Sean Sube 2023-02-18 22:11:44 -06:00
parent 25c56c7d5c
commit ac1f7449bb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
15 changed files with 55 additions and 30 deletions

View File

@ -20,9 +20,11 @@ def blend_img2img(
source: Image.Image, source: Image.Image,
*, *,
callback: ProgressCallback = None, callback: ProgressCallback = None,
stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
source = stage_source or source
logger.info( logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt "blending image using img2img, %s steps: %s", params.steps, params.prompt
) )

View File

@ -25,7 +25,8 @@ def blend_inpaint(
source: Image.Image, source: Image.Image,
*, *,
expand: Border, expand: Border,
mask: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
@ -34,17 +35,18 @@ def blend_inpaint(
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs) expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info( logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt "blending image using inpaint, %s steps: %s", params.steps, params.prompt
) )
if mask is None: if stage_mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
mask = Image.new("RGB", source.size, "black") stage_mask = Image.new("RGB", source.size, "black")
source, mask, noise, _full_dims = expand_image( source, stage_mask, noise, _full_dims = expand_image(
source, source,
mask, stage_mask,
expand, expand,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
@ -53,13 +55,13 @@ def blend_inpaint(
if is_debug(): if is_debug():
save_image(server, "last-source.png", source) save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise) save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
size = Size(*tile_source.size) size = Size(*tile_source.size)
tile_mask = mask.crop((left, top, left + tile, top + tile)) tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(server, "tile-source.png", tile_source) save_image(server, "tile-source.png", tile_source)
@ -100,7 +102,7 @@ def blend_inpaint(
height=size.height, height=size.height,
image=tile_source, image=tile_source,
latents=latents, latents=latents,
mask_image=mask, mask_image=stage_mask,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,

View File

@ -20,18 +20,18 @@ def blend_mask(
_params: ImageParams, _params: ImageParams,
*, *,
sources: Optional[List[Image.Image]] = None, sources: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None, _callback: ProgressCallback = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info("blending image using mask") logger.info("blending image using mask")
mult_mask = Image.new("RGBA", mask.size, color="black") mult_mask = Image.new("RGBA", stage_mask.size, color="black")
mult_mask.alpha_composite(mask) mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L") mult_mask = mult_mask.convert("L")
if is_debug(): if is_debug():
save_image(server, "last-mask.png", mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
resized = [ resized = [

View File

@ -24,8 +24,10 @@ def correct_codeformer(
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
from codeformer import CodeFormer from codeformer import CodeFormer
source = stage_source or source
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
device = job.get_device() device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(stage_source or source) return pipe(source)

View File

@ -53,9 +53,11 @@ def correct_gfpgan(
source: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.correction_model is None: if upscale.correction_model is None:
logger.warn("no face model given, skipping") logger.warn("no face model given, skipping")

View File

@ -17,8 +17,11 @@ def persist_disk(
source: Image.Image, source: Image.Image,
*, *,
output: str, output: str,
stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source
dest = save_image(server, output, source) dest = save_image(server, output, source)
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)
return source return source

View File

@ -21,8 +21,11 @@ def persist_s3(
bucket: str, bucket: str,
endpoint_url: str = None, endpoint_url: str = None,
profile_name: str = None, profile_name: str = None,
stage_source: Image.Image = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)

View File

@ -17,8 +17,11 @@ def reduce_crop(
*, *,
origin: Size, origin: Size,
size: Size, size: Size,
stage_source: Image.Image = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source
image = source.crop((origin.width, origin.height, size.width, size.height)) image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image return image

View File

@ -16,8 +16,10 @@ def reduce_thumbnail(
source: Image.Image, source: Image.Image,
*, *,
size: Size, size: Size,
stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source
image = source.copy() image = source.copy()
# TODO: should use a call to valid_image # TODO: should use a call to valid_image

View File

@ -18,8 +18,10 @@ def source_noise(
*, *,
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source
logger.info("generating image from noise source") logger.info("generating image from noise source")
if source is not None: if source is not None:

View File

@ -17,7 +17,7 @@ def source_txt2img(
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, _source: Image.Image,
*, *,
size: Size, size: Size,
callback: ProgressCallback = None, callback: ProgressCallback = None,
@ -29,7 +29,7 @@ def source_txt2img(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt "generating image using txt2img, %s steps: %s", params.steps, params.prompt
) )
if source is not None: if "stage_source" in kwargs:
logger.warn( logger.warn(
"a source image was passed to a txt2img stage, but will be discarded" "a source image was passed to a txt2img stage, but will be discarded"
) )

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Callable, Tuple from typing import Callable, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -25,47 +25,47 @@ def upscale_outpaint(
source: Image.Image, source: Image.Image,
*, *,
border: Border, border: Border,
prompt: str = None, stage_source: Optional[Image.Image] = None,
mask: Image.Image = None, stage_mask: Optional[Image.Image] = None,
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: ProgressCallback = None, callback: ProgressCallback = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
prompt = prompt or params.prompt source = stage_source or source
logger.info("upscaling image by expanding borders: %s", border) logger.info("upscaling image by expanding borders: %s", border)
margin_x = float(max(border.left, border.right)) margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom)) margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height) overlap = min(margin_x / source.width, margin_y / source.height)
if mask is None: if stage_mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
mask = Image.new("RGB", source.size, "black") stage_mask = Image.new("RGB", source.size, "black")
source, mask, noise, full_dims = expand_image( source, stage_mask, noise, full_dims = expand_image(
source, source,
mask, stage_mask,
border, border,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter, mask_filter=mask_filter,
) )
draw_mask = ImageDraw.Draw(mask) draw_mask = ImageDraw.Draw(stage_mask)
full_size = Size(*full_dims) full_size = Size(*full_dims)
full_latents = get_latents_from_seed(params.seed, full_size) full_latents = get_latents_from_seed(params.seed, full_size)
if is_debug(): if is_debug():
save_image(server, "last-source.png", source) save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise) save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
size = Size(*tile_source.size) size = Size(*tile_source.size)
tile_mask = mask.crop((left, top, left + tile, top + tile)) tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(server, "tile-source.png", tile_source) save_image(server, "tile-source.png", tile_source)
@ -86,7 +86,7 @@ def upscale_outpaint(
result = pipe.inpaint( result = pipe.inpaint(
tile_source, tile_source,
tile_mask, tile_mask,
prompt, params.prompt,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
height=size.height, height=size.height,
@ -99,7 +99,7 @@ def upscale_outpaint(
else: else:
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
prompt, params.prompt,
tile_source, tile_source,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,

View File

@ -103,8 +103,10 @@ def upscale_resrgan(
source: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Image.Image = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source) output = np.array(source)

View File

@ -69,11 +69,13 @@ def upscale_stable_diffusion(
source: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Image.Image = None,
callback: ProgressCallback = None, callback: ProgressCallback = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
source = stage_source or source
logger.info( logger.info(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt "upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
) )

View File

@ -188,7 +188,7 @@ def run_inpaint_pipeline(
params, params,
source, source,
border=border, border=border,
mask=mask, stage_mask=mask,
fill_color=fill_color, fill_color=fill_color,
mask_filter=mask_filter, mask_filter=mask_filter,
noise_source=noise_source, noise_source=noise_source,
@ -255,7 +255,7 @@ def run_blend_pipeline(
stage, stage,
params, params,
sources=sources, sources=sources,
mask=mask, stage_mask=mask,
callback=progress, callback=progress,
) )
image = image.convert("RGB") image = image.convert("RGB")