1
0
Fork 0

feat(api): start implementing chain pipelines

This commit is contained in:
Sean Sube 2023-01-27 17:08:36 -06:00
parent 260f7a29f4
commit 71ff3bb1c4
8 changed files with 255 additions and 93 deletions

View File

@ -18,7 +18,7 @@ from .image import (
noise_source_uniform, noise_source_uniform,
) )
from .upscale import ( from .upscale import (
make_resrgan, load_resrgan,
run_upscale_correction, run_upscale_correction,
upscale_gfpgan, upscale_gfpgan,
upscale_resrgan, upscale_resrgan,
@ -30,8 +30,8 @@ from .utils import (
get_from_list, get_from_list,
get_from_map, get_from_map,
get_not_empty, get_not_empty,
safer_join, base_join,
BaseParams, ImageParams,
Border, Border,
Point, Point,
ServerContext, ServerContext,

95
api/onnx_web/chain.py Normal file
View File

@ -0,0 +1,95 @@
from PIL import Image
from os import path
from typing import Any, List, Optional, Protocol, Tuple
from .image import (
process_tiles,
)
from .utils import (
ImageParams,
ServerContext,
)
class StageParams:
'''
Parameters for a pipeline stage, assuming they can be chained.
'''
def __init__(
self,
tile_size: int = 512,
outscale: int = 1,
) -> None:
self.tile_size = tile_size
self.outscale = outscale
class StageCallback(Protocol):
def __call__(
self,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]
class ChainPipeline:
'''
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
'''
def __init__(
self,
stages: List[PipelineStage],
):
'''
Create a new pipeline that will run the given stages.
'''
self.stages = stages
def append(self, stage: PipelineStage):
'''
Append an additional stage to this pipeline.
'''
self.stages.append(stage)
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image) -> Image.Image:
'''
TODO: handle List[Image] outputs
'''
print('running pipeline on source image with dimensions %sx%s' %
source.size)
image = source
for stage_fn, stage_params, stage_kwargs in self.stages:
print('running pipeline stage on result image with dimensions %sx%s' %
image.size)
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
print('source image larger than tile size, tiling stage',
stage_params.tile_size)
def stage_tile(tile: Image.Image) -> Image.Image:
tile = stage_fn(ctx, stage_params, tile,
params, **stage_kwargs)
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile
image = process_tiles(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
print('source image within tile size, run stage')
image = stage_fn(ctx, stage_params, image,
params, **stage_kwargs)
print('finished running pipeline stage, result size: %sx%s' % image.size)
print('finished running pipeline, result size: %sx%s' % image.size)
return image

View File

@ -8,7 +8,7 @@ from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from sys import exit from sys import exit
from torch.onnx import export from torch.onnx import export
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple
import torch import torch
import warnings import warnings
@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', '.*The shape inference of prim::Constant type
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*') warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*') warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
Models = Dict[str, List[Tuple[str, str, Union[int, None]]]] Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
# recommended models # recommended models
base_models: Models = { base_models: Models = {

View File

@ -6,7 +6,7 @@ from diffusers import (
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
) )
from PIL import Image, ImageChops from PIL import Image, ImageChops
from typing import Any, Union from typing import Any, Optional
import gc import gc
import numpy as np import numpy as np
@ -21,10 +21,11 @@ from .upscale import (
) )
from .utils import ( from .utils import (
is_debug, is_debug,
safer_join, base_join,
BaseParams, ImageParams,
Border, Border,
ServerContext, ServerContext,
StageParams,
Size, Size,
) )
@ -45,7 +46,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
return image_latents return image_latents
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None] = None): def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
@ -95,7 +96,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
def run_txt2img_pipeline( def run_txt2img_pipeline(
ctx: ServerContext, ctx: ServerContext,
params: BaseParams, params: ImageParams,
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams upscale: UpscaleParams
@ -117,9 +118,9 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps, num_inference_steps=params.steps,
) )
image = result.images[0] image = result.images[0]
image = run_upscale_correction(ctx, upscale, image) image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
dest = safer_join(ctx.output_path, output) dest = base_join(ctx.output_path, output)
image.save(dest) image.save(dest)
del image del image
@ -130,7 +131,7 @@ def run_txt2img_pipeline(
def run_img2img_pipeline( def run_img2img_pipeline(
ctx: ServerContext, ctx: ServerContext,
params: BaseParams, params: ImageParams,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image, source_image: Image,
@ -151,9 +152,9 @@ def run_img2img_pipeline(
strength=strength, strength=strength,
) )
image = result.images[0] image = result.images[0]
image = run_upscale_correction(ctx, upscale, image) image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
dest = safer_join(ctx.output_path, output) dest = base_join(ctx.output_path, output)
image.save(dest) image.save(dest)
del image del image
@ -164,7 +165,8 @@ def run_img2img_pipeline(
def run_inpaint_pipeline( def run_inpaint_pipeline(
ctx: ServerContext, ctx: ServerContext,
params: BaseParams, stage: StageParams,
params: ImageParams,
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
@ -192,9 +194,9 @@ def run_inpaint_pipeline(
mask_filter=mask_filter) mask_filter=mask_filter)
if is_debug(): if is_debug():
source_image.save(safer_join(ctx.output_path, 'last-source.png')) source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(safer_join(ctx.output_path, 'last-mask.png')) mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
noise_image.save(safer_join(ctx.output_path, 'last-noise.png')) noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
result = pipe( result = pipe(
params.prompt, params.prompt,
@ -215,9 +217,9 @@ def run_inpaint_pipeline(
else: else:
print('output image size does not match source, skipping post-blend') print('output image size does not match source, skipping post-blend')
image = run_upscale_correction(ctx, upscale, image) image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
dest = safer_join(ctx.output_path, output) dest = base_join(ctx.output_path, output)
image.save(dest) image.save(dest)
del image del image
@ -228,15 +230,15 @@ def run_inpaint_pipeline(
def run_upscale_pipeline( def run_upscale_pipeline(
ctx: ServerContext, ctx: ServerContext,
_params: BaseParams, params: ImageParams,
_size: Size, _size: Size,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image source_image: Image
): ):
image = run_upscale_correction(ctx, upscale, source_image) image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)
dest = safer_join(ctx.output_path, output) dest = base_join(ctx.output_path, output)
image.save(dest) image.save(dest)
del image del image

View File

@ -205,7 +205,7 @@ def process_tiles(
idx = (y * tiles_x) + x idx = (y * tiles_x) + x
left = x * tile left = x * tile
top = y * tile top = y * tile
print('processing tile %s of %s, %s.%s', idx, total, x, y) print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
tile_image = source.crop((left, top, left + tile, top + tile)) tile_image = source.crop((left, top, left + tile, top + tile))
for filter in filters: for filter in filters:

View File

@ -52,8 +52,8 @@ from .utils import (
get_from_map, get_from_map,
get_not_empty, get_not_empty,
make_output_name, make_output_name,
safer_join, base_join,
BaseParams, ImageParams,
Border, Border,
ServerContext, ServerContext,
Size, Size,
@ -124,7 +124,7 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options) return url_for(rule.endpoint, **options)
def pipeline_from_request() -> Tuple[BaseParams, Size]: def pipeline_from_request() -> Tuple[ImageParams, Size]:
user = request.remote_addr user = request.remote_addr
# pipeline stuff # pipeline stuff
@ -171,7 +171,7 @@ def pipeline_from_request() -> Tuple[BaseParams, Size]:
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" % print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt)) (user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
params = BaseParams(model_path, provider, scheduler, prompt, params = ImageParams(model_path, provider, scheduler, prompt,
negative_prompt, cfg, steps, seed) negative_prompt, cfg, steps, seed)
size = Size(width, height) size = Size(width, height)
return (params, size) return (params, size)
@ -288,7 +288,7 @@ if is_debug():
# TODO: these two use context # TODO: these two use context
def get_model_path(model: str): def get_model_path(model: str):
return safer_join(context.model_path, model) return base_join(context.model_path, model)
def ready_reply(ready: bool): def ready_reply(ready: bool):
@ -523,6 +523,12 @@ def upscale():
}) })
@app.route('/api/chain', methods=['POST'])
def chain():
print('TODO: run chain pipeline')
return jsonify({})
@app.route('/api/ready') @app.route('/api/ready')
def ready(): def ready():
output_file = request.args.get('output', None) output_file = request.args.get('output', None)
@ -530,7 +536,7 @@ def ready():
done = executor.futures.done(output_file) done = executor.futures.done(output_file)
if done is None: if done is None:
file = safer_join(context.output_path, output_file) file = base_join(context.output_path, output_file)
if path.exists(file): if path.exists(file):
return ready_reply(True) return ready_reply(True)

View File

@ -8,34 +8,32 @@ from gfpgan import GFPGANer
from os import path from os import path
from PIL import Image from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from typing import Literal, Union from typing import Literal, Optional
import numpy as np import numpy as np
import torch import torch
from .image import ( from .chain import (
process_tiles ChainPipeline,
StageParams,
) )
from .onnx import ( from .onnx import (
ONNXNet, ONNXNet,
OnnxStableDiffusionUpscalePipeline, OnnxStableDiffusionUpscalePipeline,
) )
from .utils import ( from .utils import (
ImageParams,
ServerContext, ServerContext,
Size, Size,
) )
# TODO: these should all be params or config
pre_pad = 0
tile_pad = 10
class UpscaleParams(): class UpscaleParams():
def __init__( def __init__(
self, self,
upscale_model: str, upscale_model: str,
provider: str, provider: str,
correction_model: Union[str, None] = None, correction_model: Optional[str] = None,
denoise: float = 0.5, denoise: float = 0.5,
faces=True, faces=True,
face_strength: float = 0.5, face_strength: float = 0.5,
@ -43,6 +41,8 @@ class UpscaleParams():
half=False, half=False,
outscale: int = 1, outscale: int = 1,
scale: int = 4, scale: int = 4,
pre_pad: int = 0,
tile_pad: int = 10,
) -> None: ) -> None:
self.upscale_model = upscale_model self.upscale_model = upscale_model
self.provider = provider self.provider = provider
@ -53,13 +53,18 @@ class UpscaleParams():
self.format = format self.format = format
self.half = half self.half = half
self.outscale = outscale self.outscale = outscale
self.pre_pad = pre_pad
self.scale = scale self.scale = scale
self.tile_pad = tile_pad
def resize(self, size: Size) -> Size: def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale) return Size(size.width * self.outscale, size.height * self.outscale)
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
'''
TODO: cache this instance
'''
model_file = '%s.%s' % (params.upscale_model, params.format) model_file = '%s.%s' % (params.upscale_model, params.format)
model_path = path.join(ctx.model_path, model_file) model_path = path.join(ctx.model_path, model_file)
if not path.isfile(model_path): if not path.isfile(model_path):
@ -71,7 +76,6 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
elif params.format == 'pth': elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale) num_block=23, num_grow_ch=32, scale=params.scale)
else:
raise Exception('unknown platform %s' % params.format) raise Exception('unknown platform %s' % params.format)
dni_weight = None dni_weight = None
@ -88,88 +92,143 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
dni_weight=dni_weight, dni_weight=dni_weight,
model=model, model=model,
tile=tile, tile=tile,
tile_pad=tile_pad, tile_pad=params.tile_pad,
pre_pad=pre_pad, pre_pad=params.pre_pad,
half=params.half) half=params.half)
return upsampler return upsampler
def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image: def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
print('upscaling image with Real ESRGAN', params.scale) '''
TODO: cache this instance
'''
if upscale.format == 'onnx':
model_path = path.join(ctx.model_path, upscale.upscale_model)
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path,
vae=AutoencoderKL.from_pretrained(
model_path, subfolder='vae_encoder'),
low_res_scheduler=DDPMScheduler.from_pretrained(
model_path, subfolder='scheduler'),
)
else:
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
'stabilityai/stable-diffusion-x4-upscaler')
return pipeline
def upscale_resrgan(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling image with Real ESRGAN', upscale.scale)
output = np.array(source_image) output = np.array(source_image)
upsampler = make_resrgan(ctx, params, tile=512) upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=params.outscale) output, _ = upsampler.enhance(output, outscale=upscale.outscale)
output = Image.fromarray(output, 'RGB') output = Image.fromarray(output, 'RGB')
print('final output image size', output.size) print('final output image size', output.size)
return output return output
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image: def upscale_gfpgan(
print('correcting faces with GFPGAN model: %s' % params.correction_model) ctx: ServerContext,
stage: StageParams,
if params.correction_model is None: params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
) -> Image:
if upscale.correction_model is None:
print('no face model given, skipping') print('no face model given, skipping')
return image return image
if upsampler is None: print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
upsampler = make_resrgan(ctx, params)
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model)) if upsampler is None:
upsampler = load_resrgan(ctx, upscale)
face_path = path.join(ctx.model_path, '%s.pth' %
(upscale.correction_model))
# TODO: doesn't have a model param, not sure how to pass ONNX model # TODO: doesn't have a model param, not sure how to pass ONNX model
face_enhancer = GFPGANer( face_enhancer = GFPGANer(
model_path=face_path, model_path=face_path,
upscale=params.outscale, upscale=upscale.outscale,
arch='clean', arch='clean',
channel_multiplier=2, channel_multiplier=2,
bg_upsampler=upsampler) bg_upsampler=upsampler)
_, _, output = face_enhancer.enhance( _, _, output = face_enhancer.enhance(
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength) image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
return output return output
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: def upscale_stable_diffusion(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
upscale: UpscaleParams,
) -> Image:
print('upscaling with Stable Diffusion') print('upscaling with Stable Diffusion')
model_path = '../models/%s' % params.upscale_model
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
# pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
# model_path,
# vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'),
# low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'),
# )
# result = pipeline('', image=image)
generator = torch.manual_seed(0) pipeline = load_stable_diffusion(ctx, upscale)
generator = torch.manual_seed(params.seed)
seed = generator.initial_seed() seed = generator.initial_seed()
pipeline = StableDiffusionUpscalePipeline.from_pretrained('stabilityai/stable-diffusion-x4-upscaler') def upscale_stage(_ctx: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image) -> Image:
upscale = lambda i: pipeline( return pipeline(
'an astronaut eating a hamburger', params.prompt,
image=i, image=image,
generator=torch.manual_seed(initial_seed), generator=torch.manual_seed(seed),
).images[0] num_inference_steps=params.steps,
result = process_tiles(image, 128, 4, [upscale]) ).images[0]
return result
chain = ChainPipeline(stages=[
(upscale_stage, stage)
])
return chain(ctx, params, source)
def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: def run_upscale_correction(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
) -> Image.Image:
print('running upscale pipeline') print('running upscale pipeline')
if params.scale > 1: if upscale.scale > 1:
if 'esrgan' in params.upscale_model: if 'esrgan' in upscale.upscale_model:
image = upscale_resrgan(ctx, params, image) stage = StageParams(tile_size=stage.tile_size,
elif 'stable-diffusion' in params.upscale_model: outscale=upscale.outscale)
image = upscale_stable_diffusion(ctx, params, image) image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
elif 'stable-diffusion' in upscale.upscale_model:
mini_tile = max(128, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
image = upscale_stable_diffusion(
ctx, stage, params, image, upscale=upscale)
if params.faces: if upscale.faces:
image = upscale_gfpgan(ctx, params, image) stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
return image return image

View File

@ -1,7 +1,7 @@
from os import environ, path from os import environ, path
from time import time from time import time
from struct import pack from struct import pack
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from hashlib import sha256 from hashlib import sha256
@ -9,14 +9,14 @@ Param = Union[str, int, float]
Point = Tuple[int, int] Point = Tuple[int, int]
class BaseParams: class ImageParams:
def __init__( def __init__(
self, self,
model: str, model: str,
provider: str, provider: str,
scheduler: Any, scheduler: Any,
prompt: str, prompt: str,
negative_prompt: Union[None, str], negative_prompt: Optional[str],
cfg: float, cfg: float,
steps: int, steps: int,
seed: int seed: int
@ -114,7 +114,7 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]) -> Union[Any, None]: def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
selected = args.get(key, None) selected = args.get(key, None)
if selected in values: if selected in values:
return selected return selected
@ -158,9 +158,9 @@ def hash_value(sha, param: Param):
def make_output_name( def make_output_name(
mode: str, mode: str,
params: BaseParams, params: ImageParams,
size: Size, size: Size,
extras: Union[None, Tuple[Param]] = None extras: Optional[Tuple[Param]] = None
) -> str: ) -> str:
now = int(time()) now = int(time())
sha = sha256() sha = sha256()
@ -184,6 +184,6 @@ def make_output_name(
return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now) return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
def safer_join(base: str, tail: str) -> str: def base_join(base: str, tail: str) -> str:
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/') tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, safer_path) return path.join(base, tail_path)