1
0
Fork 0

feat(api): make pipeline stages support multiple images

This commit is contained in:
Sean Sube 2023-07-04 13:29:58 -05:00
parent f718087a54
commit 37185252a5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
22 changed files with 482 additions and 434 deletions

View File

@ -90,10 +90,10 @@ class ChainPipeline:
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image] = None,
source: List[Image.Image],
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
) -> Image.Image:
) -> List[Image.Image]:
"""
DEPRECATED: use `run` instead
"""
@ -101,6 +101,8 @@ class ChainPipeline:
callback = ChainProgress.from_progress(callback)
start = monotonic()
# TODO: turn this into stage images
image = source
if source is not None:

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
import numpy as np
import torch
@ -22,15 +22,14 @@ class BlendImg2ImgStage(BaseStage):
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)
@ -59,39 +58,40 @@ class BlendImg2ImgStage(BaseStage):
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
outputs = []
for source in sources:
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
output = result.images[0]
outputs.extend(result.images)
logger.info("final output image size: %sx%s", output.width, output.height)
return output
return outputs

View File

@ -1,12 +1,12 @@
from logging import getLogger
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed
from ..diffusers.utils import get_latents_from_seed, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
@ -26,7 +26,7 @@ class BlendInpaintStage(BaseStage):
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
expand: Border,
stage_source: Optional[Image.Image] = None,
@ -36,95 +36,97 @@ class BlendInpaintStage(BaseStage):
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
pipe_type = "lpw" if params.lpw() else "inpaint"
_prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("inpaint")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
inversions=inversions,
loras=loras,
)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
outputs = []
for source in sources:
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
outputs.append(
process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
logger.info("final output image size: %s", output.size)
return output
)

View File

@ -18,12 +18,13 @@ class BlendLinearStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: List[Image.Image],
*,
alpha: float,
sources: Optional[List[Image.Image]] = None,
stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using linear interpolation")
) -> List[Image.Image]:
logger.info("blending source images using linear interpolation")
return Image.blend(sources[1], sources[0], alpha)
return [Image.blend(source, stage_source, alpha) for source in sources]

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -20,13 +20,13 @@ class BlendMaskStage(BaseStage):
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
@ -37,4 +37,4 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
return Image.composite(stage_source, source, mult_mask)
return [Image.composite(stage_source, source, mult_mask) for source in sources]

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -18,20 +18,18 @@ class CorrectCodeformerStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
source = stage_source or source
upscale = upscale.with_args(**kwargs)
device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(source)
return [pipe(source) for source in sources]

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Optional
from typing import List, Optional
import numpy as np
from PIL import Image
@ -57,31 +57,32 @@ class CorrectGFPGANStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.correction_model is None:
logger.warn("no face model given, skipping")
return source
return sources
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = self.load(server, stage, upscale, device)
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
outputs = []
for source in sources:
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
outputs.append(Image.fromarray(output, "RGB"))
return output
return outputs

View File

@ -1,4 +1,5 @@
from logging import getLogger
from typing import List
from PIL import Image
@ -18,14 +19,15 @@ class PersistDiskStage(BaseStage):
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
output: str,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
for source in sources:
# TODO: append index to output name
dest = save_image(server, output, source, params=params)
logger.info("saved image to %s", dest)
dest = save_image(server, output, source, params=params)
logger.info("saved image to %s", dest)
return source
return sources

View File

@ -1,6 +1,6 @@
from io import BytesIO
from logging import getLogger
from typing import Optional
from typing import List, Optional
from boto3 import Session
from PIL import Image
@ -20,7 +20,7 @@ class PersistS3Stage(BaseStage):
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
output: str,
bucket: str,
@ -28,20 +28,19 @@ class PersistS3Stage(BaseStage):
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)
for source in sources:
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)
try:
s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output)
except Exception:
logger.exception("error saving image to S3")
try:
s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output)
except Exception:
logger.exception("error saving image to S3")
return source
return sources

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -18,17 +18,20 @@ class ReduceCropStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
origin: Size,
size: Size,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
outputs = []
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image
for source in sources:
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
outputs.append(image)
return outputs

View File

@ -1,4 +1,5 @@
from logging import getLogger
from typing import List
from PIL import Image
@ -17,18 +18,23 @@ class ReduceThumbnailStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
size: Size,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
image = source.copy()
) -> List[Image.Image]:
outputs = []
image = image.thumbnail((size.width, size.height))
for source in sources:
image = source.copy()
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image
image = image.thumbnail((size.width, size.height))
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
outputs.append(image)
return outputs

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Callable
from typing import Callable, List
from PIL import Image
@ -18,22 +18,25 @@ class SourceNoiseStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
size: Size,
noise_source: Callable,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
logger.info("generating image from noise source")
if source is not None:
if len(sources) > 0:
logger.warn(
"a source image was passed to a noise stage, but will be discarded"
"source images were passed to a noise stage and will be discarded"
)
output = noise_source(source, (size.width, size.height), (0, 0))
outputs = []
for source in sources:
output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height)
return output
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return outputs

View File

@ -1,6 +1,6 @@
from io import BytesIO
from logging import getLogger
from typing import Optional
from typing import List, Optional
from boto3 import Session
from PIL import Image
@ -17,29 +17,30 @@ class SourceS3Stage(BaseStage):
def run(
self,
_job: WorkerContext,
server: ServerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
_sources: List[Image.Image],
*,
source_key: str,
source_keys: List[str],
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
try:
logger.info("loading image from s3://%s/%s", bucket, source_key)
data = BytesIO()
s3.download_fileobj(bucket, source_key, data)
outputs = []
for key in source_keys:
try:
logger.info("loading image from s3://%s/%s", bucket, key)
data = BytesIO()
s3.download_fileobj(bucket, key, data)
data.seek(0)
return Image.open(data)
except Exception:
logger.exception("error loading image from S3")
data.seek(0)
outputs.append(Image.open(data))
except Exception:
logger.exception("error loading image from S3")
return outputs

View File

@ -1,5 +1,6 @@
from io import BytesIO
from logging import getLogger
from typing import List
import requests
from PIL import Image
@ -19,22 +20,25 @@ class SourceURLStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
source_url: str,
source_urls: List[str],
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
logger.info("loading image from URL source")
if source is not None:
if len(sources) > 0:
logger.warn(
"a source image was passed to a source stage, and will be discarded"
)
response = requests.get(source_url)
output = Image.open(BytesIO(response.content))
outputs = []
for url in source_urls:
response = requests.get(url)
output = Image.open(BytesIO(response.content))
logger.info("final output image size: %sx%s", output.width, output.height)
return output
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return outputs

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -16,11 +16,11 @@ class BaseStage:
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
raise NotImplementedError()
def steps(

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Optional
from typing import List, Optional
import numpy as np
from PIL import Image
@ -54,48 +54,52 @@ class UpscaleBSRGANStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None:
logger.warn("no upscaling model given, skipping")
return source
return sources
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = self.load(server, stage, upscale, device)
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape)
outputs = []
for source in sources:
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
scale = upscale.outscale
dest = np.zeros(
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
)
)
)
logger.trace("BSRGAN output shape: %s", dest.shape)
logger.trace("BSRGAN output shape: %s", dest.shape)
dest = bsrgan(image)
dest = bsrgan(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
return output
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
outputs.append(output)
return outputs
def steps(
self,

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -20,25 +20,26 @@ class UpscaleHighresStage(BaseStage):
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*args,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
if highres.scale <= 1:
return source
return sources
chain = stage_highres(stage, params, highres, upscale)
return chain(
job,
server,
params,
source,
callback=callback,
)
return [
chain(
job,
server,
params,
source,
callback=callback,
)
for source in sources
]

View File

@ -1,12 +1,12 @@
from logging import getLogger
from typing import Callable, Optional, Tuple
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageOps
from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed, get_tile_latents
from ..diffusers.utils import get_latents_from_seed, get_tile_latents, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
@ -26,7 +26,7 @@ class UpscaleOutpaintStage(BaseStage):
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
border: Border,
stage_source: Optional[Image.Image] = None,
@ -36,123 +36,130 @@ class UpscaleOutpaintStage(BaseStage):
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
) -> List[Image.Image]:
_prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height)
outputs = []
for source in sources:
logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height)
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
draw_mask = ImageDraw.Draw(stage_mask)
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
draw_mask = ImageDraw.Draw(stage_mask)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_tile_latents(full_latents, dims, size)
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: load LoRAs and TIs
)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
tile_source,
tile_mask,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
latents = get_tile_latents(full_latents, dims, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
tile_source,
tile_mask,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
return output
return outputs

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Optional
from typing import List, Optional
import numpy as np
from PIL import Image
@ -77,20 +77,25 @@ class UpscaleRealESRGANStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source)
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
outputs = []
for source in sources:
output = np.array(source)
upsampler = self.load(
server, upscale, job.get_device(), tile=stage.tile_size
)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
return output
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
outputs.append(output)
return outputs

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -18,30 +18,32 @@ class UpscaleSimpleStage(BaseStage):
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
method: str,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
) -> List[Image.Image]:
if upscale.scale <= 1:
logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale
)
return source
return sources
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
outputs = []
for source in sources:
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.warning("unknown upscaling method: %s", method)
if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.warning("unknown upscaling method: %s", method)
return source
outputs.append(source)
return outputs

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Optional
from typing import List, Optional
import torch
from PIL import Image
@ -22,16 +22,15 @@ class UpscaleStableDiffusionStage(BaseStage):
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs)
source = stage_source or source
logger.info(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
)
@ -55,14 +54,19 @@ class UpscaleStableDiffusionStage(BaseStage):
)
pipeline.unet.set_prompts(prompt_embeds)
return pipeline(
params.prompt,
source,
generator=generator,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
noise_level=upscale.denoise,
callback=callback,
).images[0]
outputs = []
for source in sources:
result = pipeline(
params.prompt,
source,
generator=generator,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
noise_level=upscale.denoise,
callback=callback,
)
outputs.extend(result.image)
return outputs

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Optional
from typing import List, Optional
import numpy as np
from PIL import Image
@ -54,45 +54,48 @@ class UpscaleSwinIRStage(BaseStage):
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
sources: List[Image.Image],
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
) -> List[Image.Image]:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None:
logger.warn("no correction model given, skipping")
return source
return sources
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device()
swinir = self.load(server, stage, upscale, device)
# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape)
outputs = []
for source in sources:
# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
scale = upscale.outscale
dest = np.zeros(
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
)
)
)
logger.trace("SwinIR output shape: %s", dest.shape)
logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
return output
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
outputs.append(output)
return outputs