1
0
Fork 0

lint(api): use consistent name for source/mask images to avoid conflict with kwargs

This commit is contained in:
Sean Sube 2023-02-18 16:35:57 -06:00
parent 7b8ced0f68
commit b4f7973c1e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
14 changed files with 101 additions and 101 deletions

View File

@ -85,7 +85,7 @@ def blend_inpaint(
height=size.height,
image=image,
latents=latents,
mask_image=mask,
mask=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
@ -100,7 +100,7 @@ def blend_inpaint(
height=size.height,
image=image,
latents=latents,
mask_image=mask,
mask=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,

View File

@ -50,7 +50,7 @@ def correct_gfpgan(
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
upscale: UpscaleParams,
**kwargs,
@ -59,13 +59,13 @@ def correct_gfpgan(
if upscale.correction_model is None:
logger.warn("no face model given, skipping")
return source_image
return source
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = load_gfpgan(server, stage, upscale, device)
output = np.array(source_image)
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,

View File

@ -15,11 +15,11 @@ def persist_disk(
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
output: str,
**kwargs,
) -> Image.Image:
dest = save_image(server, output, source_image)
dest = save_image(server, output, source)
logger.info("saved image to %s", dest)
return source_image
return source

View File

@ -16,7 +16,7 @@ def persist_s3(
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
output: str,
bucket: str,
@ -28,7 +28,7 @@ def persist_s3(
s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO()
source_image.save(data, format=server.image_format)
source.save(data, format=server.image_format)
data.seek(0)
try:
@ -37,4 +37,4 @@ def persist_s3(
except Exception as err:
logger.error("error saving image to S3: %s", err)
return source_image
return source

View File

@ -14,12 +14,12 @@ def reduce_crop(
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
origin: Size,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.crop((origin.width, origin.height, size.width, size.height))
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image

View File

@ -14,12 +14,12 @@ def reduce_thumbnail(
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.copy()
image = source.copy()
# TODO: should use a call to valid_image
image = image.thumbnail((size.width, size.height))

View File

@ -15,7 +15,7 @@ def source_noise(
_server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
size: Size,
noise_source: Callable,
@ -23,10 +23,10 @@ def source_noise(
) -> Image.Image:
logger.info("generating image from noise source")
if source_image is not None:
if source is not None:
logger.warn("a source image was passed to a noise stage, but will be discarded")
output = noise_source(source_image, (size.width, size.height), (0, 0))
output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -18,7 +18,7 @@ def source_txt2img(
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
size: Size,
callback: ProgressCallback = None,
@ -28,7 +28,7 @@ def source_txt2img(
size = size.with_args(**kwargs)
logger.info("generating image using txt2img, %s steps: %s", params.steps, params.prompt)
if source_image is not None:
if source is not None:
logger.warn(
"a source image was passed to a txt2img stage, but will be discarded"
)

View File

@ -22,11 +22,11 @@ def upscale_outpaint(
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
border: Border,
prompt: str = None,
mask_image: Image.Image = None,
mask: Image.Image = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
@ -38,34 +38,34 @@ def upscale_outpaint(
margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source_image.width, margin_y / source_image.height)
overlap = min(margin_x / source.width, margin_y / source.height)
if mask_image is None:
if mask is None:
# if no mask was provided, keep the full source image
mask_image = Image.new("RGB", source_image.size, "black")
mask = Image.new("RGB", source.size, "black")
source_image, mask_image, noise_image, full_dims = expand_image(
source_image,
mask_image,
source, mask, noise, full_dims = expand_image(
source,
mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
draw_mask = ImageDraw.Draw(mask_image)
draw_mask = ImageDraw.Draw(mask)
full_size = Size(*full_dims)
full_latents = get_latents_from_seed(params.seed, full_size)
if is_debug():
save_image(server, "last-source.png", source_image)
save_image(server, "last-mask.png", mask_image)
save_image(server, "last-noise.png", noise_image)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", mask)
save_image(server, "last-noise.png", noise)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*image.size)
mask = mask_image.crop((left, top, left + tile, top + tile))
mask = mask.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(server, "tile-source.png", image)
@ -105,7 +105,7 @@ def upscale_outpaint(
guidance_scale=params.cfg,
height=size.height,
latents=latents,
mask_image=mask,
mask=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
@ -118,7 +118,7 @@ def upscale_outpaint(
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
@ -126,7 +126,7 @@ def upscale_outpaint(
)
output = process_tile_order(
stage.tile_order,
source_image,
source,
SizeChart.auto,
1,
[outpaint],
@ -134,7 +134,7 @@ def upscale_outpaint(
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -100,14 +100,14 @@ def upscale_resrgan(
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
source: Image.Image,
*,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source_image)
output = np.array(source)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)

View File

@ -657,7 +657,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
mask_image: Union[np.ndarray, PIL.Image.Image] = None,
mask: Union[np.ndarray, PIL.Image.Image] = None,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
@ -687,9 +687,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
mask_image (`np.ndarray` or `PIL.Image.Image`):
mask (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to 512):
@ -782,10 +782,10 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
image = preprocess_image(image)
if image is not None:
image = image.astype(dtype)
if isinstance(mask_image, PIL.Image.Image):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
if mask_image is not None:
mask = mask_image.astype(dtype)
if isinstance(mask, PIL.Image.Image):
mask = preprocess_mask(mask, self.vae_scale_factor)
if mask is not None:
mask = mask.astype(dtype)
mask = np.concatenate([mask] * batch_size * num_images_per_prompt)
else:
mask = None
@ -1057,7 +1057,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
def inpaint(
self,
image: Union[np.ndarray, PIL.Image.Image],
mask_image: Union[np.ndarray, PIL.Image.Image],
mask: Union[np.ndarray, PIL.Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
strength: float = 0.8,
@ -1079,9 +1079,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
image (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`np.ndarray` or `PIL.Image.Image`):
mask (`np.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
prompt (`str` or `List[str]`):
@ -1136,7 +1136,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask_image,
mask=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
strength=strength,

View File

@ -96,7 +96,7 @@ def run_img2img_pipeline(
params: ImageParams,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
source: Image.Image,
strength: float,
) -> None:
pipe = load_pipeline(
@ -112,7 +112,7 @@ def run_img2img_pipeline(
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source_image,
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
@ -125,7 +125,7 @@ def run_img2img_pipeline(
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
source_image,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
@ -146,7 +146,7 @@ def run_img2img_pipeline(
)
dest = save_image(server, output, image)
size = Size(*source_image.size)
size = Size(*source.size)
save_params(server, output, params, size, upscale=upscale)
del pipe
@ -165,8 +165,8 @@ def run_inpaint_pipeline(
size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
mask_image: Image.Image,
source: Image.Image,
mask: Image.Image,
border: Border,
noise_source: Any,
mask_filter: Any,
@ -187,9 +187,9 @@ def run_inpaint_pipeline(
server,
stage,
params,
source_image,
source,
border=border,
mask_image=mask_image,
mask=mask,
fill_color=fill_color,
mask_filter=mask_filter,
noise_source=noise_source,
@ -217,14 +217,14 @@ def run_upscale_pipeline(
size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
source: Image.Image,
) -> None:
# device = job.get_device()
progress = job.get_progress_callback()
stage = StageParams()
image = run_upscale_correction(
job, server, stage, params, source_image, upscale=upscale, callback=progress
job, server, stage, params, source, upscale=upscale, callback=progress
)
dest = save_image(server, output, image)

View File

@ -12,23 +12,23 @@ def get_pixel_index(x: int, y: int, width: int) -> int:
def mask_filter_none(
mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
mask: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
width, height = dims
noise = Image.new("RGB", (width, height), fill)
noise.paste(mask_image, origin)
noise.paste(mask, origin)
return noise
def mask_filter_gaussian_multiply(
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur with multiply, source image centered on white canvas.
"""
noise = mask_filter_none(mask_image, dims, origin)
noise = mask_filter_none(mask, dims, origin)
for i in range(rounds):
blur = noise.filter(ImageFilter.GaussianBlur(5))
@ -38,12 +38,12 @@ def mask_filter_gaussian_multiply(
def mask_filter_gaussian_screen(
mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas.
"""
noise = mask_filter_none(mask_image, dims, origin)
noise = mask_filter_none(mask, dims, origin)
for i in range(rounds):
blur = noise.filter(ImageFilter.GaussianBlur(5))
@ -53,7 +53,7 @@ def mask_filter_gaussian_screen(
def noise_source_fill_edge(
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
source: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
"""
Identity transform, source image centered on white canvas.
@ -61,13 +61,13 @@ def noise_source_fill_edge(
width, height = dims
noise = Image.new("RGB", (width, height), fill)
noise.paste(source_image, origin)
noise.paste(source, origin)
return noise
def noise_source_fill_mask(
source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw
source: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
"""
Fill the whole canvas, no source or noise.
@ -80,13 +80,13 @@ def noise_source_fill_mask(
def noise_source_gaussian(
source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw
source: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas.
"""
noise = noise_source_uniform(source_image, dims, origin)
noise.paste(source_image, origin)
noise = noise_source_uniform(source, dims, origin)
noise.paste(source, origin)
for i in range(rounds):
noise = noise.filter(ImageFilter.GaussianBlur(5))
@ -95,7 +95,7 @@ def noise_source_gaussian(
def noise_source_uniform(
source_image: Image.Image, dims: Point, origin: Point, **kw
source: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
@ -115,7 +115,7 @@ def noise_source_uniform(
def noise_source_normal(
source_image: Image.Image, dims: Point, origin: Point, **kw
source: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
width, height = dims
size = width * height
@ -135,9 +135,9 @@ def noise_source_normal(
def noise_source_histogram(
source_image: Image.Image, dims: Point, origin: Point, **kw
source: Image.Image, dims: Point, origin: Point, **kw
) -> Image.Image:
r, g, b = source_image.split()
r, g, b = source.split()
width, height = dims
size = width * height
@ -167,25 +167,25 @@ def noise_source_histogram(
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
def expand_image(
source_image: Image.Image,
mask_image: Image.Image,
source: Image.Image,
mask: Image.Image,
expand: Border,
fill="white",
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
):
full_width = expand.left + source_image.width + expand.right
full_height = expand.top + source_image.height + expand.bottom
full_width = expand.left + source.width + expand.right
full_height = expand.top + source.height + expand.bottom
dims = (full_width, full_height)
origin = (expand.left, expand.top)
full_source = Image.new("RGB", dims, fill)
full_source.paste(source_image, origin)
full_source.paste(source, origin)
# new mask pixels need to be filled with white so they will be replaced
full_mask = mask_filter(mask_image, dims, origin, fill="white")
full_noise = noise_source(source_image, dims, origin, fill=fill)
full_mask = mask_filter(mask, dims, origin, fill="white")
full_noise = noise_source(source, dims, origin, fill=fill)
full_noise = ImageChops.multiply(full_noise, full_mask)
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))

View File

@ -514,7 +514,7 @@ def img2img():
return error_reply("source image is required")
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
@ -530,7 +530,7 @@ def img2img():
output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output)
source_image = valid_image(source_image, min_dims=size, max_dims=size)
source = valid_image(source, min_dims=size, max_dims=size)
executor.submit(
output,
run_img2img_pipeline,
@ -538,7 +538,7 @@ def img2img():
params,
output,
upscale,
source_image,
source,
strength,
needs_device=device,
)
@ -577,10 +577,10 @@ def inpaint():
return error_reply("mask image is required")
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
mask_file = request.files.get("mask")
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
expand = border_from_request()
@ -619,8 +619,8 @@ def inpaint():
)
logger.info("inpaint job queued for: %s", output)
source_image = valid_image(source_image, min_dims=size, max_dims=size)
mask_image = valid_image(mask_image, min_dims=size, max_dims=size)
source = valid_image(source, min_dims=size, max_dims=size)
mask = valid_image(mask, min_dims=size, max_dims=size)
executor.submit(
output,
run_inpaint_pipeline,
@ -629,8 +629,8 @@ def inpaint():
size,
output,
upscale,
source_image,
mask_image,
source,
mask,
expand,
noise_source,
mask_filter,
@ -649,7 +649,7 @@ def upscale():
return error_reply("source image is required")
source_file = request.files.get("source")
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source = Image.open(BytesIO(source_file.read())).convert("RGB")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
@ -657,7 +657,7 @@ def upscale():
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
source_image = valid_image(source_image, min_dims=size, max_dims=size)
source = valid_image(source, min_dims=size, max_dims=size)
executor.submit(
output,
run_upscale_pipeline,
@ -666,7 +666,7 @@ def upscale():
size,
output,
upscale,
source_image,
source,
needs_device=device,
)
@ -723,9 +723,9 @@ def chain():
stage.name,
)
source_file = request.files.get(stage_source_name)
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image = valid_image(source_image, max_dims=(size.width, size.height))
kwargs["source_image"] = source_image
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source
if stage_mask_name in request.files:
logger.debug(
@ -734,9 +734,9 @@ def chain():
stage.name,
)
mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image = valid_image(mask_image, max_dims=(size.width, size.height))
kwargs["mask_image"] = mask_image
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))