1
0
Fork 0

feat(api): make chain stages into classes with max tile size and step count estimate

This commit is contained in:
Sean Sube 2023-07-01 07:10:53 -05:00
parent 5e1b70091c
commit 2913cd0382
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
29 changed files with 1121 additions and 1019 deletions

View File

@ -17,7 +17,7 @@ from .diffusers.run import (
run_upscale_pipeline,
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import stage_upscale_correction
from .chain.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
)

View File

@ -1,44 +1,44 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_linear import blend_linear
from .blend_mask import blend_mask
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk
from .persist_s3 import persist_s3
from .reduce_crop import reduce_crop
from .reduce_thumbnail import reduce_thumbnail
from .source_noise import source_noise
from .source_s3 import source_s3
from .source_txt2img import source_txt2img
from .source_url import source_url
from .upscale_bsrgan import upscale_bsrgan
from .upscale_highres import upscale_highres
from .upscale_outpaint import upscale_outpaint
from .upscale_resrgan import upscale_resrgan
from .upscale_stable_diffusion import upscale_stable_diffusion
from .upscale_swinir import upscale_swinir
from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .persist_disk import PersistDiskStage
from .persist_s3 import PersistS3Stage
from .reduce_crop import ReduceCropStage
from .reduce_thumbnail import ReduceThumbnailStage
from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
CHAIN_STAGES = {
"blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint,
"blend-linear": blend_linear,
"blend-mask": blend_mask,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk,
"persist-s3": persist_s3,
"reduce-crop": reduce_crop,
"reduce-thumbnail": reduce_thumbnail,
"source-noise": source_noise,
"source-s3": source_s3,
"source-txt2img": source_txt2img,
"source-url": source_url,
"upscale-bsrgan": upscale_bsrgan,
"upscale-highres": upscale_highres,
"upscale-outpaint": upscale_outpaint,
"upscale-resrgan": upscale_resrgan,
"upscale-stable-diffusion": upscale_stable_diffusion,
"upscale-swinir": upscale_swinir,
"blend-img2img": BlendImg2ImgStage,
"blend-inpaint": BlendInpaintStage,
"blend-linear": BlendLinearStage,
"blend-mask": BlendMaskStage,
"correct-codeformer": CorrectCodeformerStage,
"correct-gfpgan": CorrectGFPGANStage,
"persist-disk": PersistDiskStage,
"persist-s3": PersistS3Stage,
"reduce-crop": ReduceCropStage,
"reduce-thumbnail": ReduceThumbnailStage,
"source-noise": SourceNoiseStage,
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,
"upscale-resrgan": UpscaleRealESRGANStage,
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
"upscale-swinir": UpscaleSwinIRStage,
}

View File

@ -10,6 +10,7 @@ from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .utils import process_tile_order
logger = getLogger(__name__)
@ -35,7 +36,7 @@ class StageCallback(Protocol):
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
@ -131,7 +132,7 @@ class ChainPipeline:
logger.info("running pipeline without source image")
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
@ -158,7 +159,7 @@ class ChainPipeline:
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(
tile = stage_pipe.run(
job,
server,
stage_params,
@ -182,7 +183,7 @@ class ChainPipeline:
)
else:
logger.debug("image within tile size, running stage")
image = stage_pipe(
image = stage_pipe.run(
job,
server,
stage_params,

View File

@ -14,77 +14,81 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_img2img(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)
prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
class BlendImg2ImgStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
strength: float,
callback: Optional[ProgressCallback] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
source = stage_source or source
logger.info(
"blending image using img2img, %s steps: %s", params.steps, params.prompt
)
output = result.images[0]
prompt_pairs, loras, inversions = parse_prompt(params)
logger.info("final output image size: %sx%s", output.width, output.height)
return output
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
pipe_params = {}
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
if params.lpw():
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
**pipe_params,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -18,105 +18,112 @@ from .utils import process_tile_order
logger = getLogger(__name__)
def blend_inpaint(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
expand: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
class BlendInpaintStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
expand: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
expand = expand.with_args(**kwargs)
source = source or stage_source
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
pipe_type = "lpw" if params.lpw() else "inpaint"
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
source, stage_mask, noise, _full_dims = expand_image(
source,
stage_mask,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
pipe_type = "lpw" if params.lpw() else "inpaint"
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: add LoRAs and TIs
)
return result.images[0]
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
output = process_tile_order(
stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap
)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
logger.info("final output image size: %s", output.size)
return output
latents = get_latents_from_seed(params.seed, size)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=tile_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=tile_source,
latents=latents,
mask_image=stage_mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
eta=params.eta,
callback=callback,
)
return result.images[0]
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=params.overlap,
)
logger.info("final output image size: %s", output.size)
return output

View File

@ -10,17 +10,19 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_linear(
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
*,
alpha: float,
sources: Optional[List[Image.Image]] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using linear interpolation")
class BlendLinearStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
*,
alpha: float,
sources: Optional[List[Image.Image]] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using linear interpolation")
return Image.blend(sources[1], sources[0], alpha)
return Image.blend(sources[1], sources[0], alpha)

View File

@ -12,26 +12,28 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def blend_mask(
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using mask")
class BlendMaskStage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
_callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L")
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L")
if is_debug():
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
if is_debug():
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
return Image.composite(stage_source, source, mult_mask)
return Image.composite(stage_source, source, mult_mask)

View File

@ -9,28 +9,28 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
device = "cpu"
class CorrectCodeformerStage:
def run(
self,
job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
def correct_codeformer(
job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
stage_source: Optional[Image.Image] = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
source = stage_source or source
source = stage_source or source
upscale = upscale.with_args(**kwargs)
upscale = upscale.with_args(**kwargs)
device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(source)
device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(source)

View File

@ -13,72 +13,74 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def load_gfpgan(
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from gfpgan import GFPGANer
class CorrectGFPGANStage:
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from gfpgan import GFPGANer
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
cache_key = (face_path,)
cache_pipe = server.cache.get("gfpgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline")
return cache_pipe
if cache_pipe is not None:
logger.info("reusing existing GFPGAN pipeline")
return cache_pipe
logger.debug("loading GFPGAN model from %s", face_path)
logger.debug("loading GFPGAN model from %s", face_path)
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
arch="clean",
bg_upsampler=None,
channel_multiplier=2,
device=device.torch_str(),
model_path=face_path,
upscale=upscale.face_outscale,
)
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
arch="clean",
bg_upsampler=None,
channel_multiplier=2,
device=device.torch_str(),
model_path=face_path,
upscale=upscale.face_outscale,
)
server.cache.set("gfpgan", cache_key, gfpgan)
run_gc([device])
server.cache.set("gfpgan", cache_key, gfpgan)
run_gc([device])
return gfpgan
return gfpgan
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
def correct_gfpgan(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.correction_model is None:
logger.warn("no face model given, skipping")
return source
if upscale.correction_model is None:
logger.warn("no face model given, skipping")
return source
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = self.load(server, stage, upscale, device)
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = load_gfpgan(server, stage, upscale, device)
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
output = np.array(source)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
return output
return output

View File

@ -10,19 +10,21 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def persist_disk(
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
output: str,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
class PersistDiskStage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
output: str,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
dest = save_image(server, output, source, params=params)
logger.info("saved image to %s", dest)
return source
dest = save_image(server, output, source, params=params)
logger.info("saved image to %s", dest)
return source

View File

@ -12,33 +12,35 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def persist_s3(
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
output: str,
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
class PersistS3Stage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
output: str,
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)
data = BytesIO()
source.save(data, format=server.image_format)
data.seek(0)
try:
s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output)
except Exception:
logger.exception("error saving image to S3")
try:
s3.upload_fileobj(data, bucket, output)
logger.info("saved image to s3://%s/%s", bucket, output)
except Exception:
logger.exception("error saving image to S3")
return source
return source

View File

@ -10,20 +10,24 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def reduce_crop(
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
origin: Size,
size: Size,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
class ReduceCropStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
origin: Size,
size: Size,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image
image = source.crop((origin.width, origin.height, size.width, size.height))
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image

View File

@ -9,21 +9,25 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def reduce_thumbnail(
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
size: Size,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
image = source.copy()
class ReduceThumbnailStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
size: Size,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
image = source.copy()
image = image.thumbnail((size.width, size.height))
image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image
logger.info(
"created thumbnail with dimensions: %sx%s", image.width, image.height
)
return image

View File

@ -10,25 +10,29 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def source_noise(
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
size: Size,
noise_source: Callable,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("generating image from noise source")
class SourceNoiseStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
size: Size,
noise_source: Callable,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("generating image from noise source")
if source is not None:
logger.warn("a source image was passed to a noise stage, but will be discarded")
if source is not None:
logger.warn(
"a source image was passed to a noise stage, but will be discarded"
)
output = noise_source(source, (size.width, size.height), (0, 0))
output = noise_source(source, (size.width, size.height), (0, 0))
logger.info("final output image size: %sx%s", output.width, output.height)
return output
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -12,31 +12,33 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def source_s3(
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
source_key: str,
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
class SourceS3Stage:
def run(
self,
_job: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
source_key: str,
bucket: str,
endpoint_url: Optional[str] = None,
profile_name: Optional[str] = None,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
session = Session(profile_name=profile_name)
s3 = session.client("s3", endpoint_url=endpoint_url)
try:
logger.info("loading image from s3://%s/%s", bucket, source_key)
data = BytesIO()
s3.download_fileobj(bucket, source_key, data)
try:
logger.info("loading image from s3://%s/%s", bucket, source_key)
data = BytesIO()
s3.download_fileobj(bucket, source_key, data)
data.seek(0)
return Image.open(data)
except Exception:
logger.exception("error loading image from S3")
data.seek(0)
return Image.open(data)
except Exception:
logger.exception("error loading image from S3")

View File

@ -14,74 +14,78 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def source_txt2img(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
_source: Image.Image,
*,
size: Size,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
size = size.with_args(**kwargs)
logger.info(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
)
if "stage_source" in kwargs:
logger.warn(
"a source image was passed to a txt2img stage, and will be discarded"
class SourceTxt2ImgStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
_source: Image.Image,
*,
size: Size,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
size = size.with_args(**kwargs)
logger.info(
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
)
prompt_pairs, loras, inversions = parse_prompt(params)
if "stage_source" in kwargs:
logger.warn(
"a source image was passed to a txt2img stage, and will be discarded"
)
latents = get_latents_from_seed(params.seed, size)
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
prompt_pairs, loras, inversions = parse_prompt(params)
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
pipe.unet.set_prompts(prompt_embeds)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
latents = get_latents_from_seed(params.seed, size)
pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
inversions=inversions,
loras=loras,
)
output = result.images[0]
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
else:
# encode and record alternative prompts outside of LPW
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()
)
pipe.unet.set_prompts(prompt_embeds)
logger.info("final output image size: %sx%s", output.width, output.height)
return output
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
height=size.height,
width=size.width,
generator=rng,
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=callback,
)
output = result.images[0]
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -11,27 +11,29 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def source_url(
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
source_url: str,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("loading image from URL source")
class SourceURLStage:
def run(
self,
_job: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
source_url: str,
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("loading image from URL source")
if source is not None:
logger.warn(
"a source image was passed to a source stage, and will be discarded"
)
if source is not None:
logger.warn(
"a source image was passed to a source stage, and will be discarded"
)
response = requests.get(source_url)
output = Image.open(BytesIO(response.content))
response = requests.get(source_url)
output = Image.open(BytesIO(response.content))
logger.info("final output image size: %sx%s", output.width, output.height)
return output
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -0,0 +1,31 @@
from typing import Optional
from PIL import Image
from onnx_web.params import ImageParams, Size, SizeChart, StageParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
class BaseStage:
max_tile = SizeChart.auto
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*args,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
raise NotImplementedError()
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
raise NotImplementedError()

View File

@ -1,14 +1,14 @@
from logging import getLogger
from typing import List, Optional, Tuple
from ..chain import ChainPipeline, PipelineStage
from ..chain.correct_codeformer import correct_codeformer
from ..chain.correct_gfpgan import correct_gfpgan
from ..chain.upscale_bsrgan import upscale_bsrgan
from ..chain.upscale_resrgan import upscale_resrgan
from ..chain.upscale_stable_diffusion import upscale_stable_diffusion
from ..chain.upscale_swinir import upscale_swinir
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
from . import ChainPipeline, PipelineStage
from .correct_codeformer import CorrectCodeformerStage
from .correct_gfpgan import CorrectGFPGANStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_resrgan import UpscaleRealESRGANStage
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
from .upscale_swinir import UpscaleSwinIRStage
logger = getLogger(__name__)
@ -72,23 +72,23 @@ def stage_upscale_correction(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts)
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
elif "esrgan" in upscale.upscale_model:
esrgan_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts)
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size)
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts)
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
elif "swinir" in upscale.upscale_model:
swinir_params = StageParams(
tile_size=stage.tile_size,
outscale=upscale.outscale,
)
upscale_stage = (upscale_swinir, swinir_params, upscale_opts)
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
@ -98,9 +98,9 @@ def stage_upscale_correction(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
correct_stage = (correct_codeformer, face_params, upscale_opts)
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
elif "gfpgan" in upscale.correction_model:
correct_stage = (correct_gfpgan, face_params, upscale_opts)
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
else:
logger.warn("unknown correction model: %s", upscale.correction_model)

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..server import ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -14,105 +14,121 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def load_bsrgan(
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
class UpscaleBSRGANStage:
max_tile = 64
if cache_pipe is not None:
logger.debug("reusing existing BSRGAN pipeline")
return cache_pipe
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("bsrgan", cache_key)
logger.info("loading BSRGAN model from %s", model_path)
if cache_pipe is not None:
logger.debug("reusing existing BSRGAN pipeline")
return cache_pipe
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
logger.info("loading BSRGAN model from %s", model_path)
server.cache.set("bsrgan", cache_key, pipe)
run_gc([device])
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
return pipe
server.cache.set("bsrgan", cache_key, pipe)
run_gc([device])
return pipe
def upscale_bsrgan(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None:
logger.warn("no upscaling model given, skipping")
return source
if upscale.upscale_model is None:
logger.warn("no upscaling model given, skipping")
return source
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = load_bsrgan(server, stage, upscale, device)
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
device = job.get_device()
bsrgan = self.load(server, stage, upscale, device)
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape)
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
)
logger.trace("BSRGAN output shape: %s", dest.shape)
for x in range(tile_x):
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
scale = upscale.outscale
dest = np.zeros(
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
)
)
logger.trace("BSRGAN output shape: %s", dest.shape)
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
for x in range(tile_x):
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
return output
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
return output
def steps(
self,
_params: ImageParams,
size: Size,
) -> int:
return size.width // self.max_tile * size.height // self.max_tile

View File

@ -3,70 +3,71 @@ from typing import Any, Optional
from PIL import Image
from ..chain.base import ChainPipeline
from ..chain.blend_img2img import blend_img2img
from ..diffusers.upscale import stage_upscale_correction
from ..chain import BlendImg2ImgStage, ChainPipeline
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
from ..worker import WorkerContext
from ..worker.context import ProgressCallback
from .upscale import stage_upscale_correction
logger = getLogger(__name__)
def upscale_highres(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
class UpscaleHighresStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
pipeline: Optional[Any] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
if highres.scale <= 1:
return source
if highres.scale <= 1:
return source
chain = ChainPipeline()
scaled_size = (source.width * highres.scale, source.height * highres.scale)
chain = ChainPipeline()
scaled_size = (source.width * highres.scale, source.height * highres.scale)
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
if highres.method == "bilinear":
logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
elif highres.method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
else:
logger.debug("using upscaling pipeline for highres")
stage_upscale_correction(
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
)
chain.stage(
BlendImg2ImgStage(),
StageParams(),
params,
upscale=upscale.with_args(
faces=False,
scale=highres.scale,
outscale=highres.scale,
),
chain=chain,
overlap=params.overlap,
strength=highres.strength,
)
chain.stage(
blend_img2img,
StageParams(),
overlap=params.overlap,
strength=highres.strength,
)
return chain(
job,
server,
params,
source,
callback=callback,
)
return chain(
job,
server,
params,
source,
callback=callback,
)

View File

@ -18,138 +18,140 @@ from .utils import complete_tile, process_tile_grid, process_tile_order
logger = getLogger(__name__)
def upscale_outpaint(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
border: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
class UpscaleOutpaintStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
border: Border,
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height)
margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom))
overlap = min(margin_x / source.width, margin_y / source.height)
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
if stage_mask is None:
# if no mask was provided, keep the full source image
stage_mask = Image.new("RGB", source.size, "black")
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
# masks start as 512x512, resize to cover the source, then trim the extra
mask_max = max(source.width, source.height)
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
source, stage_mask, noise, full_size = expand_image(
source,
stage_mask,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter,
)
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
draw_mask = ImageDraw.Draw(stage_mask)
if is_debug():
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
draw_mask = ImageDraw.Draw(stage_mask)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
save_image(server, "last-source.png", source)
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-noise.png", noise)
latents = get_tile_latents(full_latents, dims, size)
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: load LoRAs and TIs
)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
tile_source,
tile_mask,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*tile_source.size)
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
tile_mask = complete_tile(tile_mask, tile)
if is_debug():
save_image(server, "tile-source.png", tile_source)
save_image(server, "tile-mask.png", tile_mask)
latents = get_tile_latents(full_latents, dims, size)
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
server,
params,
pipe_type,
job.get_device(),
# TODO: load LoRAs and TIs
)
if params.lpw():
logger.debug("using LPW pipeline for inpaint")
rng = torch.manual_seed(params.seed)
result = pipe.inpaint(
tile_source,
tile_mask,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
callback=callback,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
tile_source,
tile_mask,
height=size.height,
width=size.width,
num_inference_steps=params.steps,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
generator=rng,
latents=latents,
callback=callback,
)
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
elif border.left == border.right and border.top == border.bottom:
logger.debug(
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
output = process_tile_order(
stage.tile_order,
source,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
logger.info("final output image size: %sx%s", output.width, output.height)
return output
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -16,79 +16,80 @@ logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3"
def load_resrgan(
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
# TODO: rewrite and remove
from realesrgan import RealESRGANer
class UpscaleRealESRGANStage:
def load(
self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
):
# must be within load function for patches to take effect
# TODO: rewrite and remove
from realesrgan import RealESRGANer
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file)
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file)
cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
cache_key = (model_path, params.format)
cache_pipe = server.cache.get("resrgan", cache_key)
if cache_pipe is not None:
logger.info("reusing existing Real ESRGAN pipeline")
return cache_pipe
if not path.isfile(model_path):
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
if not path.isfile(model_path):
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
# TODO: swap for regular RRDBNet after rewriting wrapper
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
# TODO: swap for regular RRDBNet after rewriting wrapper
model = OnnxRRDBNet(
server,
model_file,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
dni_weight = None
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]
dni_weight = None
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise]
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file
model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model))
upsampler = RealESRGANer(
scale=params.scale,
model_path=model_path_pth,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=False, # TODO: use server optimizations
)
# TODO: shouldn't need the PTH file
model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model))
upsampler = RealESRGANer(
scale=params.scale,
model_path=model_path_pth,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=False, # TODO: use server optimizations
)
server.cache.set("resrgan", cache_key, upsampler)
run_gc([device])
server.cache.set("resrgan", cache_key, upsampler)
run_gc([device])
return upsampler
return upsampler
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
def upscale_resrgan(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
source = stage_source or source
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
output = np.array(source)
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
output = np.array(source)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
return output
output = Image.fromarray(output, "RGB")
logger.info("final output image size: %sx%s", output.width, output.height)
return output

View File

@ -14,52 +14,54 @@ from ..worker import ProgressCallback, WorkerContext
logger = getLogger(__name__)
def upscale_stable_diffusion(
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs)
source = stage_source or source
logger.info(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
)
class UpscaleStableDiffusionStage:
def run(
self,
job: WorkerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> Image.Image:
params = params.with_args(**kwargs)
upscale = upscale.with_args(**kwargs)
source = stage_source or source
logger.info(
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
)
prompt_pairs, _loras, _inversions = parse_prompt(params)
prompt_pairs, _loras, _inversions = parse_prompt(params)
pipeline = load_pipeline(
server,
params,
"upscale",
job.get_device(),
model=path.join(server.model_path, upscale.upscale_model),
)
generator = torch.manual_seed(params.seed)
pipeline = load_pipeline(
server,
params,
"upscale",
job.get_device(),
model=path.join(server.model_path, upscale.upscale_model),
)
generator = torch.manual_seed(params.seed)
prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipeline.unet.set_prompts(prompt_embeds)
prompt_embeds = encode_prompt(
pipeline,
prompt_pairs,
num_images_per_prompt=params.batch,
do_classifier_free_guidance=params.do_cfg(),
)
pipeline.unet.set_prompts(prompt_embeds)
return pipeline(
params.prompt,
source,
generator=generator,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
noise_level=upscale.denoise,
callback=callback,
).images[0]
return pipeline(
params.prompt,
source,
generator=generator,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
eta=params.eta,
noise_level=upscale.denoise,
callback=callback,
).images[0]

View File

@ -14,107 +14,116 @@ from ..worker import WorkerContext
logger = getLogger(__name__)
def load_swinir(
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
class UpscaleSwinIRStage:
max_tile = 64
if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
return cache_pipe
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get("swinir", cache_key)
logger.debug("loading SwinIR model from %s", model_path)
if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
return cache_pipe
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
logger.debug("loading SwinIR model from %s", model_path)
server.cache.set("swinir", cache_key, pipe)
run_gc([device])
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
return pipe
server.cache.set("swinir", cache_key, pipe)
run_gc([device])
return pipe
def upscale_swinir(
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
def run(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> Image.Image:
upscale = upscale.with_args(**kwargs)
source = stage_source or source
if upscale.upscale_model is None:
logger.warn("no correction model given, skipping")
return source
if upscale.upscale_model is None:
logger.warn("no correction model given, skipping")
return source
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device()
swinir = load_swinir(server, stage, upscale, device)
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = job.get_device()
swinir = self.load(server, stage, upscale, device)
# TODO: add support for other sizes
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
# TODO: add support for other sizes
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape)
# TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.info("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
dest = np.zeros(
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
)
logger.info("SwinIR output shape: %s", dest.shape)
for x in range(tile_x):
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
scale = upscale.outscale
dest = np.zeros(
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
)
)
logger.info("SwinIR output shape: %s", dest.shape)
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
for x in range(tile_x):
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
return output
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.info("output image size: %s x %s", output.width, output.height)
return output

View File

@ -4,13 +4,14 @@ from typing import Any, List, Optional
from PIL import Image
from ..chain import (
blend_img2img,
blend_mask,
source_txt2img,
upscale_highres,
upscale_outpaint,
BlendImg2ImgStage,
BlendMaskStage,
ChainPipeline,
SourceTxt2ImgStage,
UpscaleHighresStage,
UpscaleOutpaintStage,
)
from ..chain.base import ChainPipeline
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..output import save_image
from ..params import (
Border,
@ -24,7 +25,6 @@ from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc, show_system_toast
from ..worker import WorkerContext
from .upscale import split_upscale, stage_upscale_correction
from .utils import parse_prompt
logger = getLogger(__name__)
@ -43,7 +43,7 @@ def run_txt2img_pipeline(
chain = ChainPipeline()
stage = StageParams()
chain.stage(
source_txt2img,
SourceTxt2ImgStage(),
stage,
size=size,
)
@ -61,7 +61,7 @@ def run_txt2img_pipeline(
# apply highres
for _i in range(highres.iterations):
chain.stage(
upscale_highres,
UpscaleHighresStage(),
StageParams(
outscale=highres.scale,
),
@ -125,7 +125,7 @@ def run_img2img_pipeline(
chain = ChainPipeline()
stage = StageParams()
chain.stage(
blend_img2img,
BlendImg2ImgStage(),
stage,
strength=strength,
)
@ -144,7 +144,7 @@ def run_img2img_pipeline(
if params.loopback > 0:
for _i in range(params.loopback):
chain.stage(
blend_img2img,
BlendImg2ImgStage(),
stage,
strength=strength,
)
@ -153,7 +153,7 @@ def run_img2img_pipeline(
if highres.iterations > 0:
for _i in range(highres.iterations):
chain.stage(
upscale_highres,
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
@ -223,7 +223,7 @@ def run_inpaint_pipeline(
chain = ChainPipeline()
stage = StageParams(tile_order=tile_order)
chain.stage(
upscale_outpaint,
UpscaleOutpaintStage(),
stage,
border=border,
stage_mask=mask,
@ -234,7 +234,7 @@ def run_inpaint_pipeline(
# apply highres
chain.stage(
upscale_highres,
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
@ -300,7 +300,7 @@ def run_upscale_pipeline(
# apply highres
chain.stage(
upscale_highres,
UpscaleHighresStage(),
stage,
highres=highres,
upscale=upscale,
@ -353,7 +353,7 @@ def run_blend_pipeline(
# set up the chain pipeline and base stage
chain = ChainPipeline()
stage = StageParams()
chain.stage(blend_mask, stage, stage_source=sources[1], stage_mask=mask)
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
# apply upscaling and correction
stage_upscale_correction(

View File

@ -1,6 +1,4 @@
from typing import Tuple, Union
from PIL import Image, ImageChops, ImageOps
from PIL import Image, ImageChops
from ..params import Border, Size
from .mask_filter import mask_filter_none

View File

@ -356,12 +356,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
callback = CHAIN_STAGES[stage_data.get("type")]
stage_class = CHAIN_STAGES[stage_data.get("type")]
kwargs = stage_data.get("params", {})
logger.info("request stage: %s, %s", callback.__name__, kwargs)
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
stage = StageParams(
stage_data.get("name", callback.__name__),
stage_data.get("name", stage_class.__name__),
tile_size=get_size(kwargs.get("tile_size")),
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
)
@ -399,7 +399,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
kwargs["stage_mask"] = mask
pipeline.append((callback, stage, kwargs))
pipeline.append((stage_class(), stage, kwargs))
logger.info("running chain pipeline with %s stages", len(pipeline.stages))