1
0
Fork 0

feat(api): make pipeline stages support multiple images

This commit is contained in:
Sean Sube 2023-07-04 13:29:58 -05:00
parent f718087a54
commit 37185252a5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
22 changed files with 482 additions and 434 deletions

View File

@ -90,10 +90,10 @@ class ChainPipeline:
job: WorkerContext, job: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
source: Optional[Image.Image] = None, source: List[Image.Image],
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**pipeline_kwargs **pipeline_kwargs
) -> Image.Image: ) -> List[Image.Image]:
""" """
DEPRECATED: use `run` instead DEPRECATED: use `run` instead
""" """
@ -101,6 +101,8 @@ class ChainPipeline:
callback = ChainProgress.from_progress(callback) callback = ChainProgress.from_progress(callback)
start = monotonic() start = monotonic()
# TODO: turn this into stage images
image = source image = source
if source is not None: if source is not None:

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
@ -22,15 +22,14 @@ class BlendImg2ImgStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
strength: float, strength: float,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
source = stage_source or source
logger.info( logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt "blending image using img2img, %s steps: %s", params.steps, params.prompt
) )
@ -59,39 +58,40 @@ class BlendImg2ImgStage(BaseStage):
elif pipe_type == "pix2pix": elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength pipe_params["image_guidance_scale"] = strength
if params.lpw(): outputs = []
logger.debug("using LPW pipeline for img2img") for source in sources:
rng = torch.manual_seed(params.seed) if params.lpw():
result = pipe.img2img( logger.debug("using LPW pipeline for img2img")
source, rng = torch.manual_seed(params.seed)
params.prompt, result = pipe.img2img(
generator=rng, source,
guidance_scale=params.cfg, params.prompt,
negative_prompt=params.negative_prompt, generator=rng,
num_inference_steps=params.steps, guidance_scale=params.cfg,
callback=callback, negative_prompt=params.negative_prompt,
**pipe_params, num_inference_steps=params.steps,
) callback=callback,
else: **pipe_params,
# encode and record alternative prompts outside of LPW )
prompt_embeds = encode_prompt( else:
pipe, prompt_pairs, params.batch, params.do_cfg() # encode and record alternative prompts outside of LPW
) prompt_embeds = encode_prompt(
pipe.unet.set_prompts(prompt_embeds) pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
params.prompt, params.prompt,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
image=source, image=source,
negative_prompt=params.negative_prompt, negative_prompt=params.negative_prompt,
num_inference_steps=params.steps, num_inference_steps=params.steps,
callback=callback, callback=callback,
**pipe_params, **pipe_params,
) )
output = result.images[0] outputs.extend(result.images)
logger.info("final output image size: %sx%s", output.width, output.height) return outputs
return output

View File

@ -1,12 +1,12 @@
from logging import getLogger from logging import getLogger
from typing import Callable, Optional, Tuple from typing import Callable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed from ..diffusers.utils import get_latents_from_seed, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
@ -26,7 +26,7 @@ class BlendInpaintStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
expand: Border, expand: Border,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
@ -36,95 +36,97 @@ class BlendInpaintStage(BaseStage):
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs) expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info( logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt "blending image using inpaint, %s steps: %s", params.steps, params.prompt
) )
if stage_mask is None: _prompt_pairs, loras, inversions = parse_prompt(params)
# if no mask was provided, keep the full source image pipe_type = params.get_valid_pipeline("inpaint")
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
pipe_type = "lpw" if params.lpw() else "inpaint"
pipe = load_pipeline( pipe = load_pipeline(
server, server,
params, params,
pipe_type, pipe_type,
job.get_device(), job.get_device(),
# TODO: add LoRAs and TIs inversions=inversions,
loras=loras,
) )
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): outputs = []
left, top, tile = dims for source in sources:
size = Size(*tile_source.size) if stage_mask is None:
tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) # if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug(): if is_debug():
save_image(server, "tile-source.png", tile_source) save_image(server, "last-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_latents_from_seed(params.seed, size) def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
if params.lpw(): left, top, tile = dims
logger.debug("using LPW pipeline for inpaint") size = Size(*tile_source.size)
rng = torch.manual_seed(params.seed) tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
result = pipe.inpaint(
params.prompt, if is_debug():
generator=rng, save_image(server, "tile-source.png", tile_source)
guidance_scale=params.cfg, save_image(server, "tile-mask.png", tile_mask)
height=size.height,
image=tile_source, latents = get_latents_from_seed(params.seed, size)
latents=latents, if params.lpw():
mask_image=tile_mask, logger.debug("using LPW pipeline for inpaint")
negative_prompt=params.negative_prompt, rng = torch.manual_seed(params.seed)
num_inference_steps=params.steps, result = pipe.inpaint(
width=size.width, params.prompt,
eta=params.eta, generator=rng,
callback=callback, guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
outputs.append(
process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
) )
else: )
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
logger.info("final output image size: %s", output.size)
return output

View File

@ -18,12 +18,13 @@ class BlendLinearStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
sources: List[Image.Image],
*, *,
alpha: float, alpha: float,
sources: Optional[List[Image.Image]] = None, stage_source: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
logger.info("blending image using linear interpolation") logger.info("blending source images using linear interpolation")
return Image.blend(sources[1], sources[0], alpha) return [Image.blend(source, stage_source, alpha) for source in sources]

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -20,13 +20,13 @@ class BlendMaskStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None, _callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
logger.info("blending image using mask") logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black") mult_mask = Image.new("RGBA", stage_mask.size, color="black")
@ -37,4 +37,4 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-mult-mask.png", mult_mask)
return Image.composite(stage_source, source, mult_mask) return [Image.composite(stage_source, source, mult_mask) for source in sources]

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -18,20 +18,18 @@ class CorrectCodeformerStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
# must be within the load function for patch to take effect # must be within the load function for patch to take effect
# TODO: rewrite and remove # TODO: rewrite and remove
from codeformer import CodeFormer from codeformer import CodeFormer
source = stage_source or source
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
device = job.get_device() device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(source) return [pipe(source) for source in sources]

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -57,31 +57,32 @@ class CorrectGFPGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.correction_model is None: if upscale.correction_model is None:
logger.warn("no face model given, skipping") logger.warn("no face model given, skipping")
return source return sources
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device() device = job.get_device()
gfpgan = self.load(server, stage, upscale, device) gfpgan = self.load(server, stage, upscale, device)
output = np.array(source) outputs = []
_, _, output = gfpgan.enhance( for source in sources:
output, output = np.array(source)
has_aligned=False, _, _, output = gfpgan.enhance(
only_center_face=False, output,
paste_back=True, has_aligned=False,
weight=upscale.face_strength, only_center_face=False,
) paste_back=True,
output = Image.fromarray(output, "RGB") weight=upscale.face_strength,
)
outputs.append(Image.fromarray(output, "RGB"))
return output return outputs

View File

@ -1,4 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List
from PIL import Image from PIL import Image
@ -18,14 +19,15 @@ class PersistDiskStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
output: str, output: str,
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source for source in sources:
# TODO: append index to output name
dest = save_image(server, output, source, params=params)
logger.info("saved image to %s", dest)
dest = save_image(server, output, source, params=params) return sources
logger.info("saved image to %s", dest)
return source

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from boto3 import Session from boto3 import Session
from PIL import Image from PIL import Image
@ -20,7 +20,7 @@ class PersistS3Stage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
output: str, output: str,
bucket: str, bucket: str,
@ -28,20 +28,19 @@ class PersistS3Stage(BaseStage):
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO() for source in sources:
source.save(data, format=server.image_format) data = BytesIO()
data.seek(0) source.save(data, format=server.image_format)
data.seek(0)
try: try:
s3.upload_fileobj(data, bucket, output) s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output) logger.info("saved image to s3://%s/%s", bucket, output)
except Exception: except Exception:
logger.exception("error saving image to S3") logger.exception("error saving image to S3")
return source return sources

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -18,17 +18,20 @@ class ReduceCropStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
origin: Size, origin: Size,
size: Size, size: Size,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source outputs = []
image = source.crop((origin.width, origin.height, size.width, size.height)) for source in sources:
logger.info( image = source.crop((origin.width, origin.height, size.width, size.height))
"created thumbnail with dimensions: %sx%s", image.width, image.height logger.info(
) "created thumbnail with dimensions: %sx%s", image.width, image.height
return image )
outputs.append(image)
return outputs

View File

@ -1,4 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import List
from PIL import Image from PIL import Image
@ -17,18 +18,23 @@ class ReduceThumbnailStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
size: Size, size: Size,
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source outputs = []
image = source.copy()
image = image.thumbnail((size.width, size.height)) for source in sources:
image = source.copy()
logger.info( image = image.thumbnail((size.width, size.height))
"created thumbnail with dimensions: %sx%s", image.width, image.height
) logger.info(
return image "created thumbnail with dimensions: %sx%s", image.width, image.height
)
outputs.append(image)
return outputs

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Callable from typing import Callable, List
from PIL import Image from PIL import Image
@ -18,22 +18,25 @@ class SourceNoiseStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
size: Size, size: Size,
noise_source: Callable, noise_source: Callable,
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
logger.info("generating image from noise source") logger.info("generating image from noise source")
if source is not None: if len(sources) > 0:
logger.warn( logger.warn(
"a source image was passed to a noise stage, but will be discarded" "source images were passed to a noise stage and will be discarded"
) )
output = noise_source(source, (size.width, size.height), (0, 0)) outputs = []
for source in sources:
output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output outputs.append(output)
return outputs

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from boto3 import Session from boto3 import Session
from PIL import Image from PIL import Image
@ -17,29 +17,30 @@ class SourceS3Stage(BaseStage):
def run( def run(
self, self,
_job: WorkerContext, _job: WorkerContext,
server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, _sources: List[Image.Image],
*, *,
source_key: str, source_keys: List[str],
bucket: str, bucket: str,
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None, profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url) s3 = session.client("s3", endpoint_url=endpoint_url)
try: outputs = []
logger.info("loading image from s3://%s/%s", bucket, source_key) for key in source_keys:
data = BytesIO() try:
s3.download_fileobj(bucket, source_key, data) logger.info("loading image from s3://%s/%s", bucket, key)
data = BytesIO()
s3.download_fileobj(bucket, key, data)
data.seek(0) data.seek(0)
return Image.open(data) outputs.append(Image.open(data))
except Exception: except Exception:
logger.exception("error loading image from S3") logger.exception("error loading image from S3")
return outputs

View File

@ -1,5 +1,6 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from typing import List
import requests import requests
from PIL import Image from PIL import Image
@ -19,22 +20,25 @@ class SourceURLStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
source_url: str, source_urls: List[str],
stage_source: Image.Image, stage_source: Image.Image,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
logger.info("loading image from URL source") logger.info("loading image from URL source")
if source is not None: if len(sources) > 0:
logger.warn( logger.warn(
"a source image was passed to a source stage, and will be discarded" "a source image was passed to a source stage, and will be discarded"
) )
response = requests.get(source_url) outputs = []
output = Image.open(BytesIO(response.content)) for url in source_urls:
response = requests.get(url)
output = Image.open(BytesIO(response.content))
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output outputs.append(output)
return outputs

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -16,11 +16,11 @@ class BaseStage:
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*args, *args,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
raise NotImplementedError() raise NotImplementedError()
def steps( def steps(

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -54,48 +54,52 @@ class UpscaleBSRGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None: if upscale.upscale_model is None:
logger.warn("no upscaling model given, skipping") logger.warn("no upscaling model given, skipping")
return source return sources
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model) logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device() device = job.get_device()
bsrgan = self.load(server, stage, upscale, device) bsrgan = self.load(server, stage, upscale, device)
image = np.array(source) / 255.0 outputs = []
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) for source in sources:
image = np.expand_dims(image, axis=0) image = np.array(source) / 255.0
logger.trace("BSRGAN input shape: %s", image.shape) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( dest = np.zeros(
( (
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
)
) )
) logger.trace("BSRGAN output shape: %s", dest.shape)
logger.trace("BSRGAN output shape: %s", dest.shape)
dest = bsrgan(image) dest = bsrgan(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB") output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height) logger.debug("output image size: %s x %s", output.width, output.height)
return output
outputs.append(output)
return outputs
def steps( def steps(
self, self,

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -20,25 +20,26 @@ class UpscaleHighresStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*args, *args,
highres: HighresParams, highres: HighresParams,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
if highres.scale <= 1: if highres.scale <= 1:
return source return sources
chain = stage_highres(stage, params, highres, upscale) chain = stage_highres(stage, params, highres, upscale)
return chain( return [
job, chain(
server, job,
params, server,
source, params,
callback=callback, source,
) callback=callback,
)
for source in sources
]

View File

@ -1,12 +1,12 @@
from logging import getLogger from logging import getLogger
from typing import Callable, Optional, Tuple from typing import Callable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageDraw, ImageOps from PIL import Image, ImageDraw, ImageOps
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
from ..diffusers.utils import get_latents_from_seed, get_tile_latents from ..diffusers.utils import get_latents_from_seed, get_tile_latents, parse_prompt
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
@ -26,7 +26,7 @@ class UpscaleOutpaintStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
border: Border, border: Border,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
@ -36,123 +36,130 @@ class UpscaleOutpaintStage(BaseStage):
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source _prompt_pairs, loras, inversions = parse_prompt(params)
logger.info(
"upscaling %s x %s image by expanding borders: %s", pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
source.width, pipe = load_pipeline(
source.height, server,
border, params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
) )
margin_x = float(max(border.left, border.right)) outputs = []
margin_y = float(max(border.top, border.bottom)) for source in sources:
overlap = min(margin_x / source.width, margin_y / source.height) logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
if stage_mask is None: margin_x = float(max(border.left, border.right))
# if no mask was provided, keep the full source image margin_y = float(max(border.top, border.bottom))
stage_mask = Image.new("RGB", source.size, "black") overlap = min(margin_x / source.width, margin_y / source.height)
# masks start as 512x512, resize to cover the source, then trim the extra if stage_mask is None:
mask_max = max(source.width, source.height) # if no mask was provided, keep the full source image
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max)) stage_mask = Image.new("RGB", source.size, "black")
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
source, stage_mask, noise, full_size = expand_image( # masks start as 512x512, resize to cover the source, then trim the extra
source, mask_max = max(source.width, source.height)
stage_mask, stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
border, stage_mask = stage_mask.crop((0, 0, source.width, source.height))
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
full_latents = get_latents_from_seed(params.seed, Size(*full_size)) source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
draw_mask = ImageDraw.Draw(stage_mask) full_latents = get_latents_from_seed(params.seed, Size(*full_size))
if is_debug(): draw_mask = ImageDraw.Draw(stage_mask)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
if is_debug(): if is_debug():
save_image(server, "tile-source.png", tile_source) save_image(server, "last-source.png", source)
save_image(server, "tile-mask.png", tile_mask) save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_tile_latents(full_latents, dims, size) def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline) left, top, tile = dims
pipe = load_pipeline( size = Size(*tile_source.size)
server, tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
params, tile_mask = complete_tile(tile_mask, tile)
pipe_type,
job.get_device(), if is_debug():
# TODO: load LoRAs and TIs save_image(server, "tile-source.png", tile_source)
) save_image(server, "tile-mask.png", tile_mask)
if params.lpw():
logger.debug("using LPW pipeline for inpaint") latents = get_tile_latents(full_latents, dims, size)
rng = torch.manual_seed(params.seed) if params.lpw():
result = pipe.inpaint( logger.debug("using LPW pipeline for inpaint")
tile_source, rng = torch.manual_seed(params.seed)
tile_mask, result = pipe.inpaint(
params.prompt, tile_source,
generator=rng, tile_mask,
guidance_scale=params.cfg, params.prompt,
height=size.height, generator=rng,
latents=latents, guidance_scale=params.cfg,
negative_prompt=params.negative_prompt, height=size.height,
num_inference_steps=params.steps, latents=latents,
width=size.width, negative_prompt=params.negative_prompt,
callback=callback, num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
) )
else: else:
rng = np.random.RandomState(params.seed) logger.debug("outpainting with an uneven border, using grid tiling")
result = pipe( output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it logger.info("final output image size: %sx%s", output.width, output.height)
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") outputs.append(output)
return result.images[0]
if params.pipeline == "panorama": return outputs
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -77,20 +77,25 @@ class UpscaleRealESRGANStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source) outputs = []
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size) for source in sources:
output = np.array(source)
upsampler = self.load(
server, upscale, job.get_device(), tile=stage.tile_size
)
output, _ = upsampler.enhance(output, outscale=upscale.outscale) output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, "RGB") output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height) logger.info("final output image size: %sx%s", output.width, output.height)
return output outputs.append(output)
return outputs

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -18,30 +18,32 @@ class UpscaleSimpleStage(BaseStage):
_server: ServerContext, _server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
method: str, method: str,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
source = stage_source or source
if upscale.scale <= 1: if upscale.scale <= 1:
logger.debug( logger.debug(
"simple upscale stage run with scale of %s, skipping", upscale.scale "simple upscale stage run with scale of %s, skipping", upscale.scale
) )
return source return sources
scaled_size = (source.width * upscale.scale, source.height * upscale.scale) outputs = []
for source in sources:
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
if method == "bilinear": if method == "bilinear":
logger.debug("using bilinear interpolation for highres") logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif method == "lanczos": elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres") logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else: else:
logger.warning("unknown upscaling method: %s", method) logger.warning("unknown upscaling method: %s", method)
return source outputs.append(source)
return outputs

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Optional from typing import List, Optional
import torch import torch
from PIL import Image from PIL import Image
@ -22,16 +22,15 @@ class UpscaleStableDiffusionStage(BaseStage):
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
source = stage_source or source
logger.info( logger.info(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt "upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
) )
@ -55,14 +54,19 @@ class UpscaleStableDiffusionStage(BaseStage):
) )
pipeline.unet.set_prompts(prompt_embeds) pipeline.unet.set_prompts(prompt_embeds)
return pipeline( outputs = []
params.prompt, for source in sources:
source, result = pipeline(
generator=generator, params.prompt,
guidance_scale=params.cfg, source,
negative_prompt=params.negative_prompt, generator=generator,
num_inference_steps=params.steps, guidance_scale=params.cfg,
eta=params.eta, negative_prompt=params.negative_prompt,
noise_level=upscale.denoise, num_inference_steps=params.steps,
callback=callback, eta=params.eta,
).images[0] noise_level=upscale.denoise,
callback=callback,
)
outputs.extend(result.image)
return outputs

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -54,45 +54,48 @@ class UpscaleSwinIRStage(BaseStage):
server: ServerContext, server: ServerContext,
stage: StageParams, stage: StageParams,
_params: ImageParams, _params: ImageParams,
source: Image.Image, sources: List[Image.Image],
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> List[Image.Image]:
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None: if upscale.upscale_model is None:
logger.warn("no correction model given, skipping") logger.warn("no correction model given, skipping")
return source return sources
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device() device = job.get_device()
swinir = self.load(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)
# TODO: add support for grayscale (1-channel) images outputs = []
image = np.array(source) / 255.0 for source in sources:
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) # TODO: add support for grayscale (1-channel) images
image = np.expand_dims(image, axis=0) image = np.array(source) / 255.0
logger.trace("SwinIR input shape: %s", image.shape) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( dest = np.zeros(
( (
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
)
) )
) logger.trace("SwinIR output shape: %s", dest.shape)
logger.trace("SwinIR output shape: %s", dest.shape)
dest = swinir(image) dest = swinir(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB") output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height) logger.info("output image size: %s x %s", output.width, output.height)
return output outputs.append(output)
return outputs