feat(api): make pipeline stages support multiple images
This commit is contained in:
parent
f718087a54
commit
37185252a5
|
@ -90,10 +90,10 @@ class ChainPipeline:
|
|||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
source: Optional[Image.Image] = None,
|
||||
source: List[Image.Image],
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**pipeline_kwargs
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
"""
|
||||
DEPRECATED: use `run` instead
|
||||
"""
|
||||
|
@ -101,6 +101,8 @@ class ChainPipeline:
|
|||
callback = ChainProgress.from_progress(callback)
|
||||
|
||||
start = monotonic()
|
||||
|
||||
# TODO: turn this into stage images
|
||||
image = source
|
||||
|
||||
if source is not None:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -22,15 +22,14 @@ class BlendImg2ImgStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
strength: float,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
params = params.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
logger.info(
|
||||
"blending image using img2img, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
@ -59,39 +58,40 @@ class BlendImg2ImgStage(BaseStage):
|
|||
elif pipe_type == "pix2pix":
|
||||
pipe_params["image_guidance_scale"] = strength
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
source,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=callback,
|
||||
**pipe_params,
|
||||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
outputs = []
|
||||
for source in sources:
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
source,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=callback,
|
||||
**pipe_params,
|
||||
)
|
||||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
)
|
||||
pipe.unet.set_prompts(prompt_embeds)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=callback,
|
||||
**pipe_params,
|
||||
)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=callback,
|
||||
**pipe_params,
|
||||
)
|
||||
|
||||
output = result.images[0]
|
||||
outputs.extend(result.images)
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
return outputs
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import get_latents_from_seed
|
||||
from ..diffusers.utils import get_latents_from_seed, parse_prompt
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
|
@ -26,7 +26,7 @@ class BlendInpaintStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
expand: Border,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
|
@ -36,95 +36,97 @@ class BlendInpaintStage(BaseStage):
|
|||
noise_source: Callable = noise_source_histogram,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
params = params.with_args(**kwargs)
|
||||
expand = expand.with_args(**kwargs)
|
||||
source = source or stage_source
|
||||
logger.info(
|
||||
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
||||
if stage_mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
stage_mask = Image.new("RGB", source.size, "black")
|
||||
|
||||
source, stage_mask, noise, _full_dims = expand_image(
|
||||
source,
|
||||
stage_mask,
|
||||
expand,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-source.png", source)
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
pipe_type = "lpw" if params.lpw() else "inpaint"
|
||||
_prompt_pairs, loras, inversions = parse_prompt(params)
|
||||
pipe_type = params.get_valid_pipeline("inpaint")
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
# TODO: add LoRAs and TIs
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||
outputs = []
|
||||
for source in sources:
|
||||
if stage_mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
stage_mask = Image.new("RGB", source.size, "black")
|
||||
|
||||
source, stage_mask, noise, _full_dims = expand_image(
|
||||
source,
|
||||
stage_mask,
|
||||
expand,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", tile_source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
save_image(server, "last-source.png", source)
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=tile_source,
|
||||
latents=latents,
|
||||
mask_image=tile_mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", tile_source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=tile_source,
|
||||
latents=latents,
|
||||
mask_image=tile_mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=tile_source,
|
||||
latents=latents,
|
||||
mask_image=stage_mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return result.images[0]
|
||||
|
||||
outputs.append(
|
||||
process_tile_order(
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=params.overlap,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=tile_source,
|
||||
latents=latents,
|
||||
mask_image=stage_mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
eta=params.eta,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return result.images[0]
|
||||
|
||||
output = process_tile_order(
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=params.overlap,
|
||||
)
|
||||
|
||||
logger.info("final output image size: %s", output.size)
|
||||
return output
|
||||
)
|
||||
|
|
|
@ -18,12 +18,13 @@ class BlendLinearStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
alpha: float,
|
||||
sources: Optional[List[Image.Image]] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info("blending image using linear interpolation")
|
||||
) -> List[Image.Image]:
|
||||
logger.info("blending source images using linear interpolation")
|
||||
|
||||
return Image.blend(sources[1], sources[0], alpha)
|
||||
return [Image.blend(source, stage_source, alpha) for source in sources]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -20,13 +20,13 @@ class BlendMaskStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
stage_mask: Optional[Image.Image] = None,
|
||||
_callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
logger.info("blending image using mask")
|
||||
|
||||
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
||||
|
@ -37,4 +37,4 @@ class BlendMaskStage(BaseStage):
|
|||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-mult-mask.png", mult_mask)
|
||||
|
||||
return Image.composite(stage_source, source, mult_mask)
|
||||
return [Image.composite(stage_source, source, mult_mask) for source in sources]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -18,20 +18,18 @@ class CorrectCodeformerStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
# must be within the load function for patch to take effect
|
||||
# TODO: rewrite and remove
|
||||
from codeformer import CodeFormer
|
||||
|
||||
source = stage_source or source
|
||||
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
|
||||
device = job.get_device()
|
||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||
return pipe(source)
|
||||
return [pipe(source) for source in sources]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -57,31 +57,32 @@ class CorrectGFPGANStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
|
||||
if upscale.correction_model is None:
|
||||
logger.warn("no face model given, skipping")
|
||||
return source
|
||||
return sources
|
||||
|
||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||
device = job.get_device()
|
||||
gfpgan = self.load(server, stage, upscale, device)
|
||||
|
||||
output = np.array(source)
|
||||
_, _, output = gfpgan.enhance(
|
||||
output,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=upscale.face_strength,
|
||||
)
|
||||
output = Image.fromarray(output, "RGB")
|
||||
outputs = []
|
||||
for source in sources:
|
||||
output = np.array(source)
|
||||
_, _, output = gfpgan.enhance(
|
||||
output,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=upscale.face_strength,
|
||||
)
|
||||
outputs.append(Image.fromarray(output, "RGB"))
|
||||
|
||||
return output
|
||||
return outputs
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -18,14 +19,15 @@ class PersistDiskStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
output: str,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
) -> List[Image.Image]:
|
||||
for source in sources:
|
||||
# TODO: append index to output name
|
||||
dest = save_image(server, output, source, params=params)
|
||||
logger.info("saved image to %s", dest)
|
||||
|
||||
dest = save_image(server, output, source, params=params)
|
||||
logger.info("saved image to %s", dest)
|
||||
return source
|
||||
return sources
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from boto3 import Session
|
||||
from PIL import Image
|
||||
|
@ -20,7 +20,7 @@ class PersistS3Stage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
output: str,
|
||||
bucket: str,
|
||||
|
@ -28,20 +28,19 @@ class PersistS3Stage(BaseStage):
|
|||
profile_name: Optional[str] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
||||
) -> List[Image.Image]:
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
data = BytesIO()
|
||||
source.save(data, format=server.image_format)
|
||||
data.seek(0)
|
||||
for source in sources:
|
||||
data = BytesIO()
|
||||
source.save(data, format=server.image_format)
|
||||
data.seek(0)
|
||||
|
||||
try:
|
||||
s3.upload_fileobj(data, bucket, output)
|
||||
logger.info("saved image to s3://%s/%s", bucket, output)
|
||||
except Exception:
|
||||
logger.exception("error saving image to S3")
|
||||
try:
|
||||
s3.upload_fileobj(data, bucket, output)
|
||||
logger.info("saved image to s3://%s/%s", bucket, output)
|
||||
except Exception:
|
||||
logger.exception("error saving image to S3")
|
||||
|
||||
return source
|
||||
return sources
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -18,17 +18,20 @@ class ReduceCropStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
origin: Size,
|
||||
size: Size,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
) -> List[Image.Image]:
|
||||
outputs = []
|
||||
|
||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
return image
|
||||
for source in sources:
|
||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -17,18 +18,23 @@ class ReduceThumbnailStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
size: Size,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
image = source.copy()
|
||||
) -> List[Image.Image]:
|
||||
outputs = []
|
||||
|
||||
image = image.thumbnail((size.width, size.height))
|
||||
for source in sources:
|
||||
image = source.copy()
|
||||
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
return image
|
||||
image = image.thumbnail((size.width, size.height))
|
||||
|
||||
logger.info(
|
||||
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||
)
|
||||
|
||||
outputs.append(image)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable
|
||||
from typing import Callable, List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -18,22 +18,25 @@ class SourceNoiseStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
size: Size,
|
||||
noise_source: Callable,
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
) -> List[Image.Image]:
|
||||
logger.info("generating image from noise source")
|
||||
|
||||
if source is not None:
|
||||
if len(sources) > 0:
|
||||
logger.warn(
|
||||
"a source image was passed to a noise stage, but will be discarded"
|
||||
"source images were passed to a noise stage and will be discarded"
|
||||
)
|
||||
|
||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||
outputs = []
|
||||
for source in sources:
|
||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from boto3 import Session
|
||||
from PIL import Image
|
||||
|
@ -17,29 +17,30 @@ class SourceS3Stage(BaseStage):
|
|||
def run(
|
||||
self,
|
||||
_job: WorkerContext,
|
||||
server: ServerContext,
|
||||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
_sources: List[Image.Image],
|
||||
*,
|
||||
source_key: str,
|
||||
source_keys: List[str],
|
||||
bucket: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
profile_name: Optional[str] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
||||
) -> List[Image.Image]:
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||
|
||||
try:
|
||||
logger.info("loading image from s3://%s/%s", bucket, source_key)
|
||||
data = BytesIO()
|
||||
s3.download_fileobj(bucket, source_key, data)
|
||||
outputs = []
|
||||
for key in source_keys:
|
||||
try:
|
||||
logger.info("loading image from s3://%s/%s", bucket, key)
|
||||
data = BytesIO()
|
||||
s3.download_fileobj(bucket, key, data)
|
||||
|
||||
data.seek(0)
|
||||
return Image.open(data)
|
||||
except Exception:
|
||||
logger.exception("error loading image from S3")
|
||||
data.seek(0)
|
||||
outputs.append(Image.open(data))
|
||||
except Exception:
|
||||
logger.exception("error loading image from S3")
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
@ -19,22 +20,25 @@ class SourceURLStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
source_url: str,
|
||||
source_urls: List[str],
|
||||
stage_source: Image.Image,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
) -> List[Image.Image]:
|
||||
logger.info("loading image from URL source")
|
||||
|
||||
if source is not None:
|
||||
if len(sources) > 0:
|
||||
logger.warn(
|
||||
"a source image was passed to a source stage, and will be discarded"
|
||||
)
|
||||
|
||||
response = requests.get(source_url)
|
||||
output = Image.open(BytesIO(response.content))
|
||||
outputs = []
|
||||
for url in source_urls:
|
||||
response = requests.get(url)
|
||||
output = Image.open(BytesIO(response.content))
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -16,11 +16,11 @@ class BaseStage:
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*args,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def steps(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -54,48 +54,52 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
logger.warn("no upscaling model given, skipping")
|
||||
return source
|
||||
return sources
|
||||
|
||||
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
bsrgan = self.load(server, stage, upscale, device)
|
||||
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||
outputs = []
|
||||
for source in sources:
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||
|
||||
dest = bsrgan(image)
|
||||
dest = bsrgan(image)
|
||||
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||
return output
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
||||
def steps(
|
||||
self,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -20,25 +20,26 @@ class UpscaleHighresStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*args,
|
||||
highres: HighresParams,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
||||
) -> List[Image.Image]:
|
||||
if highres.scale <= 1:
|
||||
return source
|
||||
return sources
|
||||
|
||||
chain = stage_highres(stage, params, highres, upscale)
|
||||
|
||||
return chain(
|
||||
job,
|
||||
server,
|
||||
params,
|
||||
source,
|
||||
callback=callback,
|
||||
)
|
||||
return [
|
||||
chain(
|
||||
job,
|
||||
server,
|
||||
params,
|
||||
source,
|
||||
callback=callback,
|
||||
)
|
||||
for source in sources
|
||||
]
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageOps
|
||||
|
||||
from ..diffusers.load import load_pipeline
|
||||
from ..diffusers.utils import get_latents_from_seed, get_tile_latents
|
||||
from ..diffusers.utils import get_latents_from_seed, get_tile_latents, parse_prompt
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
|
@ -26,7 +26,7 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
border: Border,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
|
@ -36,123 +36,130 @@ class UpscaleOutpaintStage(BaseStage):
|
|||
noise_source: Callable = noise_source_histogram,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
logger.info(
|
||||
"upscaling %s x %s image by expanding borders: %s",
|
||||
source.width,
|
||||
source.height,
|
||||
border,
|
||||
) -> List[Image.Image]:
|
||||
_prompt_pairs, loras, inversions = parse_prompt(params)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
|
||||
margin_x = float(max(border.left, border.right))
|
||||
margin_y = float(max(border.top, border.bottom))
|
||||
overlap = min(margin_x / source.width, margin_y / source.height)
|
||||
outputs = []
|
||||
for source in sources:
|
||||
logger.info(
|
||||
"upscaling %s x %s image by expanding borders: %s",
|
||||
source.width,
|
||||
source.height,
|
||||
border,
|
||||
)
|
||||
|
||||
if stage_mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
stage_mask = Image.new("RGB", source.size, "black")
|
||||
margin_x = float(max(border.left, border.right))
|
||||
margin_y = float(max(border.top, border.bottom))
|
||||
overlap = min(margin_x / source.width, margin_y / source.height)
|
||||
|
||||
# masks start as 512x512, resize to cover the source, then trim the extra
|
||||
mask_max = max(source.width, source.height)
|
||||
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
|
||||
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
|
||||
if stage_mask is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
stage_mask = Image.new("RGB", source.size, "black")
|
||||
|
||||
source, stage_mask, noise, full_size = expand_image(
|
||||
source,
|
||||
stage_mask,
|
||||
border,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
# masks start as 512x512, resize to cover the source, then trim the extra
|
||||
mask_max = max(source.width, source.height)
|
||||
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
|
||||
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
|
||||
|
||||
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
|
||||
source, stage_mask, noise, full_size = expand_image(
|
||||
source,
|
||||
stage_mask,
|
||||
border,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter,
|
||||
)
|
||||
|
||||
draw_mask = ImageDraw.Draw(stage_mask)
|
||||
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-source.png", source)
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||
tile_mask = complete_tile(tile_mask, tile)
|
||||
draw_mask = ImageDraw.Draw(stage_mask)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", tile_source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
save_image(server, "last-source.png", source)
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-noise.png", noise)
|
||||
|
||||
latents = get_tile_latents(full_latents, dims, size)
|
||||
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params,
|
||||
pipe_type,
|
||||
job.get_device(),
|
||||
# TODO: load LoRAs and TIs
|
||||
)
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
tile_source,
|
||||
tile_mask,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
callback=callback,
|
||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*tile_source.size)
|
||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||
tile_mask = complete_tile(tile_mask, tile)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "tile-source.png", tile_source)
|
||||
save_image(server, "tile-mask.png", tile_mask)
|
||||
|
||||
latents = get_tile_latents(full_latents, dims, size)
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for inpaint")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.inpaint(
|
||||
tile_source,
|
||||
tile_mask,
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
callback=callback,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
tile_source,
|
||||
tile_mask,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
generator=rng,
|
||||
latents=latents,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# once part of the image has been drawn, keep it
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||
return result.images[0]
|
||||
|
||||
if params.pipeline == "panorama":
|
||||
logger.debug("outpainting with one shot panorama, no tiling")
|
||||
return outpaint(source, (0, 0, max(source.width, source.height)))
|
||||
if overlap == 0:
|
||||
logger.debug("outpainting with 0 margin, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
elif border.left == border.right and border.top == border.bottom:
|
||||
logger.debug(
|
||||
"outpainting with an even border, using spiral tiling with %s overlap",
|
||||
overlap,
|
||||
)
|
||||
output = process_tile_order(
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=overlap,
|
||||
)
|
||||
else:
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
tile_source,
|
||||
tile_mask,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
num_inference_steps=params.steps,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
generator=rng,
|
||||
latents=latents,
|
||||
callback=callback,
|
||||
)
|
||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
# once part of the image has been drawn, keep it
|
||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||
return result.images[0]
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
if params.pipeline == "panorama":
|
||||
logger.debug("outpainting with one shot panorama, no tiling")
|
||||
return outpaint(source, (0, 0, max(source.width, source.height)))
|
||||
if overlap == 0:
|
||||
logger.debug("outpainting with 0 margin, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
elif border.left == border.right and border.top == border.bottom:
|
||||
logger.debug(
|
||||
"outpainting with an even border, using spiral tiling with %s overlap",
|
||||
overlap,
|
||||
)
|
||||
output = process_tile_order(
|
||||
stage.tile_order,
|
||||
source,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=overlap,
|
||||
)
|
||||
else:
|
||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
return outputs
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -77,20 +77,25 @@ class UpscaleRealESRGANStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
) -> List[Image.Image]:
|
||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||
|
||||
output = np.array(source)
|
||||
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||
outputs = []
|
||||
for source in sources:
|
||||
output = np.array(source)
|
||||
upsampler = self.load(
|
||||
server, upscale, job.get_device(), tile=stage.tile_size
|
||||
)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
output = Image.fromarray(output, "RGB")
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
return output
|
||||
output = Image.fromarray(output, "RGB")
|
||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -18,30 +18,32 @@ class UpscaleSimpleStage(BaseStage):
|
|||
_server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
method: str,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
source = stage_source or source
|
||||
|
||||
) -> List[Image.Image]:
|
||||
if upscale.scale <= 1:
|
||||
logger.debug(
|
||||
"simple upscale stage run with scale of %s, skipping", upscale.scale
|
||||
)
|
||||
return source
|
||||
return sources
|
||||
|
||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||
outputs = []
|
||||
for source in sources:
|
||||
scaled_size = (source.width * upscale.scale, source.height * upscale.scale)
|
||||
|
||||
if method == "bilinear":
|
||||
logger.debug("using bilinear interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
elif method == "lanczos":
|
||||
logger.debug("using Lanczos interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
logger.warning("unknown upscaling method: %s", method)
|
||||
if method == "bilinear":
|
||||
logger.debug("using bilinear interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
elif method == "lanczos":
|
||||
logger.debug("using Lanczos interpolation for highres")
|
||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
logger.warning("unknown upscaling method: %s", method)
|
||||
|
||||
return source
|
||||
outputs.append(source)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
@ -22,16 +22,15 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
params = params.with_args(**kwargs)
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
logger.info(
|
||||
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
@ -55,14 +54,19 @@ class UpscaleStableDiffusionStage(BaseStage):
|
|||
)
|
||||
pipeline.unet.set_prompts(prompt_embeds)
|
||||
|
||||
return pipeline(
|
||||
params.prompt,
|
||||
source,
|
||||
generator=generator,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
noise_level=upscale.denoise,
|
||||
callback=callback,
|
||||
).images[0]
|
||||
outputs = []
|
||||
for source in sources:
|
||||
result = pipeline(
|
||||
params.prompt,
|
||||
source,
|
||||
generator=generator,
|
||||
guidance_scale=params.cfg,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
eta=params.eta,
|
||||
noise_level=upscale.denoise,
|
||||
callback=callback,
|
||||
)
|
||||
outputs.extend(result.image)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -54,45 +54,48 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source: Image.Image,
|
||||
sources: List[Image.Image],
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> List[Image.Image]:
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
source = stage_source or source
|
||||
|
||||
if upscale.upscale_model is None:
|
||||
logger.warn("no correction model given, skipping")
|
||||
return source
|
||||
return sources
|
||||
|
||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||
device = job.get_device()
|
||||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
# TODO: add support for grayscale (1-channel) images
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("SwinIR input shape: %s", image.shape)
|
||||
outputs = []
|
||||
for source in sources:
|
||||
# TODO: add support for grayscale (1-channel) images
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
logger.trace("SwinIR input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
scale = upscale.outscale
|
||||
dest = np.zeros(
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.trace("SwinIR output shape: %s", dest.shape)
|
||||
logger.trace("SwinIR output shape: %s", dest.shape)
|
||||
|
||||
dest = swinir(image)
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
dest = swinir(image)
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.info("output image size: %s x %s", output.width, output.height)
|
||||
return output
|
||||
output = Image.fromarray(dest, "RGB")
|
||||
logger.info("output image size: %s x %s", output.width, output.height)
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
|
Loading…
Reference in New Issue