1
0
Fork 0

feat(api): start implementing chain pipelines

This commit is contained in:
Sean Sube 2023-01-27 17:08:36 -06:00
parent 260f7a29f4
commit 71ff3bb1c4
8 changed files with 255 additions and 93 deletions

View File

@ -18,7 +18,7 @@ from .image import (
noise_source_uniform,
)
from .upscale import (
make_resrgan,
load_resrgan,
run_upscale_correction,
upscale_gfpgan,
upscale_resrgan,
@ -30,8 +30,8 @@ from .utils import (
get_from_list,
get_from_map,
get_not_empty,
safer_join,
BaseParams,
base_join,
ImageParams,
Border,
Point,
ServerContext,

95
api/onnx_web/chain.py Normal file
View File

@ -0,0 +1,95 @@
from PIL import Image
from os import path
from typing import Any, List, Optional, Protocol, Tuple
from .image import (
process_tiles,
)
from .utils import (
ImageParams,
ServerContext,
)
class StageParams:
'''
Parameters for a pipeline stage, assuming they can be chained.
'''
def __init__(
self,
tile_size: int = 512,
outscale: int = 1,
) -> None:
self.tile_size = tile_size
self.outscale = outscale
class StageCallback(Protocol):
def __call__(
self,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]
class ChainPipeline:
'''
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
'''
def __init__(
self,
stages: List[PipelineStage],
):
'''
Create a new pipeline that will run the given stages.
'''
self.stages = stages
def append(self, stage: PipelineStage):
'''
Append an additional stage to this pipeline.
'''
self.stages.append(stage)
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image) -> Image.Image:
'''
TODO: handle List[Image] outputs
'''
print('running pipeline on source image with dimensions %sx%s' %
source.size)
image = source
for stage_fn, stage_params, stage_kwargs in self.stages:
print('running pipeline stage on result image with dimensions %sx%s' %
image.size)
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
print('source image larger than tile size, tiling stage',
stage_params.tile_size)
def stage_tile(tile: Image.Image) -> Image.Image:
tile = stage_fn(ctx, stage_params, tile,
params, **stage_kwargs)
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile
image = process_tiles(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
print('source image within tile size, run stage')
image = stage_fn(ctx, stage_params, image,
params, **stage_kwargs)
print('finished running pipeline stage, result size: %sx%s' % image.size)
print('finished running pipeline, result size: %sx%s' % image.size)
return image

View File

@ -8,7 +8,7 @@ from pathlib import Path
from shutil import copyfile, rmtree
from sys import exit
from torch.onnx import export
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Optional, Tuple
import torch
import warnings
@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', '.*The shape inference of prim::Constant type
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
Models = Dict[str, List[Tuple[str, str, Union[int, None]]]]
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
# recommended models
base_models: Models = {

View File

@ -6,7 +6,7 @@ from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from PIL import Image, ImageChops
from typing import Any, Union
from typing import Any, Optional
import gc
import numpy as np
@ -21,10 +21,11 @@ from .upscale import (
)
from .utils import (
is_debug,
safer_join,
BaseParams,
base_join,
ImageParams,
Border,
ServerContext,
StageParams,
Size,
)
@ -45,7 +46,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None] = None):
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
@ -95,7 +96,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
def run_txt2img_pipeline(
ctx: ServerContext,
params: BaseParams,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams
@ -117,9 +118,9 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps,
)
image = result.images[0]
image = run_upscale_correction(ctx, upscale, image)
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)
del image
@ -130,7 +131,7 @@ def run_txt2img_pipeline(
def run_img2img_pipeline(
ctx: ServerContext,
params: BaseParams,
params: ImageParams,
output: str,
upscale: UpscaleParams,
source_image: Image,
@ -151,9 +152,9 @@ def run_img2img_pipeline(
strength=strength,
)
image = result.images[0]
image = run_upscale_correction(ctx, upscale, image)
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)
del image
@ -164,7 +165,8 @@ def run_img2img_pipeline(
def run_inpaint_pipeline(
ctx: ServerContext,
params: BaseParams,
stage: StageParams,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams,
@ -192,9 +194,9 @@ def run_inpaint_pipeline(
mask_filter=mask_filter)
if is_debug():
source_image.save(safer_join(ctx.output_path, 'last-source.png'))
mask_image.save(safer_join(ctx.output_path, 'last-mask.png'))
noise_image.save(safer_join(ctx.output_path, 'last-noise.png'))
source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
result = pipe(
params.prompt,
@ -215,9 +217,9 @@ def run_inpaint_pipeline(
else:
print('output image size does not match source, skipping post-blend')
image = run_upscale_correction(ctx, upscale, image)
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)
del image
@ -228,15 +230,15 @@ def run_inpaint_pipeline(
def run_upscale_pipeline(
ctx: ServerContext,
_params: BaseParams,
params: ImageParams,
_size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image
):
image = run_upscale_correction(ctx, upscale, source_image)
image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)
dest = safer_join(ctx.output_path, output)
dest = base_join(ctx.output_path, output)
image.save(dest)
del image

View File

@ -205,7 +205,7 @@ def process_tiles(
idx = (y * tiles_x) + x
left = x * tile
top = y * tile
print('processing tile %s of %s, %s.%s', idx, total, x, y)
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
tile_image = source.crop((left, top, left + tile, top + tile))
for filter in filters:

View File

@ -52,8 +52,8 @@ from .utils import (
get_from_map,
get_not_empty,
make_output_name,
safer_join,
BaseParams,
base_join,
ImageParams,
Border,
ServerContext,
Size,
@ -124,7 +124,7 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)
def pipeline_from_request() -> Tuple[BaseParams, Size]:
def pipeline_from_request() -> Tuple[ImageParams, Size]:
user = request.remote_addr
# pipeline stuff
@ -171,7 +171,7 @@ def pipeline_from_request() -> Tuple[BaseParams, Size]:
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
params = BaseParams(model_path, provider, scheduler, prompt,
params = ImageParams(model_path, provider, scheduler, prompt,
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (params, size)
@ -288,7 +288,7 @@ if is_debug():
# TODO: these two use context
def get_model_path(model: str):
return safer_join(context.model_path, model)
return base_join(context.model_path, model)
def ready_reply(ready: bool):
@ -523,6 +523,12 @@ def upscale():
})
@app.route('/api/chain', methods=['POST'])
def chain():
print('TODO: run chain pipeline')
return jsonify({})
@app.route('/api/ready')
def ready():
output_file = request.args.get('output', None)
@ -530,7 +536,7 @@ def ready():
done = executor.futures.done(output_file)
if done is None:
file = safer_join(context.output_path, output_file)
file = base_join(context.output_path, output_file)
if path.exists(file):
return ready_reply(True)

View File

@ -8,34 +8,32 @@ from gfpgan import GFPGANer
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Literal, Union
from typing import Literal, Optional
import numpy as np
import torch
from .image import (
process_tiles
from .chain import (
ChainPipeline,
StageParams,
)
from .onnx import (
ONNXNet,
OnnxStableDiffusionUpscalePipeline,
)
from .utils import (
ImageParams,
ServerContext,
Size,
)
# TODO: these should all be params or config
pre_pad = 0
tile_pad = 10
class UpscaleParams():
def __init__(
self,
upscale_model: str,
provider: str,
correction_model: Union[str, None] = None,
correction_model: Optional[str] = None,
denoise: float = 0.5,
faces=True,
face_strength: float = 0.5,
@ -43,6 +41,8 @@ class UpscaleParams():
half=False,
outscale: int = 1,
scale: int = 4,
pre_pad: int = 0,
tile_pad: int = 10,
) -> None:
self.upscale_model = upscale_model
self.provider = provider
@ -53,13 +53,18 @@ class UpscaleParams():
self.format = format
self.half = half
self.outscale = outscale
self.pre_pad = pre_pad
self.scale = scale
self.tile_pad = tile_pad
def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale)
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
'''
TODO: cache this instance
'''
model_file = '%s.%s' % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path):
@ -71,7 +76,6 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
else:
raise Exception('unknown platform %s' % params.format)
dni_weight = None
@ -88,88 +92,143 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=params.half)
return upsampler
def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image:
print('upscaling image with Real ESRGAN', params.scale)
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
'''
TODO: cache this instance
'''
if upscale.format == 'onnx':
model_path = path.join(ctx.model_path, upscale.upscale_model)
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler')
return pipeline
def upscale_resrgan(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling image with Real ESRGAN', upscale.scale)
output = np.array(source_image)
upsampler = make_resrgan(ctx, params, tile=512)
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=params.outscale)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, 'RGB')
print('final output image size', output.size)
return output
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN model: %s' % params.correction_model)
if params.correction_model is None:
def upscale_gfpgan(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
) -> Image:
if upscale.correction_model is None:
print('no face model given, skipping')
return image
if upsampler is None:
upsampler = make_resrgan(ctx, params)
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))
if upsampler is None:
upsampler = load_resrgan(ctx, upscale)
face_path = path.join(ctx.model_path, '%s.pth' %
(upscale.correction_model))
# TODO: doesn't have a model param, not sure how to pass ONNX model
face_enhancer = GFPGANer(
model_path=face_path,
upscale=params.outscale,
upscale=upscale.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance(
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
return output
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
def upscale_stable_diffusion(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling with Stable Diffusion')
model_path = '../models/%s' % params.upscale_model
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
# pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
# model_path,
# vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'),
# low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'),
# )
# result = pipeline('', image=image)
generator = torch.manual_seed(0)
pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)
seed = generator.initial_seed()
pipeline = StableDiffusionUpscalePipeline.from_pretrained('stabilityai/stable-diffusion-x4-upscaler')
upscale = lambda i: pipeline(
'an astronaut eating a hamburger',
image=i,
generator=torch.manual_seed(initial_seed),
).images[0]
result = process_tiles(image, 128, 4, [upscale])
return result
def upscale_stage(_ctx: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image) -> Image:
return pipeline(
params.prompt,
image=image,
generator=torch.manual_seed(seed),
num_inference_steps=params.steps,
).images[0]
chain = ChainPipeline(stages=[
(upscale_stage, stage)
])
return chain(ctx, params, source)
def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
def run_upscale_correction(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image.Image:
print('running upscale pipeline')
if params.scale > 1:
if 'esrgan' in params.upscale_model:
image = upscale_resrgan(ctx, params, image)
elif 'stable-diffusion' in params.upscale_model:
image = upscale_stable_diffusion(ctx, params, image)
if upscale.scale > 1:
if 'esrgan' in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
elif 'stable-diffusion' in upscale.upscale_model:
mini_tile = max(128, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
image = upscale_stable_diffusion(
ctx, stage, params, image, upscale=upscale)
if params.faces:
image = upscale_gfpgan(ctx, params, image)
if upscale.faces:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
return image

View File

@ -1,7 +1,7 @@
from os import environ, path
from time import time
from struct import pack
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from hashlib import sha256
@ -9,14 +9,14 @@ Param = Union[str, int, float]
Point = Tuple[int, int]
class BaseParams:
class ImageParams:
def __init__(
self,
model: str,
provider: str,
scheduler: Any,
prompt: str,
negative_prompt: Union[None, str],
negative_prompt: Optional[str],
cfg: float,
steps: int,
seed: int
@ -114,7 +114,7 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]) -> Union[Any, None]:
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
selected = args.get(key, None)
if selected in values:
return selected
@ -158,9 +158,9 @@ def hash_value(sha, param: Param):
def make_output_name(
mode: str,
params: BaseParams,
params: ImageParams,
size: Size,
extras: Union[None, Tuple[Param]] = None
extras: Optional[Tuple[Param]] = None
) -> str:
now = int(time())
sha = sha256()
@ -184,6 +184,6 @@ def make_output_name(
return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
def safer_join(base: str, tail: str) -> str:
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, safer_path)
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, tail_path)