feat(api): add img2img and inpaint chain stages
This commit is contained in:
parent
4188b019a1
commit
dcbd059082
|
@ -4,15 +4,24 @@ from .base import (
|
|||
StageCallback,
|
||||
StageParams,
|
||||
)
|
||||
from .blend_img2img import (
|
||||
blend_img2img,
|
||||
)
|
||||
from .blend_inpaint import (
|
||||
blend_inpaint,
|
||||
)
|
||||
from .correct_gfpgan import (
|
||||
correct_gfpgan,
|
||||
)
|
||||
from .generate_txt2img import (
|
||||
generate_txt2img,
|
||||
)
|
||||
from .persist_disk import (
|
||||
persist_disk,
|
||||
)
|
||||
from .persist_s3 import (
|
||||
persist_s3,
|
||||
)
|
||||
from .source_txt2img import (
|
||||
source_txt2img,
|
||||
)
|
||||
from .upscale_outpaint import (
|
||||
upscale_outpaint,
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@ from ..params import (
|
|||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
is_debug,
|
||||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
|
@ -68,10 +69,13 @@ class ChainPipeline:
|
|||
print('source image larger than tile size of %s, tiling stage' % (
|
||||
stage_params.tile_size))
|
||||
|
||||
def stage_tile(tile: Image.Image) -> Image.Image:
|
||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||
tile = stage_pipe(ctx, stage_params, params, tile,
|
||||
**kwargs)
|
||||
|
||||
if is_debug():
|
||||
tile.save(path.join(ctx.output_path, 'last-tile.png'))
|
||||
|
||||
return tile
|
||||
|
||||
image = process_tiles(
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from ..diffusion import (
|
||||
load_pipeline,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def blend_img2img(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
strength: float,
|
||||
) -> Image.Image:
|
||||
print('generating image using img2img', params.prompt)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source_image,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
)
|
||||
output = result.images[0]
|
||||
|
||||
print('final output image size', output.size)
|
||||
return output
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
from diffusers import (
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from PIL import Image
|
||||
from typing import Callable, Tuple
|
||||
|
||||
from ..diffusion import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
from ..image import (
|
||||
expand_image,
|
||||
mask_filter_none,
|
||||
noise_source_histogram,
|
||||
)
|
||||
from ..params import (
|
||||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
base_join,
|
||||
is_debug,
|
||||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tiles,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def blend_inpaint(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
expand: Border,
|
||||
mask_image: Image.Image = None,
|
||||
fill_color: str = 'white',
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
) -> Image.Image:
|
||||
print('upscaling image by expanding borders', expand)
|
||||
|
||||
if mask_image is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
mask_image = Image.new('RGB', source_image.size, 'black')
|
||||
|
||||
source_image, mask_image, noise_image, _full_dims = expand_image(
|
||||
source_image,
|
||||
mask_image,
|
||||
expand,
|
||||
fill=fill_color,
|
||||
noise_source=noise_source,
|
||||
mask_filter=mask_filter)
|
||||
|
||||
if is_debug():
|
||||
source_image.save(base_join(ctx.output_path, 'last-source.png'))
|
||||
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
|
||||
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
|
||||
|
||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||
left, top, tile = dims
|
||||
size = Size(*image.size)
|
||||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||
|
||||
if is_debug():
|
||||
image.save(base_join(ctx.output_path, 'tile-source.png'))
|
||||
mask.save(base_join(ctx.output_path, 'tile-mask.png'))
|
||||
|
||||
# TODO: must use inpainting model here
|
||||
model = '../models/stable-diffusion-onnx-v1-inpainting'
|
||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||
model, params.provider, params.scheduler)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
height=size.height,
|
||||
image=image,
|
||||
latents=latents,
|
||||
mask_image=mask,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
width=size.width,
|
||||
)
|
||||
return result.images[0]
|
||||
|
||||
output = process_tiles(source_image, 512, 1, [outpaint])
|
||||
|
||||
print('final output image size', output.size)
|
||||
return output
|
|
@ -57,7 +57,7 @@ def correct_gfpgan(
|
|||
*,
|
||||
upscale: UpscaleParams,
|
||||
upsampler: Optional[RealESRGANer] = None,
|
||||
) -> Image:
|
||||
) -> Image.Image:
|
||||
if upscale.correction_model is None:
|
||||
print('no face model given, skipping')
|
||||
return image
|
||||
|
|
|
@ -13,8 +13,8 @@ from ..utils import (
|
|||
|
||||
def persist_disk(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
output: str,
|
||||
|
|
|
@ -2,43 +2,31 @@ from diffusers import (
|
|||
OnnxStableDiffusionPipeline,
|
||||
)
|
||||
from PIL import Image
|
||||
from typing import Callable
|
||||
|
||||
from ..diffusion import (
|
||||
get_latents_from_seed,
|
||||
load_pipeline,
|
||||
)
|
||||
from ..image import (
|
||||
expand_image,
|
||||
mask_filter_none,
|
||||
noise_source_histogram,
|
||||
)
|
||||
from ..params import (
|
||||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
base_join,
|
||||
is_debug,
|
||||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tiles,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_txt2img(
|
||||
def source_txt2img(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
size: Size,
|
||||
) -> Image:
|
||||
) -> Image.Image:
|
||||
print('generating image using txt2img', params.prompt)
|
||||
|
||||
if source_image is not None:
|
|
@ -75,7 +75,7 @@ def upscale_resrgan(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image:
|
||||
) -> Image.Image:
|
||||
print('upscaling image with Real ESRGAN', upscale.scale)
|
||||
|
||||
output = np.array(source_image)
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
from PIL import Image
|
||||
from typing import Callable, List
|
||||
from typing import List, Protocol, Tuple
|
||||
|
||||
|
||||
class TileCallback(Protocol):
|
||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
|
||||
pass
|
||||
|
||||
|
||||
def process_tiles(
|
||||
source: Image.Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[Callable],
|
||||
) -> Image:
|
||||
filters: List[TileCallback],
|
||||
) -> Image.Image:
|
||||
width, height = source.size
|
||||
image = Image.new('RGB', (width * scale, height * scale))
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ def run_txt2img_pipeline(
|
|||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams
|
||||
):
|
||||
) -> None:
|
||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
|
@ -139,9 +139,9 @@ def run_img2img_pipeline(
|
|||
params: ImageParams,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image,
|
||||
source_image: Image.Image,
|
||||
strength: float,
|
||||
):
|
||||
) -> None:
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
|
@ -176,14 +176,14 @@ def run_inpaint_pipeline(
|
|||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image,
|
||||
mask_image: Image,
|
||||
source_image: Image.Image,
|
||||
mask_image: Image.Image,
|
||||
expand: Border,
|
||||
noise_source: Any,
|
||||
mask_filter: Any,
|
||||
strength: float,
|
||||
fill_color: str,
|
||||
):
|
||||
) -> None:
|
||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
|
@ -241,8 +241,8 @@ def run_upscale_pipeline(
|
|||
_size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image
|
||||
):
|
||||
source_image: Image.Image,
|
||||
) -> None:
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, source_image, upscale=upscale)
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from numpy import random
|
||||
from PIL import Image, ImageChops, ImageFilter
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -14,7 +13,7 @@ def get_pixel_index(x: int, y: int, width: int) -> int:
|
|||
return (y * width) + x
|
||||
|
||||
|
||||
def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image:
|
||||
def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new('RGB', (width, height), fill)
|
||||
|
@ -23,7 +22,7 @@ def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white'
|
|||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image:
|
||||
def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
||||
'''
|
||||
Gaussian blur with multiply, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -36,7 +35,7 @@ def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point,
|
|||
return noise
|
||||
|
||||
|
||||
def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image:
|
||||
def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
||||
'''
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -49,7 +48,7 @@ def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, r
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image:
|
||||
def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
||||
'''
|
||||
Identity transform, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -61,7 +60,7 @@ def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image:
|
||||
def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
|
||||
'''
|
||||
Fill the whole canvas, no source or noise.
|
||||
'''
|
||||
|
@ -72,7 +71,7 @@ def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image:
|
||||
def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
|
||||
'''
|
||||
Gaussian blur, source image centered on white canvas.
|
||||
'''
|
||||
|
@ -85,7 +84,7 @@ def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, round
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw) -> Image:
|
||||
def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -107,7 +106,7 @@ def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw)
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -> Image:
|
||||
def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
||||
|
@ -129,7 +128,7 @@ def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -
|
|||
return noise
|
||||
|
||||
|
||||
def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw) -> Image:
|
||||
def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
|
||||
r, g, b = source_image.split()
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -161,8 +160,8 @@ def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw
|
|||
|
||||
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
|
||||
def expand_image(
|
||||
source_image: Image,
|
||||
mask_image: Image,
|
||||
source_image: Image.Image,
|
||||
mask_image: Image.Image,
|
||||
expand: Border,
|
||||
fill='white',
|
||||
noise_source=noise_source_histogram,
|
||||
|
|
|
@ -24,8 +24,9 @@ from typing import Tuple
|
|||
|
||||
from .chain import (
|
||||
correct_gfpgan,
|
||||
generate_txt2img,
|
||||
source_txt2img,
|
||||
persist_disk,
|
||||
persist_s3,
|
||||
upscale_outpaint,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
|
@ -546,7 +547,7 @@ def chain():
|
|||
|
||||
# parse body as json, list of stages
|
||||
example = ChainPipeline(stages=[
|
||||
(generate_txt2img, StageParams(), {
|
||||
(source_txt2img, StageParams(), {
|
||||
'size': size,
|
||||
}),
|
||||
(upscale_outpaint, StageParams(), {
|
||||
|
@ -561,6 +562,10 @@ def chain():
|
|||
(persist_disk, StageParams(tile_size=8192), {
|
||||
'output': output,
|
||||
}),
|
||||
(persist_s3, StageParams(tile_size=8192), {
|
||||
'bucket': 'storage-stable-diffusion',
|
||||
'output': output,
|
||||
}),
|
||||
])
|
||||
|
||||
# build and run chain pipeline
|
||||
|
|
|
@ -24,6 +24,10 @@ def run_upscale_correction(
|
|||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image.Image:
|
||||
'''
|
||||
This is a convenience method for a chain pipeline that will run upscaling and
|
||||
correction, based on the `upscale` params.
|
||||
'''
|
||||
print('running upscale pipeline')
|
||||
|
||||
chain = ChainPipeline()
|
||||
|
|
Loading…
Reference in New Issue