1
0
Fork 0

lint(api): use consistent name for source/mask images to avoid conflict with kwargs

This commit is contained in:
Sean Sube 2023-02-18 16:35:57 -06:00
parent 7b8ced0f68
commit b4f7973c1e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
14 changed files with 101 additions and 101 deletions

View File

@ -85,7 +85,7 @@ def blend_inpaint(
height=size.height, height=size.height,
image=image, image=image,
latents=latents, latents=latents,
mask_image=mask, mask=mask,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
@ -100,7 +100,7 @@ def blend_inpaint(
height=size.height, height=size.height,
image=image, image=image,
latents=latents, latents=latents,
mask_image=mask, mask=mask,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,

View File

@ -50,7 +50,7 @@ def correct_gfpgan(
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
@ -59,13 +59,13 @@ def correct_gfpgan(
if upscale.correction_model is None: if upscale.correction_model is None:
logger.warn("no face model given, skipping") logger.warn("no face model given, skipping")
return source_image return source
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device() device = job.get_device()
gfpgan = load_gfpgan(server, stage, upscale, device) gfpgan = load_gfpgan(server, stage, upscale, device)
output = np.array(source_image) output = np.array(source)
_, _, output = gfpgan.enhance( _, _, output = gfpgan.enhance(
output, output,
has_aligned=False, has_aligned=False,

View File

@ -15,11 +15,11 @@ def persist_disk(
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
output: str, output: str,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
dest = save_image(server, output, source_image) dest = save_image(server, output, source)
logger.info("saved image to %s", dest) logger.info("saved image to %s", dest)
return source_image return source

View File

@ -16,7 +16,7 @@ def persist_s3(
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
output: str, output: str,
bucket: str, bucket: str,
@ -28,7 +28,7 @@ def persist_s3(
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO() data = BytesIO()
source_image.save(data, format=server.image_format) source.save(data, format=server.image_format)
data.seek(0) data.seek(0)
try: try:
@ -37,4 +37,4 @@ def persist_s3(
except Exception as err: except Exception as err:
logger.error("error saving image to S3: %s", err) logger.error("error saving image to S3: %s", err)
return source_image return source

View File

@ -14,12 +14,12 @@ def reduce_crop(
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
origin: Size, origin: Size,
size: Size, size: Size,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
image = source_image.crop((origin.width, origin.height, size.width, size.height)) image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image return image

View File

@ -14,12 +14,12 @@ def reduce_thumbnail(
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
size: Size, size: Size,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
image = source_image.copy() image = source.copy()
# TODO: should use a call to valid_image # TODO: should use a call to valid_image
image = image.thumbnail((size.width, size.height)) image = image.thumbnail((size.width, size.height))

View File

@ -15,7 +15,7 @@ def source_noise(
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
@ -23,10 +23,10 @@ def source_noise(
) -> Image.Image: ) -> Image.Image:
logger.info("generating image from noise source") logger.info("generating image from noise source")
if source_image is not None: if source is not None:
logger.warn("a source image was passed to a noise stage, but will be discarded") logger.warn("a source image was passed to a noise stage, but will be discarded")
output = noise_source(source_image, (size.width, size.height), (0, 0)) output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -18,7 +18,7 @@ def source_txt2img(
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
size: Size, size: Size,
callback: ProgressCallback = None, callback: ProgressCallback = None,
@ -28,7 +28,7 @@ def source_txt2img(
size = size.with_args(**kwargs) size = size.with_args(**kwargs)
logger.info("generating image using txt2img, %s steps: %s", params.steps, params.prompt) logger.info("generating image using txt2img, %s steps: %s", params.steps, params.prompt)
if source_image is not None: if source is not None:
logger.warn( logger.warn(
"a source image was passed to a txt2img stage, but will be discarded" "a source image was passed to a txt2img stage, but will be discarded"
) )

View File

@ -22,11 +22,11 @@ def upscale_outpaint(
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
border: Border, border: Border,
prompt: str = None, prompt: str = None,
mask_image: Image.Image = None, mask: Image.Image = None,
fill_color: str = "white", fill_color: str = "white",
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
@ -38,34 +38,34 @@ def upscale_outpaint(
margin_x = float(max(border.left, border.right)) margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom)) margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source_image.width, margin_y / source_image.height) overlap = min(margin_x / source.width, margin_y / source.height)
if mask_image is None: if mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
mask_image = Image.new("RGB", source_image.size, "black") mask = Image.new("RGB", source.size, "black")
source_image, mask_image, noise_image, full_dims = expand_image( source, mask, noise, full_dims = expand_image(
source_image, source,
mask_image, mask,
border, border,
fill=fill_color, fill=fill_color,
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter, mask_filter=mask_filter,
) )
draw_mask = ImageDraw.Draw(mask_image) draw_mask = ImageDraw.Draw(mask)
full_size = Size(*full_dims) full_size = Size(*full_dims)
full_latents = get_latents_from_seed(params.seed, full_size) full_latents = get_latents_from_seed(params.seed, full_size)
if is_debug(): if is_debug():
save_image(server, "last-source.png", source_image) save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", mask_image) save_image(server, "last-mask.png", mask)
save_image(server, "last-noise.png", noise_image) save_image(server, "last-noise.png", noise)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]): def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims left, top, tile = dims
size = Size(*image.size) size = Size(*image.size)
mask = mask_image.crop((left, top, left + tile, top + tile)) mask = mask.crop((left, top, left + tile, top + tile))
if is_debug(): if is_debug():
save_image(server, "tile-source.png", image) save_image(server, "tile-source.png", image)
@ -105,7 +105,7 @@ def upscale_outpaint(
guidance_scale=params.cfg, guidance_scale=params.cfg,
height=size.height, height=size.height,
latents=latents, latents=latents,
mask_image=mask, mask=mask,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
width=size.width, width=size.width,
@ -118,7 +118,7 @@ def upscale_outpaint(
if overlap == 0: if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling") logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom: elif border.left == border.right and border.top == border.bottom:
logger.debug( logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap", "outpainting with an even border, using spiral tiling with %s overlap",
@ -126,7 +126,7 @@ def upscale_outpaint(
) )
output = process_tile_order( output = process_tile_order(
stage.tile_order, stage.tile_order,
source_image, source,
SizeChart.auto, SizeChart.auto,
1, 1,
[outpaint], [outpaint],
@ -134,7 +134,7 @@ def upscale_outpaint(
) )
else: else:
logger.debug("outpainting with an uneven border, using grid tiling") logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output return output

View File

@ -100,14 +100,14 @@ def upscale_resrgan(
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source_image: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source_image) output = np.array(source)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size) upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale) output, _ = upsampler.enhance(output, outscale=upscale.outscale)

View File

@ -657,7 +657,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None, image: Union[np.ndarray, PIL.Image.Image] = None,
mask_image: Union[np.ndarray, PIL.Image.Image] = None, mask: Union[np.ndarray, PIL.Image.Image] = None,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
num_inference_steps: int = 50, num_inference_steps: int = 50,
@ -687,9 +687,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
image (`np.ndarray` or `PIL.Image.Image`): image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process.
mask_image (`np.ndarray` or `PIL.Image.Image`): mask (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a replaced by noise and therefore repainted, while black pixels will be preserved. If `mask` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to 512): height (`int`, *optional*, defaults to 512):
@ -782,10 +782,10 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
image = preprocess_image(image) image = preprocess_image(image)
if image is not None: if image is not None:
image = image.astype(dtype) image = image.astype(dtype)
if isinstance(mask_image, PIL.Image.Image): if isinstance(mask, PIL.Image.Image):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor) mask = preprocess_mask(mask, self.vae_scale_factor)
if mask_image is not None: if mask is not None:
mask = mask_image.astype(dtype) mask = mask.astype(dtype)
mask = np.concatenate([mask] * batch_size * num_images_per_prompt) mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
else: else:
mask = None mask = None
@ -1057,7 +1057,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
def inpaint( def inpaint(
self, self,
image: Union[np.ndarray, PIL.Image.Image], image: Union[np.ndarray, PIL.Image.Image],
mask_image: Union[np.ndarray, PIL.Image.Image], mask: Union[np.ndarray, PIL.Image.Image],
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
strength: float = 0.8, strength: float = 0.8,
@ -1079,9 +1079,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
image (`np.ndarray` or `PIL.Image.Image`): image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted. process. This is the image whose masked region will be inpainted.
mask_image (`np.ndarray` or `PIL.Image.Image`): mask (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a replaced by noise and therefore repainted, while black pixels will be preserved. If `mask` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
@ -1136,7 +1136,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
image=image, image=image,
mask_image=mask_image, mask=mask,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
strength=strength, strength=strength,

View File

@ -96,7 +96,7 @@ def run_img2img_pipeline(
params: ImageParams, params: ImageParams,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image.Image, source: Image.Image,
strength: float, strength: float,
) -> None: ) -> None:
pipe = load_pipeline( pipe = load_pipeline(
@ -112,7 +112,7 @@ def run_img2img_pipeline(
logger.debug("using LPW pipeline for img2img") logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.img2img( result = pipe.img2img(
source_image, source,
params.prompt, params.prompt,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
@ -125,7 +125,7 @@ def run_img2img_pipeline(
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
params.prompt, params.prompt,
source_image, source,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
@ -146,7 +146,7 @@ def run_img2img_pipeline(
) )
dest = save_image(server, output, image) dest = save_image(server, output, image)
size = Size(*source_image.size) size = Size(*source.size)
save_params(server, output, params, size, upscale=upscale) save_params(server, output, params, size, upscale=upscale)
del pipe del pipe
@ -165,8 +165,8 @@ def run_inpaint_pipeline(
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image.Image, source: Image.Image,
mask_image: Image.Image, mask: Image.Image,
border: Border, border: Border,
noise_source: Any, noise_source: Any,
mask_filter: Any, mask_filter: Any,
@ -187,9 +187,9 @@ def run_inpaint_pipeline(
server, server,
stage, stage,
params, params,
source_image, source,
border=border, border=border,
mask_image=mask_image, mask=mask,
fill_color=fill_color, fill_color=fill_color,
mask_filter=mask_filter, mask_filter=mask_filter,
noise_source=noise_source, noise_source=noise_source,
@ -217,14 +217,14 @@ def run_upscale_pipeline(
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image.Image, source: Image.Image,
) -> None: ) -> None:
# device = job.get_device() # device = job.get_device()
progress = job.get_progress_callback() progress = job.get_progress_callback()
stage = StageParams() stage = StageParams()
image = run_upscale_correction( image = run_upscale_correction(
job, server, stage, params, source_image, upscale=upscale, callback=progress job, server, stage, params, source, upscale=upscale, callback=progress
) )
dest = save_image(server, output, image) dest = save_image(server, output, image)

View File

@ -12,23 +12,23 @@ def get_pixel_index(x: int, y: int, width: int) -> int:
def mask_filter_none( def mask_filter_none(
mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw mask: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new("RGB", (width, height), fill)
noise.paste(mask_image, origin) noise.paste(mask, origin)
return noise return noise
def mask_filter_gaussian_multiply( def mask_filter_gaussian_multiply(
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image: ) -> Image.Image:
""" """
Gaussian blur with multiply, source image centered on white canvas. Gaussian blur with multiply, source image centered on white canvas.
""" """
noise = mask_filter_none(mask_image, dims, origin) noise = mask_filter_none(mask, dims, origin)
for i in range(rounds): for i in range(rounds):
blur = noise.filter(ImageFilter.GaussianBlur(5)) blur = noise.filter(ImageFilter.GaussianBlur(5))
@ -38,12 +38,12 @@ def mask_filter_gaussian_multiply(
def mask_filter_gaussian_screen( def mask_filter_gaussian_screen(
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image: ) -> Image.Image:
""" """
Gaussian blur, source image centered on white canvas. Gaussian blur, source image centered on white canvas.
""" """
noise = mask_filter_none(mask_image, dims, origin) noise = mask_filter_none(mask, dims, origin)
for i in range(rounds): for i in range(rounds):
blur = noise.filter(ImageFilter.GaussianBlur(5)) blur = noise.filter(ImageFilter.GaussianBlur(5))
@ -53,7 +53,7 @@ def mask_filter_gaussian_screen(
def noise_source_fill_edge( def noise_source_fill_edge(
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw source: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image: ) -> Image.Image:
""" """
Identity transform, source image centered on white canvas. Identity transform, source image centered on white canvas.
@ -61,13 +61,13 @@ def noise_source_fill_edge(
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new("RGB", (width, height), fill)
noise.paste(source_image, origin) noise.paste(source, origin)
return noise return noise
def noise_source_fill_mask( def noise_source_fill_mask(
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw source: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image: ) -> Image.Image:
""" """
Fill the whole canvas, no source or noise. Fill the whole canvas, no source or noise.
@ -80,13 +80,13 @@ def noise_source_fill_mask(
def noise_source_gaussian( def noise_source_gaussian(
source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw source: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image: ) -> Image.Image:
""" """
Gaussian blur, source image centered on white canvas. Gaussian blur, source image centered on white canvas.
""" """
noise = noise_source_uniform(source_image, dims, origin) noise = noise_source_uniform(source, dims, origin)
noise.paste(source_image, origin) noise.paste(source, origin)
for i in range(rounds): for i in range(rounds):
noise = noise.filter(ImageFilter.GaussianBlur(5)) noise = noise.filter(ImageFilter.GaussianBlur(5))
@ -95,7 +95,7 @@ def noise_source_gaussian(
def noise_source_uniform( def noise_source_uniform(
source_image: Image.Image, dims: Point, origin: Point, **kw source: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -115,7 +115,7 @@ def noise_source_uniform(
def noise_source_normal( def noise_source_normal(
source_image: Image.Image, dims: Point, origin: Point, **kw source: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -135,9 +135,9 @@ def noise_source_normal(
def noise_source_histogram( def noise_source_histogram(
source_image: Image.Image, dims: Point, origin: Point, **kw source: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
r, g, b = source_image.split() r, g, b = source.split()
width, height = dims width, height = dims
size = width * height size = width * height
@ -167,25 +167,25 @@ def noise_source_histogram(
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232 # very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
def expand_image( def expand_image(
source_image: Image.Image, source: Image.Image,
mask_image: Image.Image, mask: Image.Image,
expand: Border, expand: Border,
fill="white", fill="white",
noise_source=noise_source_histogram, noise_source=noise_source_histogram,
mask_filter=mask_filter_none, mask_filter=mask_filter_none,
): ):
full_width = expand.left + source_image.width + expand.right full_width = expand.left + source.width + expand.right
full_height = expand.top + source_image.height + expand.bottom full_height = expand.top + source.height + expand.bottom
dims = (full_width, full_height) dims = (full_width, full_height)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)
full_source = Image.new("RGB", dims, fill) full_source = Image.new("RGB", dims, fill)
full_source.paste(source_image, origin) full_source.paste(source, origin)
# new mask pixels need to be filled with white so they will be replaced # new mask pixels need to be filled with white so they will be replaced
full_mask = mask_filter(mask_image, dims, origin, fill="white") full_mask = mask_filter(mask, dims, origin, fill="white")
full_noise = noise_source(source_image, dims, origin, fill=fill) full_noise = noise_source(source, dims, origin, fill=fill)
full_noise = ImageChops.multiply(full_noise, full_mask) full_noise = ImageChops.multiply(full_noise, full_mask)
full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))

View File

@ -514,7 +514,7 @@ def img2img():
return error_reply("source image is required") return error_reply("source image is required")
source_file = request.files.get("source") source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
@ -530,7 +530,7 @@ def img2img():
output = make_output_name(context, "img2img", params, size, extras=(strength,)) output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output) logger.info("img2img job queued for: %s", output)
source_image = valid_image(source_image, min_dims=size, max_dims=size) source = valid_image(source, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_img2img_pipeline, run_img2img_pipeline,
@ -538,7 +538,7 @@ def img2img():
params, params,
output, output,
upscale, upscale,
source_image, source,
strength, strength,
needs_device=device, needs_device=device,
) )
@ -577,10 +577,10 @@ def inpaint():
return error_reply("mask image is required") return error_reply("mask image is required")
source_file = request.files.get("source") source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get("mask") mask_file = request.files.get("mask")
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
expand = border_from_request() expand = border_from_request()
@ -619,8 +619,8 @@ def inpaint():
) )
logger.info("inpaint job queued for: %s", output) logger.info("inpaint job queued for: %s", output)
source_image = valid_image(source_image, min_dims=size, max_dims=size) source = valid_image(source, min_dims=size, max_dims=size)
mask_image = valid_image(mask_image, min_dims=size, max_dims=size) mask = valid_image(mask, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_inpaint_pipeline, run_inpaint_pipeline,
@ -629,8 +629,8 @@ def inpaint():
size, size,
output, output,
upscale, upscale,
source_image, source,
mask_image, mask,
expand, expand,
noise_source, noise_source,
mask_filter, mask_filter,
@ -649,7 +649,7 @@ def upscale():
return error_reply("source image is required") return error_reply("source image is required")
source_file = request.files.get("source") source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request() device, params, size = pipeline_from_request()
upscale = upscale_from_request() upscale = upscale_from_request()
@ -657,7 +657,7 @@ def upscale():
output = make_output_name(context, "upscale", params, size) output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output) logger.info("upscale job queued for: %s", output)
source_image = valid_image(source_image, min_dims=size, max_dims=size) source = valid_image(source, min_dims=size, max_dims=size)
executor.submit( executor.submit(
output, output,
run_upscale_pipeline, run_upscale_pipeline,
@ -666,7 +666,7 @@ def upscale():
size, size,
output, output,
upscale, upscale,
source_image, source,
needs_device=device, needs_device=device,
) )
@ -723,9 +723,9 @@ def chain():
stage.name, stage.name,
) )
source_file = request.files.get(stage_source_name) source_file = request.files.get(stage_source_name)
source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image = valid_image(source_image, max_dims=(size.width, size.height)) source = valid_image(source, max_dims=(size.width, size.height))
kwargs["source_image"] = source_image kwargs["stage_source"] = source
if stage_mask_name in request.files: if stage_mask_name in request.files:
logger.debug( logger.debug(
@ -734,9 +734,9 @@ def chain():
stage.name, stage.name,
) )
mask_file = request.files.get(stage_mask_name) mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image = valid_image(mask_image, max_dims=(size.width, size.height)) mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["mask_image"] = mask_image kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs)) pipeline.append((callback, stage, kwargs))