1
0
Fork 0

feat(api): add img2img and inpaint chain stages

This commit is contained in:
Sean Sube 2023-01-28 12:42:02 -06:00
parent 4188b019a1
commit dcbd059082
13 changed files with 209 additions and 48 deletions

View File

@ -4,15 +4,24 @@ from .base import (
StageCallback,
StageParams,
)
from .blend_img2img import (
blend_img2img,
)
from .blend_inpaint import (
blend_inpaint,
)
from .correct_gfpgan import (
correct_gfpgan,
)
from .generate_txt2img import (
generate_txt2img,
)
from .persist_disk import (
persist_disk,
)
from .persist_s3 import (
persist_s3,
)
from .source_txt2img import (
source_txt2img,
)
from .upscale_outpaint import (
upscale_outpaint,
)

View File

@ -7,6 +7,7 @@ from ..params import (
StageParams,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
@ -68,10 +69,13 @@ class ChainPipeline:
print('source image larger than tile size of %s, tiling stage' % (
stage_params.tile_size))
def stage_tile(tile: Image.Image) -> Image.Image:
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(ctx, stage_params, params, tile,
**kwargs)
if is_debug():
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile
image = process_tiles(

View File

@ -0,0 +1,48 @@
from diffusers import (
OnnxStableDiffusionImg2ImgPipeline,
)
from PIL import Image
from ..diffusion import (
load_pipeline,
)
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)
import numpy as np
def blend_img2img(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
strength: float,
) -> Image.Image:
print('generating image using img2img', params.prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source_image,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
)
output = result.images[0]
print('final output image size', output.size)
return output

View File

@ -0,0 +1,99 @@
from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from PIL import Image
from typing import Callable, Tuple
from ..diffusion import (
get_latents_from_seed,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
StageParams,
)
from ..utils import (
base_join,
is_debug,
ServerContext,
)
from .utils import (
process_tiles,
)
import numpy as np
def blend_inpaint(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
expand: Border,
mask_image: Image.Image = None,
fill_color: str = 'white',
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
) -> Image.Image:
print('upscaling image by expanding borders', expand)
if mask_image is None:
# if no mask was provided, keep the full source image
mask_image = Image.new('RGB', source_image.size, 'black')
source_image, mask_image, noise_image, _full_dims = expand_image(
source_image,
mask_image,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter)
if is_debug():
source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*image.size)
mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug():
image.save(base_join(ctx.output_path, 'tile-source.png'))
mask.save(base_join(ctx.output_path, 'tile-mask.png'))
# TODO: must use inpainting model here
model = '../models/stable-diffusion-onnx-v1-inpainting'
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
model, params.provider, params.scheduler)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=image,
latents=latents,
mask_image=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
)
return result.images[0]
output = process_tiles(source_image, 512, 1, [outpaint])
print('final output image size', output.size)
return output

View File

@ -57,7 +57,7 @@ def correct_gfpgan(
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
) -> Image:
) -> Image.Image:
if upscale.correction_model is None:
print('no face model given, skipping')
return image

View File

@ -13,8 +13,8 @@ from ..utils import (
def persist_disk(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
output: str,

View File

@ -2,43 +2,31 @@ from diffusers import (
OnnxStableDiffusionPipeline,
)
from PIL import Image
from typing import Callable
from ..diffusion import (
get_latents_from_seed,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
StageParams,
)
from ..utils import (
base_join,
is_debug,
ServerContext,
)
from .utils import (
process_tiles,
)
import numpy as np
def generate_txt2img(
def source_txt2img(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
size: Size,
) -> Image:
) -> Image.Image:
print('generating image using txt2img', params.prompt)
if source_image is not None:

View File

@ -75,7 +75,7 @@ def upscale_resrgan(
source_image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
) -> Image.Image:
print('upscaling image with Real ESRGAN', upscale.scale)
output = np.array(source_image)

View File

@ -1,13 +1,18 @@
from PIL import Image
from typing import Callable, List
from typing import List, Protocol, Tuple
class TileCallback(Protocol):
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
pass
def process_tiles(
source: Image.Image,
tile: int,
scale: int,
filters: List[Callable],
) -> Image:
filters: List[TileCallback],
) -> Image.Image:
width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))

View File

@ -104,7 +104,7 @@ def run_txt2img_pipeline(
size: Size,
output: str,
upscale: UpscaleParams
):
) -> None:
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler)
@ -139,9 +139,9 @@ def run_img2img_pipeline(
params: ImageParams,
output: str,
upscale: UpscaleParams,
source_image: Image,
source_image: Image.Image,
strength: float,
):
) -> None:
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
@ -176,14 +176,14 @@ def run_inpaint_pipeline(
size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image,
mask_image: Image,
source_image: Image.Image,
mask_image: Image.Image,
expand: Border,
noise_source: Any,
mask_filter: Any,
strength: float,
fill_color: str,
):
) -> None:
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler)
@ -241,8 +241,8 @@ def run_upscale_pipeline(
_size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image
):
source_image: Image.Image,
) -> None:
image = run_upscale_correction(
ctx, StageParams(), params, source_image, upscale=upscale)

View File

@ -1,6 +1,5 @@
from numpy import random
from PIL import Image, ImageChops, ImageFilter
from typing import Callable, List
import numpy as np
@ -14,7 +13,7 @@ def get_pixel_index(x: int, y: int, width: int) -> int:
return (y * width) + x
def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image:
def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
width, height = dims
noise = Image.new('RGB', (width, height), fill)
@ -23,7 +22,7 @@ def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white'
return noise
def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image:
def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
'''
Gaussian blur with multiply, source image centered on white canvas.
'''
@ -36,7 +35,7 @@ def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point,
return noise
def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image:
def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
'''
Gaussian blur, source image centered on white canvas.
'''
@ -49,7 +48,7 @@ def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, r
return noise
def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image:
def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
'''
Identity transform, source image centered on white canvas.
'''
@ -61,7 +60,7 @@ def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill
return noise
def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image:
def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
'''
Fill the whole canvas, no source or noise.
'''
@ -72,7 +71,7 @@ def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill
return noise
def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image:
def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
'''
Gaussian blur, source image centered on white canvas.
'''
@ -85,7 +84,7 @@ def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, round
return noise
def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw) -> Image:
def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
width, height = dims
size = width * height
@ -107,7 +106,7 @@ def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw)
return noise
def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -> Image:
def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
width, height = dims
size = width * height
@ -129,7 +128,7 @@ def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -
return noise
def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw) -> Image:
def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
r, g, b = source_image.split()
width, height = dims
size = width * height
@ -161,8 +160,8 @@ def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
def expand_image(
source_image: Image,
mask_image: Image,
source_image: Image.Image,
mask_image: Image.Image,
expand: Border,
fill='white',
noise_source=noise_source_histogram,

View File

@ -24,8 +24,9 @@ from typing import Tuple
from .chain import (
correct_gfpgan,
generate_txt2img,
source_txt2img,
persist_disk,
persist_s3,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
@ -546,7 +547,7 @@ def chain():
# parse body as json, list of stages
example = ChainPipeline(stages=[
(generate_txt2img, StageParams(), {
(source_txt2img, StageParams(), {
'size': size,
}),
(upscale_outpaint, StageParams(), {
@ -561,6 +562,10 @@ def chain():
(persist_disk, StageParams(tile_size=8192), {
'output': output,
}),
(persist_s3, StageParams(tile_size=8192), {
'bucket': 'storage-stable-diffusion',
'output': output,
}),
])
# build and run chain pipeline

View File

@ -24,6 +24,10 @@ def run_upscale_correction(
*,
upscale: UpscaleParams,
) -> Image.Image:
'''
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
'''
print('running upscale pipeline')
chain = ChainPipeline()