feat(api): make chain stages into classes with max tile size and step count estimate
This commit is contained in:
parent
5e1b70091c
commit
2913cd0382
|
@ -17,7 +17,7 @@ from .diffusers.run import (
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from .diffusers.stub_scheduler import StubScheduler
|
from .diffusers.stub_scheduler import StubScheduler
|
||||||
from .diffusers.upscale import stage_upscale_correction
|
from .chain.upscale import stage_upscale_correction
|
||||||
from .image.utils import (
|
from .image.utils import (
|
||||||
expand_image,
|
expand_image,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,44 +1,44 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||||
from .blend_img2img import blend_img2img
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
from .blend_inpaint import blend_inpaint
|
from .blend_inpaint import BlendInpaintStage
|
||||||
from .blend_linear import blend_linear
|
from .blend_linear import BlendLinearStage
|
||||||
from .blend_mask import blend_mask
|
from .blend_mask import BlendMaskStage
|
||||||
from .correct_codeformer import correct_codeformer
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
from .correct_gfpgan import correct_gfpgan
|
from .correct_gfpgan import CorrectGFPGANStage
|
||||||
from .persist_disk import persist_disk
|
from .persist_disk import PersistDiskStage
|
||||||
from .persist_s3 import persist_s3
|
from .persist_s3 import PersistS3Stage
|
||||||
from .reduce_crop import reduce_crop
|
from .reduce_crop import ReduceCropStage
|
||||||
from .reduce_thumbnail import reduce_thumbnail
|
from .reduce_thumbnail import ReduceThumbnailStage
|
||||||
from .source_noise import source_noise
|
from .source_noise import SourceNoiseStage
|
||||||
from .source_s3 import source_s3
|
from .source_s3 import SourceS3Stage
|
||||||
from .source_txt2img import source_txt2img
|
from .source_txt2img import SourceTxt2ImgStage
|
||||||
from .source_url import source_url
|
from .source_url import SourceURLStage
|
||||||
from .upscale_bsrgan import upscale_bsrgan
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||||
from .upscale_highres import upscale_highres
|
from .upscale_highres import UpscaleHighresStage
|
||||||
from .upscale_outpaint import upscale_outpaint
|
from .upscale_outpaint import UpscaleOutpaintStage
|
||||||
from .upscale_resrgan import upscale_resrgan
|
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||||
from .upscale_stable_diffusion import upscale_stable_diffusion
|
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||||
from .upscale_swinir import upscale_swinir
|
from .upscale_swinir import UpscaleSwinIRStage
|
||||||
|
|
||||||
CHAIN_STAGES = {
|
CHAIN_STAGES = {
|
||||||
"blend-img2img": blend_img2img,
|
"blend-img2img": BlendImg2ImgStage,
|
||||||
"blend-inpaint": blend_inpaint,
|
"blend-inpaint": BlendInpaintStage,
|
||||||
"blend-linear": blend_linear,
|
"blend-linear": BlendLinearStage,
|
||||||
"blend-mask": blend_mask,
|
"blend-mask": BlendMaskStage,
|
||||||
"correct-codeformer": correct_codeformer,
|
"correct-codeformer": CorrectCodeformerStage,
|
||||||
"correct-gfpgan": correct_gfpgan,
|
"correct-gfpgan": CorrectGFPGANStage,
|
||||||
"persist-disk": persist_disk,
|
"persist-disk": PersistDiskStage,
|
||||||
"persist-s3": persist_s3,
|
"persist-s3": PersistS3Stage,
|
||||||
"reduce-crop": reduce_crop,
|
"reduce-crop": ReduceCropStage,
|
||||||
"reduce-thumbnail": reduce_thumbnail,
|
"reduce-thumbnail": ReduceThumbnailStage,
|
||||||
"source-noise": source_noise,
|
"source-noise": SourceNoiseStage,
|
||||||
"source-s3": source_s3,
|
"source-s3": SourceS3Stage,
|
||||||
"source-txt2img": source_txt2img,
|
"source-txt2img": SourceTxt2ImgStage,
|
||||||
"source-url": source_url,
|
"source-url": SourceURLStage,
|
||||||
"upscale-bsrgan": upscale_bsrgan,
|
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||||
"upscale-highres": upscale_highres,
|
"upscale-highres": UpscaleHighresStage,
|
||||||
"upscale-outpaint": upscale_outpaint,
|
"upscale-outpaint": UpscaleOutpaintStage,
|
||||||
"upscale-resrgan": upscale_resrgan,
|
"upscale-resrgan": UpscaleRealESRGANStage,
|
||||||
"upscale-stable-diffusion": upscale_stable_diffusion,
|
"upscale-stable-diffusion": UpscaleStableDiffusionStage,
|
||||||
"upscale-swinir": upscale_swinir,
|
"upscale-swinir": UpscaleSwinIRStage,
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ..params import ImageParams, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import is_debug
|
from ..utils import is_debug
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .stage import BaseStage
|
||||||
from .utils import process_tile_order
|
from .utils import process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -35,7 +36,7 @@ class StageCallback(Protocol):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
class ChainProgress:
|
class ChainProgress:
|
||||||
|
@ -131,7 +132,7 @@ class ChainPipeline:
|
||||||
logger.info("running pipeline without source image")
|
logger.info("running pipeline without source image")
|
||||||
|
|
||||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||||
name = stage_params.name or stage_pipe.__name__
|
name = stage_params.name or stage_pipe.__class__.__name__
|
||||||
kwargs = stage_kwargs or {}
|
kwargs = stage_kwargs or {}
|
||||||
kwargs = {**pipeline_kwargs, **kwargs}
|
kwargs = {**pipeline_kwargs, **kwargs}
|
||||||
|
|
||||||
|
@ -158,7 +159,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||||
tile = stage_pipe(
|
tile = stage_pipe.run(
|
||||||
job,
|
job,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
|
@ -182,7 +183,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("image within tile size, running stage")
|
logger.debug("image within tile size, running stage")
|
||||||
image = stage_pipe(
|
image = stage_pipe.run(
|
||||||
job,
|
job,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
|
|
|
@ -14,77 +14,81 @@ from ..worker import ProgressCallback, WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_img2img(
|
class BlendImg2ImgStage:
|
||||||
job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
strength: float,
|
source: Image.Image,
|
||||||
callback: Optional[ProgressCallback] = None,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
strength: float,
|
||||||
**kwargs,
|
callback: Optional[ProgressCallback] = None,
|
||||||
) -> Image.Image:
|
stage_source: Optional[Image.Image] = None,
|
||||||
params = params.with_args(**kwargs)
|
**kwargs,
|
||||||
source = stage_source or source
|
) -> Image.Image:
|
||||||
logger.info(
|
params = params.with_args(**kwargs)
|
||||||
"blending image using img2img, %s steps: %s", params.steps, params.prompt
|
source = stage_source or source
|
||||||
)
|
logger.info(
|
||||||
|
"blending image using img2img, %s steps: %s", params.steps, params.prompt
|
||||||
prompt_pairs, loras, inversions = parse_prompt(params)
|
|
||||||
|
|
||||||
pipe_type = params.get_valid_pipeline("img2img")
|
|
||||||
pipe = load_pipeline(
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
pipe_type,
|
|
||||||
job.get_device(),
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipe_params = {}
|
|
||||||
if pipe_type == "controlnet":
|
|
||||||
pipe_params["controlnet_conditioning_scale"] = strength
|
|
||||||
elif pipe_type == "img2img":
|
|
||||||
pipe_params["strength"] = strength
|
|
||||||
elif pipe_type == "panorama":
|
|
||||||
pipe_params["strength"] = strength
|
|
||||||
elif pipe_type == "pix2pix":
|
|
||||||
pipe_params["image_guidance_scale"] = strength
|
|
||||||
|
|
||||||
if params.lpw():
|
|
||||||
logger.debug("using LPW pipeline for img2img")
|
|
||||||
rng = torch.manual_seed(params.seed)
|
|
||||||
result = pipe.img2img(
|
|
||||||
params.prompt,
|
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
image=source,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
callback=callback,
|
|
||||||
**pipe_params,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# encode and record alternative prompts outside of LPW
|
|
||||||
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
|
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
|
||||||
result = pipe(
|
|
||||||
params.prompt,
|
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
image=source,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
callback=callback,
|
|
||||||
**pipe_params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output = result.images[0]
|
prompt_pairs, loras, inversions = parse_prompt(params)
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
pipe_type = params.get_valid_pipeline("img2img")
|
||||||
return output
|
pipe = load_pipeline(
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
pipe_type,
|
||||||
|
job.get_device(),
|
||||||
|
inversions=inversions,
|
||||||
|
loras=loras,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe_params = {}
|
||||||
|
if pipe_type == "controlnet":
|
||||||
|
pipe_params["controlnet_conditioning_scale"] = strength
|
||||||
|
elif pipe_type == "img2img":
|
||||||
|
pipe_params["strength"] = strength
|
||||||
|
elif pipe_type == "panorama":
|
||||||
|
pipe_params["strength"] = strength
|
||||||
|
elif pipe_type == "pix2pix":
|
||||||
|
pipe_params["image_guidance_scale"] = strength
|
||||||
|
|
||||||
|
if params.lpw():
|
||||||
|
logger.debug("using LPW pipeline for img2img")
|
||||||
|
rng = torch.manual_seed(params.seed)
|
||||||
|
result = pipe.img2img(
|
||||||
|
params.prompt,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
image=source,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
|
**pipe_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# encode and record alternative prompts outside of LPW
|
||||||
|
prompt_embeds = encode_prompt(
|
||||||
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
|
)
|
||||||
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
|
rng = np.random.RandomState(params.seed)
|
||||||
|
result = pipe(
|
||||||
|
params.prompt,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
image=source,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
|
**pipe_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = result.images[0]
|
||||||
|
|
||||||
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
|
return output
|
||||||
|
|
|
@ -18,105 +18,112 @@ from .utils import process_tile_order
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_inpaint(
|
class BlendInpaintStage:
|
||||||
job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
stage: StageParams,
|
job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
expand: Border,
|
source: Image.Image,
|
||||||
stage_source: Optional[Image.Image] = None,
|
*,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
expand: Border,
|
||||||
fill_color: str = "white",
|
stage_source: Optional[Image.Image] = None,
|
||||||
mask_filter: Callable = mask_filter_none,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
noise_source: Callable = noise_source_histogram,
|
fill_color: str = "white",
|
||||||
callback: Optional[ProgressCallback] = None,
|
mask_filter: Callable = mask_filter_none,
|
||||||
**kwargs,
|
noise_source: Callable = noise_source_histogram,
|
||||||
) -> Image.Image:
|
callback: Optional[ProgressCallback] = None,
|
||||||
params = params.with_args(**kwargs)
|
**kwargs,
|
||||||
expand = expand.with_args(**kwargs)
|
) -> Image.Image:
|
||||||
source = source or stage_source
|
params = params.with_args(**kwargs)
|
||||||
logger.info(
|
expand = expand.with_args(**kwargs)
|
||||||
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
|
source = source or stage_source
|
||||||
)
|
logger.info(
|
||||||
|
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
|
||||||
|
)
|
||||||
|
|
||||||
if stage_mask is None:
|
if stage_mask is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
stage_mask = Image.new("RGB", source.size, "black")
|
stage_mask = Image.new("RGB", source.size, "black")
|
||||||
|
|
||||||
source, stage_mask, noise, _full_dims = expand_image(
|
source, stage_mask, noise, _full_dims = expand_image(
|
||||||
source,
|
source,
|
||||||
stage_mask,
|
stage_mask,
|
||||||
expand,
|
expand,
|
||||||
fill=fill_color,
|
fill=fill_color,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-source.png", source)
|
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
|
||||||
save_image(server, "last-noise.png", noise)
|
|
||||||
|
|
||||||
pipe_type = "lpw" if params.lpw() else "inpaint"
|
|
||||||
pipe = load_pipeline(
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
pipe_type,
|
|
||||||
job.get_device(),
|
|
||||||
# TODO: add LoRAs and TIs
|
|
||||||
)
|
|
||||||
|
|
||||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
|
||||||
left, top, tile = dims
|
|
||||||
size = Size(*tile_source.size)
|
|
||||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "tile-source.png", tile_source)
|
save_image(server, "last-source.png", source)
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
|
save_image(server, "last-noise.png", noise)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
pipe_type = "lpw" if params.lpw() else "inpaint"
|
||||||
if params.lpw():
|
pipe = load_pipeline(
|
||||||
logger.debug("using LPW pipeline for inpaint")
|
server,
|
||||||
rng = torch.manual_seed(params.seed)
|
params,
|
||||||
result = pipe.inpaint(
|
pipe_type,
|
||||||
params.prompt,
|
job.get_device(),
|
||||||
generator=rng,
|
# TODO: add LoRAs and TIs
|
||||||
guidance_scale=params.cfg,
|
)
|
||||||
height=size.height,
|
|
||||||
image=tile_source,
|
|
||||||
latents=latents,
|
|
||||||
mask_image=tile_mask,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
width=size.width,
|
|
||||||
eta=params.eta,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rng = np.random.RandomState(params.seed)
|
|
||||||
result = pipe(
|
|
||||||
params.prompt,
|
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
height=size.height,
|
|
||||||
image=tile_source,
|
|
||||||
latents=latents,
|
|
||||||
mask_image=stage_mask,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
width=size.width,
|
|
||||||
eta=params.eta,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result.images[0]
|
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||||
|
left, top, tile = dims
|
||||||
|
size = Size(*tile_source.size)
|
||||||
|
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
output = process_tile_order(
|
if is_debug():
|
||||||
stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap
|
save_image(server, "tile-source.png", tile_source)
|
||||||
)
|
save_image(server, "tile-mask.png", tile_mask)
|
||||||
|
|
||||||
logger.info("final output image size: %s", output.size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
return output
|
if params.lpw():
|
||||||
|
logger.debug("using LPW pipeline for inpaint")
|
||||||
|
rng = torch.manual_seed(params.seed)
|
||||||
|
result = pipe.inpaint(
|
||||||
|
params.prompt,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
height=size.height,
|
||||||
|
image=tile_source,
|
||||||
|
latents=latents,
|
||||||
|
mask_image=tile_mask,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
width=size.width,
|
||||||
|
eta=params.eta,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rng = np.random.RandomState(params.seed)
|
||||||
|
result = pipe(
|
||||||
|
params.prompt,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
height=size.height,
|
||||||
|
image=tile_source,
|
||||||
|
latents=latents,
|
||||||
|
mask_image=stage_mask,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
width=size.width,
|
||||||
|
eta=params.eta,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.images[0]
|
||||||
|
|
||||||
|
output = process_tile_order(
|
||||||
|
stage.tile_order,
|
||||||
|
source,
|
||||||
|
SizeChart.auto,
|
||||||
|
1,
|
||||||
|
[outpaint],
|
||||||
|
overlap=params.overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("final output image size: %s", output.size)
|
||||||
|
return output
|
||||||
|
|
|
@ -10,17 +10,19 @@ from ..worker import ProgressCallback, WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_linear(
|
class BlendLinearStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
_server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
_server: ServerContext,
|
||||||
*,
|
_stage: StageParams,
|
||||||
alpha: float,
|
_params: ImageParams,
|
||||||
sources: Optional[List[Image.Image]] = None,
|
*,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
alpha: float,
|
||||||
**kwargs,
|
sources: Optional[List[Image.Image]] = None,
|
||||||
) -> Image.Image:
|
_callback: Optional[ProgressCallback] = None,
|
||||||
logger.info("blending image using linear interpolation")
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
logger.info("blending image using linear interpolation")
|
||||||
|
|
||||||
return Image.blend(sources[1], sources[0], alpha)
|
return Image.blend(sources[1], sources[0], alpha)
|
||||||
|
|
|
@ -12,26 +12,28 @@ from ..worker import ProgressCallback, WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_mask(
|
class BlendMaskStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
source: Image.Image,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
*,
|
||||||
_callback: Optional[ProgressCallback] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
) -> Image.Image:
|
_callback: Optional[ProgressCallback] = None,
|
||||||
logger.info("blending image using mask")
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
logger.info("blending image using mask")
|
||||||
|
|
||||||
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
||||||
mult_mask.alpha_composite(stage_mask)
|
mult_mask.alpha_composite(stage_mask)
|
||||||
mult_mask = mult_mask.convert("L")
|
mult_mask = mult_mask.convert("L")
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
save_image(server, "last-mult-mask.png", mult_mask)
|
save_image(server, "last-mult-mask.png", mult_mask)
|
||||||
|
|
||||||
return Image.composite(stage_source, source, mult_mask)
|
return Image.composite(stage_source, source, mult_mask)
|
||||||
|
|
|
@ -9,28 +9,28 @@ from ..worker import WorkerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
device = "cpu"
|
|
||||||
|
|
||||||
|
class CorrectCodeformerStage:
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
job: WorkerContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
*,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
# must be within the load function for patch to take effect
|
||||||
|
# TODO: rewrite and remove
|
||||||
|
from codeformer import CodeFormer
|
||||||
|
|
||||||
def correct_codeformer(
|
source = stage_source or source
|
||||||
job: WorkerContext,
|
|
||||||
_server: ServerContext,
|
|
||||||
_stage: StageParams,
|
|
||||||
_params: ImageParams,
|
|
||||||
source: Image.Image,
|
|
||||||
*,
|
|
||||||
stage_source: Optional[Image.Image] = None,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
**kwargs,
|
|
||||||
) -> Image.Image:
|
|
||||||
# must be within the load function for patch to take effect
|
|
||||||
# TODO: rewrite and remove
|
|
||||||
from codeformer import CodeFormer
|
|
||||||
|
|
||||||
source = stage_source or source
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
upscale = upscale.with_args(**kwargs)
|
device = job.get_device()
|
||||||
|
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||||
device = job.get_device()
|
return pipe(source)
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
|
||||||
return pipe(source)
|
|
||||||
|
|
|
@ -13,72 +13,74 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_gfpgan(
|
class CorrectGFPGANStage:
|
||||||
server: ServerContext,
|
def load(
|
||||||
_stage: StageParams,
|
self,
|
||||||
upscale: UpscaleParams,
|
server: ServerContext,
|
||||||
device: DeviceParams,
|
_stage: StageParams,
|
||||||
):
|
upscale: UpscaleParams,
|
||||||
# must be within the load function for patch to take effect
|
device: DeviceParams,
|
||||||
# TODO: rewrite and remove
|
):
|
||||||
from gfpgan import GFPGANer
|
# must be within the load function for patch to take effect
|
||||||
|
# TODO: rewrite and remove
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
|
face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model))
|
||||||
cache_key = (face_path,)
|
cache_key = (face_path,)
|
||||||
cache_pipe = server.cache.get("gfpgan", cache_key)
|
cache_pipe = server.cache.get("gfpgan", cache_key)
|
||||||
|
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing GFPGAN pipeline")
|
logger.info("reusing existing GFPGAN pipeline")
|
||||||
return cache_pipe
|
return cache_pipe
|
||||||
|
|
||||||
logger.debug("loading GFPGAN model from %s", face_path)
|
logger.debug("loading GFPGAN model from %s", face_path)
|
||||||
|
|
||||||
# TODO: find a way to pass the ONNX model to underlying architectures
|
# TODO: find a way to pass the ONNX model to underlying architectures
|
||||||
gfpgan = GFPGANer(
|
gfpgan = GFPGANer(
|
||||||
arch="clean",
|
arch="clean",
|
||||||
bg_upsampler=None,
|
bg_upsampler=None,
|
||||||
channel_multiplier=2,
|
channel_multiplier=2,
|
||||||
device=device.torch_str(),
|
device=device.torch_str(),
|
||||||
model_path=face_path,
|
model_path=face_path,
|
||||||
upscale=upscale.face_outscale,
|
upscale=upscale.face_outscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("gfpgan", cache_key, gfpgan)
|
server.cache.set("gfpgan", cache_key, gfpgan)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return gfpgan
|
return gfpgan
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
job: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
def correct_gfpgan(
|
if upscale.correction_model is None:
|
||||||
job: WorkerContext,
|
logger.warn("no face model given, skipping")
|
||||||
server: ServerContext,
|
return source
|
||||||
stage: StageParams,
|
|
||||||
_params: ImageParams,
|
|
||||||
source: Image.Image,
|
|
||||||
*,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
stage_source: Optional[Image.Image] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Image.Image:
|
|
||||||
upscale = upscale.with_args(**kwargs)
|
|
||||||
source = stage_source or source
|
|
||||||
|
|
||||||
if upscale.correction_model is None:
|
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||||
logger.warn("no face model given, skipping")
|
device = job.get_device()
|
||||||
return source
|
gfpgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
output = np.array(source)
|
||||||
device = job.get_device()
|
_, _, output = gfpgan.enhance(
|
||||||
gfpgan = load_gfpgan(server, stage, upscale, device)
|
output,
|
||||||
|
has_aligned=False,
|
||||||
|
only_center_face=False,
|
||||||
|
paste_back=True,
|
||||||
|
weight=upscale.face_strength,
|
||||||
|
)
|
||||||
|
output = Image.fromarray(output, "RGB")
|
||||||
|
|
||||||
output = np.array(source)
|
return output
|
||||||
_, _, output = gfpgan.enhance(
|
|
||||||
output,
|
|
||||||
has_aligned=False,
|
|
||||||
only_center_face=False,
|
|
||||||
paste_back=True,
|
|
||||||
weight=upscale.face_strength,
|
|
||||||
)
|
|
||||||
output = Image.fromarray(output, "RGB")
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
|
@ -10,19 +10,21 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_disk(
|
class PersistDiskStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
output: str,
|
source: Image.Image,
|
||||||
stage_source: Image.Image,
|
*,
|
||||||
**kwargs,
|
output: str,
|
||||||
) -> Image.Image:
|
stage_source: Image.Image,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
dest = save_image(server, output, source, params=params)
|
dest = save_image(server, output, source, params=params)
|
||||||
logger.info("saved image to %s", dest)
|
logger.info("saved image to %s", dest)
|
||||||
return source
|
return source
|
||||||
|
|
|
@ -12,33 +12,35 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_s3(
|
class PersistS3Stage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
output: str,
|
source: Image.Image,
|
||||||
bucket: str,
|
*,
|
||||||
endpoint_url: Optional[str] = None,
|
output: str,
|
||||||
profile_name: Optional[str] = None,
|
bucket: str,
|
||||||
stage_source: Optional[Image.Image] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
**kwargs,
|
profile_name: Optional[str] = None,
|
||||||
) -> Image.Image:
|
stage_source: Optional[Image.Image] = None,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
source.save(data, format=server.image_format)
|
source.save(data, format=server.image_format)
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
s3.upload_fileobj(data, bucket, output)
|
s3.upload_fileobj(data, bucket, output)
|
||||||
logger.info("saved image to s3://%s/%s", bucket, output)
|
logger.info("saved image to s3://%s/%s", bucket, output)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error saving image to S3")
|
logger.exception("error saving image to S3")
|
||||||
|
|
||||||
return source
|
return source
|
||||||
|
|
|
@ -10,20 +10,24 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reduce_crop(
|
class ReduceCropStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
_server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
_server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
origin: Size,
|
source: Image.Image,
|
||||||
size: Size,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
origin: Size,
|
||||||
**kwargs,
|
size: Size,
|
||||||
) -> Image.Image:
|
stage_source: Optional[Image.Image] = None,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
image = source.crop((origin.width, origin.height, size.width, size.height))
|
image = source.crop((origin.width, origin.height, size.width, size.height))
|
||||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
logger.info(
|
||||||
return image
|
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
|
@ -9,21 +9,25 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def reduce_thumbnail(
|
class ReduceThumbnailStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
_server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
_server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
size: Size,
|
source: Image.Image,
|
||||||
stage_source: Image.Image,
|
*,
|
||||||
**kwargs,
|
size: Size,
|
||||||
) -> Image.Image:
|
stage_source: Image.Image,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
image = source.copy()
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
image = source.copy()
|
||||||
|
|
||||||
image = image.thumbnail((size.width, size.height))
|
image = image.thumbnail((size.width, size.height))
|
||||||
|
|
||||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
logger.info(
|
||||||
return image
|
"created thumbnail with dimensions: %sx%s", image.width, image.height
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
|
|
@ -10,25 +10,29 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_noise(
|
class SourceNoiseStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
_server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
_server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
size: Size,
|
source: Image.Image,
|
||||||
noise_source: Callable,
|
*,
|
||||||
stage_source: Image.Image,
|
size: Size,
|
||||||
**kwargs,
|
noise_source: Callable,
|
||||||
) -> Image.Image:
|
stage_source: Image.Image,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
logger.info("generating image from noise source")
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
logger.info("generating image from noise source")
|
||||||
|
|
||||||
if source is not None:
|
if source is not None:
|
||||||
logger.warn("a source image was passed to a noise stage, but will be discarded")
|
logger.warn(
|
||||||
|
"a source image was passed to a noise stage, but will be discarded"
|
||||||
|
)
|
||||||
|
|
||||||
output = noise_source(source, (size.width, size.height), (0, 0))
|
output = noise_source(source, (size.width, size.height), (0, 0))
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -12,31 +12,33 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_s3(
|
class SourceS3Stage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
source_key: str,
|
source: Image.Image,
|
||||||
bucket: str,
|
*,
|
||||||
endpoint_url: Optional[str] = None,
|
source_key: str,
|
||||||
profile_name: Optional[str] = None,
|
bucket: str,
|
||||||
stage_source: Optional[Image.Image] = None,
|
endpoint_url: Optional[str] = None,
|
||||||
**kwargs,
|
profile_name: Optional[str] = None,
|
||||||
) -> Image.Image:
|
stage_source: Optional[Image.Image] = None,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client("s3", endpoint_url=endpoint_url)
|
s3 = session.client("s3", endpoint_url=endpoint_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("loading image from s3://%s/%s", bucket, source_key)
|
logger.info("loading image from s3://%s/%s", bucket, source_key)
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
s3.download_fileobj(bucket, source_key, data)
|
s3.download_fileobj(bucket, source_key, data)
|
||||||
|
|
||||||
data.seek(0)
|
data.seek(0)
|
||||||
return Image.open(data)
|
return Image.open(data)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error loading image from S3")
|
logger.exception("error loading image from S3")
|
||||||
|
|
|
@ -14,74 +14,78 @@ from ..worker import ProgressCallback, WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_txt2img(
|
class SourceTxt2ImgStage:
|
||||||
job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
_source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
size: Size,
|
_source: Image.Image,
|
||||||
callback: Optional[ProgressCallback] = None,
|
*,
|
||||||
**kwargs,
|
size: Size,
|
||||||
) -> Image.Image:
|
callback: Optional[ProgressCallback] = None,
|
||||||
params = params.with_args(**kwargs)
|
**kwargs,
|
||||||
size = size.with_args(**kwargs)
|
) -> Image.Image:
|
||||||
logger.info(
|
params = params.with_args(**kwargs)
|
||||||
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
size = size.with_args(**kwargs)
|
||||||
)
|
logger.info(
|
||||||
|
"generating image using txt2img, %s steps: %s", params.steps, params.prompt
|
||||||
if "stage_source" in kwargs:
|
|
||||||
logger.warn(
|
|
||||||
"a source image was passed to a txt2img stage, and will be discarded"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_pairs, loras, inversions = parse_prompt(params)
|
if "stage_source" in kwargs:
|
||||||
|
logger.warn(
|
||||||
|
"a source image was passed to a txt2img stage, and will be discarded"
|
||||||
|
)
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
prompt_pairs, loras, inversions = parse_prompt(params)
|
||||||
pipe_type = params.get_valid_pipeline("txt2img")
|
|
||||||
pipe = load_pipeline(
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
pipe_type,
|
|
||||||
job.get_device(),
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.lpw():
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
logger.debug("using LPW pipeline for txt2img")
|
pipe_type = params.get_valid_pipeline("txt2img")
|
||||||
rng = torch.manual_seed(params.seed)
|
pipe = load_pipeline(
|
||||||
result = pipe.text2img(
|
server,
|
||||||
params.prompt,
|
params,
|
||||||
height=size.height,
|
pipe_type,
|
||||||
width=size.width,
|
job.get_device(),
|
||||||
generator=rng,
|
inversions=inversions,
|
||||||
guidance_scale=params.cfg,
|
loras=loras,
|
||||||
latents=latents,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# encode and record alternative prompts outside of LPW
|
|
||||||
prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg())
|
|
||||||
pipe.unet.set_prompts(prompt_embeds)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
|
||||||
result = pipe(
|
|
||||||
params.prompt,
|
|
||||||
height=size.height,
|
|
||||||
width=size.width,
|
|
||||||
generator=rng,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
latents=latents,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
callback=callback,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output = result.images[0]
|
if params.lpw():
|
||||||
|
logger.debug("using LPW pipeline for txt2img")
|
||||||
|
rng = torch.manual_seed(params.seed)
|
||||||
|
result = pipe.text2img(
|
||||||
|
params.prompt,
|
||||||
|
height=size.height,
|
||||||
|
width=size.width,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
latents=latents,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# encode and record alternative prompts outside of LPW
|
||||||
|
prompt_embeds = encode_prompt(
|
||||||
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
|
)
|
||||||
|
pipe.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
rng = np.random.RandomState(params.seed)
|
||||||
return output
|
result = pipe(
|
||||||
|
params.prompt,
|
||||||
|
height=size.height,
|
||||||
|
width=size.width,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
latents=latents,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = result.images[0]
|
||||||
|
|
||||||
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
|
return output
|
||||||
|
|
|
@ -11,27 +11,29 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_url(
|
class SourceURLStage:
|
||||||
_job: WorkerContext,
|
def run(
|
||||||
_server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
_job: WorkerContext,
|
||||||
_params: ImageParams,
|
_server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
_params: ImageParams,
|
||||||
source_url: str,
|
source: Image.Image,
|
||||||
stage_source: Image.Image,
|
*,
|
||||||
**kwargs,
|
source_url: str,
|
||||||
) -> Image.Image:
|
stage_source: Image.Image,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
logger.info("loading image from URL source")
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
logger.info("loading image from URL source")
|
||||||
|
|
||||||
if source is not None:
|
if source is not None:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"a source image was passed to a source stage, and will be discarded"
|
"a source image was passed to a source stage, and will be discarded"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = requests.get(source_url)
|
response = requests.get(source_url)
|
||||||
output = Image.open(BytesIO(response.content))
|
output = Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.params import ImageParams, Size, SizeChart, StageParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStage:
|
||||||
|
max_tile = SizeChart.auto
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
job: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
*args,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
_params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
) -> int:
|
||||||
|
raise NotImplementedError()
|
|
@ -1,14 +1,14 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from ..chain import ChainPipeline, PipelineStage
|
|
||||||
from ..chain.correct_codeformer import correct_codeformer
|
|
||||||
from ..chain.correct_gfpgan import correct_gfpgan
|
|
||||||
from ..chain.upscale_bsrgan import upscale_bsrgan
|
|
||||||
from ..chain.upscale_resrgan import upscale_resrgan
|
|
||||||
from ..chain.upscale_stable_diffusion import upscale_stable_diffusion
|
|
||||||
from ..chain.upscale_swinir import upscale_swinir
|
|
||||||
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
from ..params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
|
from . import ChainPipeline, PipelineStage
|
||||||
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
|
from .correct_gfpgan import CorrectGFPGANStage
|
||||||
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||||
|
from .upscale_resrgan import UpscaleRealESRGANStage
|
||||||
|
from .upscale_stable_diffusion import UpscaleStableDiffusionStage
|
||||||
|
from .upscale_swinir import UpscaleSwinIRStage
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -72,23 +72,23 @@ def stage_upscale_correction(
|
||||||
tile_size=stage.tile_size,
|
tile_size=stage.tile_size,
|
||||||
outscale=upscale.outscale,
|
outscale=upscale.outscale,
|
||||||
)
|
)
|
||||||
upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts)
|
upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts)
|
||||||
elif "esrgan" in upscale.upscale_model:
|
elif "esrgan" in upscale.upscale_model:
|
||||||
esrgan_params = StageParams(
|
esrgan_params = StageParams(
|
||||||
tile_size=stage.tile_size,
|
tile_size=stage.tile_size,
|
||||||
outscale=upscale.outscale,
|
outscale=upscale.outscale,
|
||||||
)
|
)
|
||||||
upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts)
|
upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts)
|
||||||
elif "stable-diffusion" in upscale.upscale_model:
|
elif "stable-diffusion" in upscale.upscale_model:
|
||||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||||
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||||
upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts)
|
upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts)
|
||||||
elif "swinir" in upscale.upscale_model:
|
elif "swinir" in upscale.upscale_model:
|
||||||
swinir_params = StageParams(
|
swinir_params = StageParams(
|
||||||
tile_size=stage.tile_size,
|
tile_size=stage.tile_size,
|
||||||
outscale=upscale.outscale,
|
outscale=upscale.outscale,
|
||||||
)
|
)
|
||||||
upscale_stage = (upscale_swinir, swinir_params, upscale_opts)
|
upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts)
|
||||||
else:
|
else:
|
||||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||||
|
|
||||||
|
@ -98,9 +98,9 @@ def stage_upscale_correction(
|
||||||
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
||||||
)
|
)
|
||||||
if "codeformer" in upscale.correction_model:
|
if "codeformer" in upscale.correction_model:
|
||||||
correct_stage = (correct_codeformer, face_params, upscale_opts)
|
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
|
||||||
elif "gfpgan" in upscale.correction_model:
|
elif "gfpgan" in upscale.correction_model:
|
||||||
correct_stage = (correct_gfpgan, face_params, upscale_opts)
|
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
|
||||||
else:
|
else:
|
||||||
logger.warn("unknown correction model: %s", upscale.correction_model)
|
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..models.onnx import OnnxModel
|
from ..models.onnx import OnnxModel
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
|
@ -14,105 +14,121 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_bsrgan(
|
class UpscaleBSRGANStage:
|
||||||
server: ServerContext,
|
max_tile = 64
|
||||||
_stage: StageParams,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
device: DeviceParams,
|
|
||||||
):
|
|
||||||
# must be within the load function for patch to take effect
|
|
||||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
|
||||||
cache_key = (model_path,)
|
|
||||||
cache_pipe = server.cache.get("bsrgan", cache_key)
|
|
||||||
|
|
||||||
if cache_pipe is not None:
|
def load(
|
||||||
logger.debug("reusing existing BSRGAN pipeline")
|
self,
|
||||||
return cache_pipe
|
server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
device: DeviceParams,
|
||||||
|
):
|
||||||
|
# must be within the load function for patch to take effect
|
||||||
|
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||||
|
cache_key = (model_path,)
|
||||||
|
cache_pipe = server.cache.get("bsrgan", cache_key)
|
||||||
|
|
||||||
logger.info("loading BSRGAN model from %s", model_path)
|
if cache_pipe is not None:
|
||||||
|
logger.debug("reusing existing BSRGAN pipeline")
|
||||||
|
return cache_pipe
|
||||||
|
|
||||||
pipe = OnnxModel(
|
logger.info("loading BSRGAN model from %s", model_path)
|
||||||
server,
|
|
||||||
model_path,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
|
|
||||||
server.cache.set("bsrgan", cache_key, pipe)
|
pipe = OnnxModel(
|
||||||
run_gc([device])
|
server,
|
||||||
|
model_path,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
|
||||||
return pipe
|
server.cache.set("bsrgan", cache_key, pipe)
|
||||||
|
run_gc([device])
|
||||||
|
|
||||||
|
return pipe
|
||||||
|
|
||||||
def upscale_bsrgan(
|
def run(
|
||||||
job: WorkerContext,
|
self,
|
||||||
server: ServerContext,
|
job: WorkerContext,
|
||||||
stage: StageParams,
|
server: ServerContext,
|
||||||
_params: ImageParams,
|
stage: StageParams,
|
||||||
source: Image.Image,
|
_params: ImageParams,
|
||||||
*,
|
source: Image.Image,
|
||||||
upscale: UpscaleParams,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
stage_source: Optional[Image.Image] = None,
|
||||||
) -> Image.Image:
|
**kwargs,
|
||||||
upscale = upscale.with_args(**kwargs)
|
) -> Image.Image:
|
||||||
source = stage_source or source
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
logger.warn("no upscaling model given, skipping")
|
logger.warn("no upscaling model given, skipping")
|
||||||
return source
|
return source
|
||||||
|
|
||||||
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model)
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
bsrgan = load_bsrgan(server, stage, upscale, device)
|
bsrgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
tile_size = (64, 64)
|
tile_size = (64, 64)
|
||||||
tile_x = source.width // tile_size[0]
|
tile_x = source.width // tile_size[0]
|
||||||
tile_y = source.height // tile_size[1]
|
tile_y = source.height // tile_size[1]
|
||||||
|
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
dest = np.zeros(
|
||||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
(
|
||||||
)
|
image.shape[0],
|
||||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
image.shape[1],
|
||||||
|
image.shape[2] * scale,
|
||||||
for x in range(tile_x):
|
image.shape[3] * scale,
|
||||||
for y in range(tile_y):
|
|
||||||
xt = x * tile_size[0]
|
|
||||||
yt = y * tile_size[1]
|
|
||||||
|
|
||||||
ix1 = xt
|
|
||||||
ix2 = xt + tile_size[0]
|
|
||||||
iy1 = yt
|
|
||||||
iy2 = yt + tile_size[1]
|
|
||||||
logger.debug(
|
|
||||||
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
|
||||||
ix1,
|
|
||||||
ix2,
|
|
||||||
iy1,
|
|
||||||
iy2,
|
|
||||||
ix1 * scale,
|
|
||||||
ix2 * scale,
|
|
||||||
iy1 * scale,
|
|
||||||
iy2 * scale,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||||
|
|
||||||
dest[
|
for x in range(tile_x):
|
||||||
:,
|
for y in range(tile_y):
|
||||||
:,
|
xt = x * tile_size[0]
|
||||||
ix1 * scale : ix2 * scale,
|
yt = y * tile_size[1]
|
||||||
iy1 * scale : iy2 * scale,
|
|
||||||
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
|
|
||||||
|
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
ix1 = xt
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
ix2 = xt + tile_size[0]
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
iy1 = yt
|
||||||
|
iy2 = yt + tile_size[1]
|
||||||
|
logger.debug(
|
||||||
|
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||||
|
ix1,
|
||||||
|
ix2,
|
||||||
|
iy1,
|
||||||
|
iy2,
|
||||||
|
ix1 * scale,
|
||||||
|
ix2 * scale,
|
||||||
|
iy1 * scale,
|
||||||
|
iy2 * scale,
|
||||||
|
)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
dest[
|
||||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
:,
|
||||||
return output
|
:,
|
||||||
|
ix1 * scale : ix2 * scale,
|
||||||
|
iy1 * scale : iy2 * scale,
|
||||||
|
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
|
||||||
|
|
||||||
|
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||||
|
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
|
dest = (dest * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
|
output = Image.fromarray(dest, "RGB")
|
||||||
|
logger.debug("output image size: %s x %s", output.width, output.height)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def steps(
|
||||||
|
self,
|
||||||
|
_params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
) -> int:
|
||||||
|
return size.width // self.max_tile * size.height // self.max_tile
|
||||||
|
|
|
@ -3,70 +3,71 @@ from typing import Any, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..chain.base import ChainPipeline
|
from ..chain import BlendImg2ImgStage, ChainPipeline
|
||||||
from ..chain.blend_img2img import blend_img2img
|
|
||||||
from ..diffusers.upscale import stage_upscale_correction
|
|
||||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from ..worker.context import ProgressCallback
|
from ..worker.context import ProgressCallback
|
||||||
|
from .upscale import stage_upscale_correction
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upscale_highres(
|
class UpscaleHighresStage:
|
||||||
job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
highres: HighresParams,
|
source: Image.Image,
|
||||||
upscale: UpscaleParams,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
highres: HighresParams,
|
||||||
pipeline: Optional[Any] = None,
|
upscale: UpscaleParams,
|
||||||
callback: Optional[ProgressCallback] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
pipeline: Optional[Any] = None,
|
||||||
) -> Image.Image:
|
callback: Optional[ProgressCallback] = None,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
if highres.scale <= 1:
|
if highres.scale <= 1:
|
||||||
return source
|
return source
|
||||||
|
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
scaled_size = (source.width * highres.scale, source.height * highres.scale)
|
scaled_size = (source.width * highres.scale, source.height * highres.scale)
|
||||||
|
|
||||||
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
|
# TODO: upscaling within the same stage prevents tiling from happening and causes OOM
|
||||||
if highres.method == "bilinear":
|
if highres.method == "bilinear":
|
||||||
logger.debug("using bilinear interpolation for highres")
|
logger.debug("using bilinear interpolation for highres")
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||||
elif highres.method == "lanczos":
|
elif highres.method == "lanczos":
|
||||||
logger.debug("using Lanczos interpolation for highres")
|
logger.debug("using Lanczos interpolation for highres")
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||||
else:
|
else:
|
||||||
logger.debug("using upscaling pipeline for highres")
|
logger.debug("using upscaling pipeline for highres")
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
|
StageParams(),
|
||||||
|
params,
|
||||||
|
upscale=upscale.with_args(
|
||||||
|
faces=False,
|
||||||
|
scale=highres.scale,
|
||||||
|
outscale=highres.scale,
|
||||||
|
),
|
||||||
|
chain=chain,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain.stage(
|
||||||
|
BlendImg2ImgStage(),
|
||||||
StageParams(),
|
StageParams(),
|
||||||
params,
|
overlap=params.overlap,
|
||||||
upscale=upscale.with_args(
|
strength=highres.strength,
|
||||||
faces=False,
|
|
||||||
scale=highres.scale,
|
|
||||||
outscale=highres.scale,
|
|
||||||
),
|
|
||||||
chain=chain,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chain.stage(
|
return chain(
|
||||||
blend_img2img,
|
job,
|
||||||
StageParams(),
|
server,
|
||||||
overlap=params.overlap,
|
params,
|
||||||
strength=highres.strength,
|
source,
|
||||||
)
|
callback=callback,
|
||||||
|
)
|
||||||
return chain(
|
|
||||||
job,
|
|
||||||
server,
|
|
||||||
params,
|
|
||||||
source,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
|
|
|
@ -18,138 +18,140 @@ from .utils import complete_tile, process_tile_grid, process_tile_order
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upscale_outpaint(
|
class UpscaleOutpaintStage:
|
||||||
job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
stage: StageParams,
|
job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
border: Border,
|
source: Image.Image,
|
||||||
stage_source: Optional[Image.Image] = None,
|
*,
|
||||||
stage_mask: Optional[Image.Image] = None,
|
border: Border,
|
||||||
fill_color: str = "white",
|
stage_source: Optional[Image.Image] = None,
|
||||||
mask_filter: Callable = mask_filter_none,
|
stage_mask: Optional[Image.Image] = None,
|
||||||
noise_source: Callable = noise_source_histogram,
|
fill_color: str = "white",
|
||||||
callback: Optional[ProgressCallback] = None,
|
mask_filter: Callable = mask_filter_none,
|
||||||
**kwargs,
|
noise_source: Callable = noise_source_histogram,
|
||||||
) -> Image.Image:
|
callback: Optional[ProgressCallback] = None,
|
||||||
source = stage_source or source
|
**kwargs,
|
||||||
logger.info(
|
) -> Image.Image:
|
||||||
"upscaling %s x %s image by expanding borders: %s",
|
source = stage_source or source
|
||||||
source.width,
|
logger.info(
|
||||||
source.height,
|
"upscaling %s x %s image by expanding borders: %s",
|
||||||
border,
|
source.width,
|
||||||
)
|
source.height,
|
||||||
|
border,
|
||||||
|
)
|
||||||
|
|
||||||
margin_x = float(max(border.left, border.right))
|
margin_x = float(max(border.left, border.right))
|
||||||
margin_y = float(max(border.top, border.bottom))
|
margin_y = float(max(border.top, border.bottom))
|
||||||
overlap = min(margin_x / source.width, margin_y / source.height)
|
overlap = min(margin_x / source.width, margin_y / source.height)
|
||||||
|
|
||||||
if stage_mask is None:
|
if stage_mask is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
stage_mask = Image.new("RGB", source.size, "black")
|
stage_mask = Image.new("RGB", source.size, "black")
|
||||||
|
|
||||||
# masks start as 512x512, resize to cover the source, then trim the extra
|
# masks start as 512x512, resize to cover the source, then trim the extra
|
||||||
mask_max = max(source.width, source.height)
|
mask_max = max(source.width, source.height)
|
||||||
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
|
stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max))
|
||||||
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
|
stage_mask = stage_mask.crop((0, 0, source.width, source.height))
|
||||||
|
|
||||||
source, stage_mask, noise, full_size = expand_image(
|
source, stage_mask, noise, full_size = expand_image(
|
||||||
source,
|
source,
|
||||||
stage_mask,
|
stage_mask,
|
||||||
border,
|
border,
|
||||||
fill=fill_color,
|
fill=fill_color,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
|
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
|
||||||
|
|
||||||
draw_mask = ImageDraw.Draw(stage_mask)
|
draw_mask = ImageDraw.Draw(stage_mask)
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
save_image(server, "last-source.png", source)
|
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
|
||||||
save_image(server, "last-noise.png", noise)
|
|
||||||
|
|
||||||
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
|
||||||
left, top, tile = dims
|
|
||||||
size = Size(*tile_source.size)
|
|
||||||
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
|
||||||
tile_mask = complete_tile(tile_mask, tile)
|
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "tile-source.png", tile_source)
|
save_image(server, "last-source.png", source)
|
||||||
save_image(server, "tile-mask.png", tile_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
|
save_image(server, "last-noise.png", noise)
|
||||||
|
|
||||||
latents = get_tile_latents(full_latents, dims, size)
|
def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]):
|
||||||
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
|
left, top, tile = dims
|
||||||
pipe = load_pipeline(
|
size = Size(*tile_source.size)
|
||||||
server,
|
tile_mask = stage_mask.crop((left, top, left + tile, top + tile))
|
||||||
params,
|
tile_mask = complete_tile(tile_mask, tile)
|
||||||
pipe_type,
|
|
||||||
job.get_device(),
|
if is_debug():
|
||||||
# TODO: load LoRAs and TIs
|
save_image(server, "tile-source.png", tile_source)
|
||||||
)
|
save_image(server, "tile-mask.png", tile_mask)
|
||||||
if params.lpw():
|
|
||||||
logger.debug("using LPW pipeline for inpaint")
|
latents = get_tile_latents(full_latents, dims, size)
|
||||||
rng = torch.manual_seed(params.seed)
|
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
|
||||||
result = pipe.inpaint(
|
pipe = load_pipeline(
|
||||||
tile_source,
|
server,
|
||||||
tile_mask,
|
params,
|
||||||
params.prompt,
|
pipe_type,
|
||||||
generator=rng,
|
job.get_device(),
|
||||||
guidance_scale=params.cfg,
|
# TODO: load LoRAs and TIs
|
||||||
height=size.height,
|
)
|
||||||
latents=latents,
|
if params.lpw():
|
||||||
negative_prompt=params.negative_prompt,
|
logger.debug("using LPW pipeline for inpaint")
|
||||||
num_inference_steps=params.steps,
|
rng = torch.manual_seed(params.seed)
|
||||||
width=size.width,
|
result = pipe.inpaint(
|
||||||
callback=callback,
|
tile_source,
|
||||||
|
tile_mask,
|
||||||
|
params.prompt,
|
||||||
|
generator=rng,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
height=size.height,
|
||||||
|
latents=latents,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
width=size.width,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rng = np.random.RandomState(params.seed)
|
||||||
|
result = pipe(
|
||||||
|
params.prompt,
|
||||||
|
tile_source,
|
||||||
|
tile_mask,
|
||||||
|
height=size.height,
|
||||||
|
width=size.width,
|
||||||
|
num_inference_steps=params.steps,
|
||||||
|
guidance_scale=params.cfg,
|
||||||
|
negative_prompt=params.negative_prompt,
|
||||||
|
generator=rng,
|
||||||
|
latents=latents,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# once part of the image has been drawn, keep it
|
||||||
|
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
||||||
|
return result.images[0]
|
||||||
|
|
||||||
|
if params.pipeline == "panorama":
|
||||||
|
logger.debug("outpainting with one shot panorama, no tiling")
|
||||||
|
return outpaint(source, (0, 0, max(source.width, source.height)))
|
||||||
|
if overlap == 0:
|
||||||
|
logger.debug("outpainting with 0 margin, using grid tiling")
|
||||||
|
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||||
|
elif border.left == border.right and border.top == border.bottom:
|
||||||
|
logger.debug(
|
||||||
|
"outpainting with an even border, using spiral tiling with %s overlap",
|
||||||
|
overlap,
|
||||||
|
)
|
||||||
|
output = process_tile_order(
|
||||||
|
stage.tile_order,
|
||||||
|
source,
|
||||||
|
SizeChart.auto,
|
||||||
|
1,
|
||||||
|
[outpaint],
|
||||||
|
overlap=overlap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rng = np.random.RandomState(params.seed)
|
logger.debug("outpainting with an uneven border, using grid tiling")
|
||||||
result = pipe(
|
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
||||||
params.prompt,
|
|
||||||
tile_source,
|
|
||||||
tile_mask,
|
|
||||||
height=size.height,
|
|
||||||
width=size.width,
|
|
||||||
num_inference_steps=params.steps,
|
|
||||||
guidance_scale=params.cfg,
|
|
||||||
negative_prompt=params.negative_prompt,
|
|
||||||
generator=rng,
|
|
||||||
latents=latents,
|
|
||||||
callback=callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# once part of the image has been drawn, keep it
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
|
return output
|
||||||
return result.images[0]
|
|
||||||
|
|
||||||
if params.pipeline == "panorama":
|
|
||||||
logger.debug("outpainting with one shot panorama, no tiling")
|
|
||||||
return outpaint(source, (0, 0, max(source.width, source.height)))
|
|
||||||
if overlap == 0:
|
|
||||||
logger.debug("outpainting with 0 margin, using grid tiling")
|
|
||||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
|
||||||
elif border.left == border.right and border.top == border.bottom:
|
|
||||||
logger.debug(
|
|
||||||
"outpainting with an even border, using spiral tiling with %s overlap",
|
|
||||||
overlap,
|
|
||||||
)
|
|
||||||
output = process_tile_order(
|
|
||||||
stage.tile_order,
|
|
||||||
source,
|
|
||||||
SizeChart.auto,
|
|
||||||
1,
|
|
||||||
[outpaint],
|
|
||||||
overlap=overlap,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
|
||||||
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])
|
|
||||||
|
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
|
||||||
return output
|
|
||||||
|
|
|
@ -16,79 +16,80 @@ logger = getLogger(__name__)
|
||||||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
|
||||||
def load_resrgan(
|
class UpscaleRealESRGANStage:
|
||||||
server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
def load(
|
||||||
):
|
self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||||
# must be within load function for patches to take effect
|
):
|
||||||
# TODO: rewrite and remove
|
# must be within load function for patches to take effect
|
||||||
from realesrgan import RealESRGANer
|
# TODO: rewrite and remove
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
model_file = "%s.%s" % (params.upscale_model, params.format)
|
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||||
model_path = path.join(server.model_path, model_file)
|
model_path = path.join(server.model_path, model_file)
|
||||||
|
|
||||||
cache_key = (model_path, params.format)
|
cache_key = (model_path, params.format)
|
||||||
cache_pipe = server.cache.get("resrgan", cache_key)
|
cache_pipe = server.cache.get("resrgan", cache_key)
|
||||||
if cache_pipe is not None:
|
if cache_pipe is not None:
|
||||||
logger.info("reusing existing Real ESRGAN pipeline")
|
logger.info("reusing existing Real ESRGAN pipeline")
|
||||||
return cache_pipe
|
return cache_pipe
|
||||||
|
|
||||||
if not path.isfile(model_path):
|
if not path.isfile(model_path):
|
||||||
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
|
raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path)
|
||||||
|
|
||||||
# TODO: swap for regular RRDBNet after rewriting wrapper
|
# TODO: swap for regular RRDBNet after rewriting wrapper
|
||||||
model = OnnxRRDBNet(
|
model = OnnxRRDBNet(
|
||||||
server,
|
server,
|
||||||
model_file,
|
model_file,
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
|
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
|
||||||
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
|
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
|
||||||
model_path = [model_path, wdn_model_path]
|
model_path = [model_path, wdn_model_path]
|
||||||
dni_weight = [params.denoise, 1 - params.denoise]
|
dni_weight = [params.denoise, 1 - params.denoise]
|
||||||
|
|
||||||
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
||||||
|
|
||||||
# TODO: shouldn't need the PTH file
|
# TODO: shouldn't need the PTH file
|
||||||
model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model))
|
model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model))
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=params.scale,
|
scale=params.scale,
|
||||||
model_path=model_path_pth,
|
model_path=model_path_pth,
|
||||||
dni_weight=dni_weight,
|
dni_weight=dni_weight,
|
||||||
model=model,
|
model=model,
|
||||||
tile=tile,
|
tile=tile,
|
||||||
tile_pad=params.tile_pad,
|
tile_pad=params.tile_pad,
|
||||||
pre_pad=params.pre_pad,
|
pre_pad=params.pre_pad,
|
||||||
half=False, # TODO: use server optimizations
|
half=False, # TODO: use server optimizations
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set("resrgan", cache_key, upsampler)
|
server.cache.set("resrgan", cache_key, upsampler)
|
||||||
run_gc([device])
|
run_gc([device])
|
||||||
|
|
||||||
return upsampler
|
return upsampler
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
job: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
_params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
source = stage_source or source
|
||||||
|
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
||||||
|
|
||||||
def upscale_resrgan(
|
output = np.array(source)
|
||||||
job: WorkerContext,
|
upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||||
server: ServerContext,
|
|
||||||
stage: StageParams,
|
|
||||||
_params: ImageParams,
|
|
||||||
source: Image.Image,
|
|
||||||
*,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
stage_source: Optional[Image.Image] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Image.Image:
|
|
||||||
source = stage_source or source
|
|
||||||
logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale)
|
|
||||||
|
|
||||||
output = np.array(source)
|
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||||
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
|
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
output = Image.fromarray(output, "RGB")
|
||||||
|
logger.info("final output image size: %sx%s", output.width, output.height)
|
||||||
output = Image.fromarray(output, "RGB")
|
return output
|
||||||
logger.info("final output image size: %sx%s", output.width, output.height)
|
|
||||||
return output
|
|
||||||
|
|
|
@ -14,52 +14,54 @@ from ..worker import ProgressCallback, WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def upscale_stable_diffusion(
|
class UpscaleStableDiffusionStage:
|
||||||
job: WorkerContext,
|
def run(
|
||||||
server: ServerContext,
|
self,
|
||||||
_stage: StageParams,
|
job: WorkerContext,
|
||||||
params: ImageParams,
|
server: ServerContext,
|
||||||
source: Image.Image,
|
_stage: StageParams,
|
||||||
*,
|
params: ImageParams,
|
||||||
upscale: UpscaleParams,
|
source: Image.Image,
|
||||||
stage_source: Optional[Image.Image] = None,
|
*,
|
||||||
callback: Optional[ProgressCallback] = None,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
stage_source: Optional[Image.Image] = None,
|
||||||
) -> Image.Image:
|
callback: Optional[ProgressCallback] = None,
|
||||||
params = params.with_args(**kwargs)
|
**kwargs,
|
||||||
upscale = upscale.with_args(**kwargs)
|
) -> Image.Image:
|
||||||
source = stage_source or source
|
params = params.with_args(**kwargs)
|
||||||
logger.info(
|
upscale = upscale.with_args(**kwargs)
|
||||||
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
|
source = stage_source or source
|
||||||
)
|
logger.info(
|
||||||
|
"upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt
|
||||||
|
)
|
||||||
|
|
||||||
prompt_pairs, _loras, _inversions = parse_prompt(params)
|
prompt_pairs, _loras, _inversions = parse_prompt(params)
|
||||||
|
|
||||||
pipeline = load_pipeline(
|
pipeline = load_pipeline(
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
"upscale",
|
"upscale",
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
model=path.join(server.model_path, upscale.upscale_model),
|
model=path.join(server.model_path, upscale.upscale_model),
|
||||||
)
|
)
|
||||||
generator = torch.manual_seed(params.seed)
|
generator = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
prompt_embeds = encode_prompt(
|
prompt_embeds = encode_prompt(
|
||||||
pipeline,
|
pipeline,
|
||||||
prompt_pairs,
|
prompt_pairs,
|
||||||
num_images_per_prompt=params.batch,
|
num_images_per_prompt=params.batch,
|
||||||
do_classifier_free_guidance=params.do_cfg(),
|
do_classifier_free_guidance=params.do_cfg(),
|
||||||
)
|
)
|
||||||
pipeline.unet.set_prompts(prompt_embeds)
|
pipeline.unet.set_prompts(prompt_embeds)
|
||||||
|
|
||||||
return pipeline(
|
return pipeline(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
source,
|
source,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
guidance_scale=params.cfg,
|
guidance_scale=params.cfg,
|
||||||
negative_prompt=params.negative_prompt,
|
negative_prompt=params.negative_prompt,
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
eta=params.eta,
|
eta=params.eta,
|
||||||
noise_level=upscale.denoise,
|
noise_level=upscale.denoise,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
|
@ -14,107 +14,116 @@ from ..worker import WorkerContext
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_swinir(
|
class UpscaleSwinIRStage:
|
||||||
server: ServerContext,
|
max_tile = 64
|
||||||
_stage: StageParams,
|
|
||||||
upscale: UpscaleParams,
|
|
||||||
device: DeviceParams,
|
|
||||||
):
|
|
||||||
# must be within the load function for patch to take effect
|
|
||||||
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
|
||||||
cache_key = (model_path,)
|
|
||||||
cache_pipe = server.cache.get("swinir", cache_key)
|
|
||||||
|
|
||||||
if cache_pipe is not None:
|
def load(
|
||||||
logger.info("reusing existing SwinIR pipeline")
|
self,
|
||||||
return cache_pipe
|
server: ServerContext,
|
||||||
|
_stage: StageParams,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
device: DeviceParams,
|
||||||
|
):
|
||||||
|
# must be within the load function for patch to take effect
|
||||||
|
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
|
||||||
|
cache_key = (model_path,)
|
||||||
|
cache_pipe = server.cache.get("swinir", cache_key)
|
||||||
|
|
||||||
logger.debug("loading SwinIR model from %s", model_path)
|
if cache_pipe is not None:
|
||||||
|
logger.info("reusing existing SwinIR pipeline")
|
||||||
|
return cache_pipe
|
||||||
|
|
||||||
pipe = OnnxModel(
|
logger.debug("loading SwinIR model from %s", model_path)
|
||||||
server,
|
|
||||||
model_path,
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
|
|
||||||
server.cache.set("swinir", cache_key, pipe)
|
pipe = OnnxModel(
|
||||||
run_gc([device])
|
server,
|
||||||
|
model_path,
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
|
||||||
return pipe
|
server.cache.set("swinir", cache_key, pipe)
|
||||||
|
run_gc([device])
|
||||||
|
|
||||||
|
return pipe
|
||||||
|
|
||||||
def upscale_swinir(
|
def run(
|
||||||
job: WorkerContext,
|
self,
|
||||||
server: ServerContext,
|
job: WorkerContext,
|
||||||
stage: StageParams,
|
server: ServerContext,
|
||||||
_params: ImageParams,
|
stage: StageParams,
|
||||||
source: Image.Image,
|
_params: ImageParams,
|
||||||
*,
|
source: Image.Image,
|
||||||
upscale: UpscaleParams,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
stage_source: Optional[Image.Image] = None,
|
||||||
) -> Image.Image:
|
**kwargs,
|
||||||
upscale = upscale.with_args(**kwargs)
|
) -> Image.Image:
|
||||||
source = stage_source or source
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
source = stage_source or source
|
||||||
|
|
||||||
if upscale.upscale_model is None:
|
if upscale.upscale_model is None:
|
||||||
logger.warn("no correction model given, skipping")
|
logger.warn("no correction model given, skipping")
|
||||||
return source
|
return source
|
||||||
|
|
||||||
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
swinir = load_swinir(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
# TODO: add support for other sizes
|
# TODO: add support for other sizes
|
||||||
tile_size = (64, 64)
|
tile_size = (64, 64)
|
||||||
tile_x = source.width // tile_size[0]
|
tile_x = source.width // tile_size[0]
|
||||||
tile_y = source.height // tile_size[1]
|
tile_y = source.height // tile_size[1]
|
||||||
|
|
||||||
# TODO: add support for grayscale (1-channel) images
|
# TODO: add support for grayscale (1-channel) images
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
logger.info("SwinIR input shape: %s", image.shape)
|
logger.info("SwinIR input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
dest = np.zeros(
|
||||||
(image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale)
|
(
|
||||||
)
|
image.shape[0],
|
||||||
logger.info("SwinIR output shape: %s", dest.shape)
|
image.shape[1],
|
||||||
|
image.shape[2] * scale,
|
||||||
for x in range(tile_x):
|
image.shape[3] * scale,
|
||||||
for y in range(tile_y):
|
|
||||||
xt = x * tile_size[0]
|
|
||||||
yt = y * tile_size[1]
|
|
||||||
|
|
||||||
ix1 = xt
|
|
||||||
ix2 = xt + tile_size[0]
|
|
||||||
iy1 = yt
|
|
||||||
iy2 = yt + tile_size[1]
|
|
||||||
logger.info(
|
|
||||||
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
|
||||||
ix1,
|
|
||||||
ix2,
|
|
||||||
iy1,
|
|
||||||
iy2,
|
|
||||||
ix1 * scale,
|
|
||||||
ix2 * scale,
|
|
||||||
iy1 * scale,
|
|
||||||
iy2 * scale,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
logger.info("SwinIR output shape: %s", dest.shape)
|
||||||
|
|
||||||
dest[
|
for x in range(tile_x):
|
||||||
:,
|
for y in range(tile_y):
|
||||||
:,
|
xt = x * tile_size[0]
|
||||||
ix1 * scale : ix2 * scale,
|
yt = y * tile_size[1]
|
||||||
iy1 * scale : iy2 * scale,
|
|
||||||
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
|
|
||||||
|
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
ix1 = xt
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
ix2 = xt + tile_size[0]
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
iy1 = yt
|
||||||
|
iy2 = yt + tile_size[1]
|
||||||
|
logger.info(
|
||||||
|
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||||
|
ix1,
|
||||||
|
ix2,
|
||||||
|
iy1,
|
||||||
|
iy2,
|
||||||
|
ix1 * scale,
|
||||||
|
ix2 * scale,
|
||||||
|
iy1 * scale,
|
||||||
|
iy2 * scale,
|
||||||
|
)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
dest[
|
||||||
logger.info("output image size: %s x %s", output.width, output.height)
|
:,
|
||||||
return output
|
:,
|
||||||
|
ix1 * scale : ix2 * scale,
|
||||||
|
iy1 * scale : iy2 * scale,
|
||||||
|
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
|
||||||
|
|
||||||
|
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||||
|
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
|
dest = (dest * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
|
output = Image.fromarray(dest, "RGB")
|
||||||
|
logger.info("output image size: %s x %s", output.width, output.height)
|
||||||
|
return output
|
||||||
|
|
|
@ -4,13 +4,14 @@ from typing import Any, List, Optional
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..chain import (
|
from ..chain import (
|
||||||
blend_img2img,
|
BlendImg2ImgStage,
|
||||||
blend_mask,
|
BlendMaskStage,
|
||||||
source_txt2img,
|
ChainPipeline,
|
||||||
upscale_highres,
|
SourceTxt2ImgStage,
|
||||||
upscale_outpaint,
|
UpscaleHighresStage,
|
||||||
|
UpscaleOutpaintStage,
|
||||||
)
|
)
|
||||||
from ..chain.base import ChainPipeline
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
|
@ -24,7 +25,6 @@ from ..server import ServerContext
|
||||||
from ..server.load import get_source_filters
|
from ..server.load import get_source_filters
|
||||||
from ..utils import run_gc, show_system_toast
|
from ..utils import run_gc, show_system_toast
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .upscale import split_upscale, stage_upscale_correction
|
|
||||||
from .utils import parse_prompt
|
from .utils import parse_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -43,7 +43,7 @@ def run_txt2img_pipeline(
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
chain.stage(
|
chain.stage(
|
||||||
source_txt2img,
|
SourceTxt2ImgStage(),
|
||||||
stage,
|
stage,
|
||||||
size=size,
|
size=size,
|
||||||
)
|
)
|
||||||
|
@ -61,7 +61,7 @@ def run_txt2img_pipeline(
|
||||||
# apply highres
|
# apply highres
|
||||||
for _i in range(highres.iterations):
|
for _i in range(highres.iterations):
|
||||||
chain.stage(
|
chain.stage(
|
||||||
upscale_highres,
|
UpscaleHighresStage(),
|
||||||
StageParams(
|
StageParams(
|
||||||
outscale=highres.scale,
|
outscale=highres.scale,
|
||||||
),
|
),
|
||||||
|
@ -125,7 +125,7 @@ def run_img2img_pipeline(
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
chain.stage(
|
chain.stage(
|
||||||
blend_img2img,
|
BlendImg2ImgStage(),
|
||||||
stage,
|
stage,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
|
@ -144,7 +144,7 @@ def run_img2img_pipeline(
|
||||||
if params.loopback > 0:
|
if params.loopback > 0:
|
||||||
for _i in range(params.loopback):
|
for _i in range(params.loopback):
|
||||||
chain.stage(
|
chain.stage(
|
||||||
blend_img2img,
|
BlendImg2ImgStage(),
|
||||||
stage,
|
stage,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
|
@ -153,7 +153,7 @@ def run_img2img_pipeline(
|
||||||
if highres.iterations > 0:
|
if highres.iterations > 0:
|
||||||
for _i in range(highres.iterations):
|
for _i in range(highres.iterations):
|
||||||
chain.stage(
|
chain.stage(
|
||||||
upscale_highres,
|
UpscaleHighresStage(),
|
||||||
stage,
|
stage,
|
||||||
highres=highres,
|
highres=highres,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
|
@ -223,7 +223,7 @@ def run_inpaint_pipeline(
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order)
|
stage = StageParams(tile_order=tile_order)
|
||||||
chain.stage(
|
chain.stage(
|
||||||
upscale_outpaint,
|
UpscaleOutpaintStage(),
|
||||||
stage,
|
stage,
|
||||||
border=border,
|
border=border,
|
||||||
stage_mask=mask,
|
stage_mask=mask,
|
||||||
|
@ -234,7 +234,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
chain.stage(
|
chain.stage(
|
||||||
upscale_highres,
|
UpscaleHighresStage(),
|
||||||
stage,
|
stage,
|
||||||
highres=highres,
|
highres=highres,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
|
@ -300,7 +300,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
chain.stage(
|
chain.stage(
|
||||||
upscale_highres,
|
UpscaleHighresStage(),
|
||||||
stage,
|
stage,
|
||||||
highres=highres,
|
highres=highres,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
|
@ -353,7 +353,7 @@ def run_blend_pipeline(
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
chain.stage(blend_mask, stage, stage_source=sources[1], stage_mask=mask)
|
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
from typing import Tuple, Union
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
from PIL import Image, ImageChops, ImageOps
|
|
||||||
|
|
||||||
from ..params import Border, Size
|
from ..params import Border, Size
|
||||||
from .mask_filter import mask_filter_none
|
from .mask_filter import mask_filter_none
|
||||||
|
|
|
@ -356,12 +356,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
pipeline = ChainPipeline()
|
pipeline = ChainPipeline()
|
||||||
for stage_data in data.get("stages", []):
|
for stage_data in data.get("stages", []):
|
||||||
callback = CHAIN_STAGES[stage_data.get("type")]
|
stage_class = CHAIN_STAGES[stage_data.get("type")]
|
||||||
kwargs = stage_data.get("params", {})
|
kwargs = stage_data.get("params", {})
|
||||||
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
logger.info("request stage: %s, %s", stage_class.__name__, kwargs)
|
||||||
|
|
||||||
stage = StageParams(
|
stage = StageParams(
|
||||||
stage_data.get("name", callback.__name__),
|
stage_data.get("name", stage_class.__name__),
|
||||||
tile_size=get_size(kwargs.get("tile_size")),
|
tile_size=get_size(kwargs.get("tile_size")),
|
||||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||||
)
|
)
|
||||||
|
@ -399,7 +399,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||||
kwargs["stage_mask"] = mask
|
kwargs["stage_mask"] = mask
|
||||||
|
|
||||||
pipeline.append((callback, stage, kwargs))
|
pipeline.append((stage_class(), stage, kwargs))
|
||||||
|
|
||||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue