1
0
Fork 0

lint(api): make modules for chain pipeline and params

This commit is contained in:
Sean Sube 2023-01-27 22:48:06 -06:00
parent bcaf0f73e6
commit caafc9ebc9
9 changed files with 199 additions and 161 deletions

View File

@ -17,12 +17,18 @@ from .image import (
noise_source_normal,
noise_source_uniform,
)
from .params import (
UpscaleParams,
ImageParams,
Border,
Point,
Size,
)
from .upscale import (
load_resrgan,
run_upscale_correction,
upscale_gfpgan,
correct_gfpgan,
upscale_resrgan,
UpscaleParams,
)
from .utils import (
get_and_clamp_float,
@ -31,9 +37,5 @@ from .utils import (
get_from_map,
get_not_empty,
base_join,
ImageParams,
Border,
Point,
ServerContext,
Size,
)

View File

@ -0,0 +1,6 @@
from .base import (
ChainPipeline,
PipelineStage,
StageCallback,
StageParams,
)

View File

@ -2,31 +2,18 @@ from PIL import Image
from os import path
from typing import Any, List, Optional, Protocol, Tuple
from .image import (
from ..image import (
process_tiles,
)
from .utils import (
from ..params import (
StageParams,
)
from ..utils import (
ImageParams,
ServerContext,
)
class StageParams:
'''
Parameters for a pipeline stage, assuming they can be chained.
'''
def __init__(
self,
name: Optional[str] = None,
tile_size: int = 512,
outscale: int = 1,
) -> None:
self.name = name
self.tile_size = tile_size
self.outscale = outscale
class StageCallback(Protocol):
def __call__(
self,
@ -39,7 +26,7 @@ class StageCallback(Protocol):
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainPipeline:
@ -71,17 +58,19 @@ class ChainPipeline:
source.size)
image = source
for stage_fn, stage_params, stage_kwargs in self.stages:
name = stage_params.label or stage_fn.__name__
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.label or stage_pipe.__name__
kwargs = stage_kwargs or {}
print('running pipeline stage %s on result image with dimensions %sx%s' %
(name, image.width, image.height))
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
print('source image larger than tile size of %s, tiling stage' % (
stage_params.tile_size))
def stage_tile(tile: Image.Image) -> Image.Image:
tile = stage_fn(ctx, stage_params, params, tile,
**stage_kwargs)
tile = stage_pipe(ctx, stage_params, params, tile,
**kwargs)
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile
@ -89,8 +78,8 @@ class ChainPipeline:
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
print('source image within tile size, running stage')
image = stage_fn(ctx, stage_params, params, image,
**stage_kwargs)
image = stage_pipe(ctx, stage_params, params, image,
**kwargs)
print('finished running pipeline stage %s, result size: %sx%s' %
(name, image.width, image.height))

View File

@ -18,6 +18,11 @@ from .chain import (
from .image import (
expand_image,
)
from .params import (
ImageParams,
Border,
Size,
)
from .upscale import (
run_upscale_correction,
UpscaleParams,
@ -25,10 +30,7 @@ from .upscale import (
from .utils import (
is_debug,
base_join,
ImageParams,
Border,
ServerContext,
Size,
)
last_pipeline_instance = None
@ -120,7 +122,8 @@ def run_txt2img_pipeline(
num_inference_steps=params.steps,
)
image = result.images[0]
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
dest = base_join(ctx.output_path, output)
image.save(dest)
@ -154,7 +157,8 @@ def run_img2img_pipeline(
strength=strength,
)
image = result.images[0]
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
dest = base_join(ctx.output_path, output)
image.save(dest)
@ -219,7 +223,8 @@ def run_inpaint_pipeline(
else:
print('output image size does not match source, skipping post-blend')
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
dest = base_join(ctx.output_path, output)
image.save(dest)
@ -238,7 +243,8 @@ def run_upscale_pipeline(
upscale: UpscaleParams,
source_image: Image
):
image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)
image = run_upscale_correction(
ctx, StageParams(), params, source_image, upscale=upscale)
dest = base_join(ctx.output_path, output)
image.save(dest)

View File

@ -4,7 +4,7 @@ from typing import Callable, List
import numpy as np
from .utils import (
from .params import (
Border,
Point,
)

112
api/onnx_web/params.py Normal file
View File

@ -0,0 +1,112 @@
from typing import Any, Dict, Literal, Optional, Tuple, Union
Param = Union[str, int, float]
Point = Tuple[int, int]
class Border:
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
self.left = left
self.right = right
self.top = top
self.bottom = bottom
class Size:
def __init__(self, width: int, height: int) -> None:
self.width = width
self.height = height
def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def tojson(self) -> Dict[str, int]:
return {
'height': self.height,
'width': self.width,
}
class ImageParams:
def __init__(
self,
model: str,
provider: str,
scheduler: Any,
prompt: str,
negative_prompt: Optional[str],
cfg: float,
steps: int,
seed: int
) -> None:
self.model = model
self.provider = provider
self.scheduler = scheduler
self.prompt = prompt
self.negative_prompt = negative_prompt
self.cfg = cfg
self.steps = steps
self.seed = seed
def tojson(self) -> Dict[str, Param]:
return {
'model': self.model,
'provider': self.provider,
'scheduler': self.scheduler.__name__,
'seed': self.seed,
'prompt': self.prompt,
'cfg': self.cfg,
'negativePrompt': self.negative_prompt,
'steps': self.steps,
}
class StageParams:
'''
Parameters for a chained pipeline stage
'''
def __init__(
self,
name: Optional[str] = None,
tile_size: int = 512,
outscale: int = 1,
# batch_size: int = 1,
) -> None:
self.name = name
self.tile_size = tile_size
self.outscale = outscale
class UpscaleParams():
def __init__(
self,
upscale_model: str,
provider: str,
correction_model: Optional[str] = None,
denoise: float = 0.5,
faces=True,
face_strength: float = 0.5,
format: Literal['onnx', 'pth'] = 'onnx',
half=False,
outscale: int = 1,
scale: int = 4,
pre_pad: int = 0,
tile_pad: int = 10,
) -> None:
self.upscale_model = upscale_model
self.provider = provider
self.correction_model = correction_model
self.denoise = denoise
self.faces = faces
self.face_strength = face_strength
self.format = format
self.half = half
self.outscale = outscale
self.pre_pad = pre_pad
self.scale = scale
self.tile_pad = tile_pad
def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale)

View File

@ -19,7 +19,7 @@ from glob import glob
from io import BytesIO
from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path, scandir
from os import makedirs, path
from typing import Tuple
from .diffusion import (
@ -41,7 +41,15 @@ from .image import (
noise_source_normal,
noise_source_uniform,
)
from .params import (
Border,
ImageParams,
Size,
)
from .upscale import (
correct_gfpgan,
upscale_resrgan,
upscale_stable_diffusion,
UpscaleParams,
)
from .utils import (
@ -53,10 +61,7 @@ from .utils import (
get_not_empty,
make_output_name,
base_join,
ImageParams,
Border,
ServerContext,
Size,
)
import gc
@ -102,6 +107,11 @@ mask_filters = {
'gaussian-multiply': mask_filter_gaussian_multiply,
'gaussian-screen': mask_filter_gaussian_screen,
}
chain_stages = {
'correction-gfpgan': correct_gfpgan,
'upscaling-resrgan': upscale_resrgan,
'upscaling-stable-diffusion': upscale_stable_diffusion,
}
# Available ORT providers
available_platforms = []
@ -172,7 +182,7 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]:
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
params = ImageParams(model_path, provider, scheduler, prompt,
negative_prompt, cfg, steps, seed)
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (params, size)
@ -526,6 +536,8 @@ def upscale():
@app.route('/api/chain', methods=['POST'])
def chain():
print('TODO: run chain pipeline')
# parse body as json, list of stages
# build and run chain pipeline
return jsonify({})

View File

@ -8,7 +8,7 @@ from gfpgan import GFPGANer
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Literal, Optional
from typing import Optional
import numpy as np
import torch
@ -21,44 +21,14 @@ from .onnx import (
ONNXNet,
OnnxStableDiffusionUpscalePipeline,
)
from .utils import (
from .params import (
ImageParams,
ServerContext,
Size,
UpscaleParams,
)
from .utils import (
ServerContext,
)
class UpscaleParams():
def __init__(
self,
upscale_model: str,
provider: str,
correction_model: Optional[str] = None,
denoise: float = 0.5,
faces=True,
face_strength: float = 0.5,
format: Literal['onnx', 'pth'] = 'onnx',
half=False,
outscale: int = 1,
scale: int = 4,
pre_pad: int = 0,
tile_pad: int = 10,
) -> None:
self.upscale_model = upscale_model
self.provider = provider
self.correction_model = correction_model
self.denoise = denoise
self.faces = faces
self.face_strength = face_strength
self.format = format
self.half = half
self.outscale = outscale
self.pre_pad = pre_pad
self.scale = scale
self.tile_pad = tile_pad
def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
@ -125,7 +95,7 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
def upscale_resrgan(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
_params: ImageParams,
source_image: Image.Image,
*,
upscale: UpscaleParams,
@ -142,10 +112,10 @@ def upscale_resrgan(
return output
def upscale_gfpgan(
def correct_gfpgan(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
_stage: StageParams,
_params: ImageParams,
image: Image.Image,
*,
upscale: UpscaleParams,
@ -179,7 +149,7 @@ def upscale_gfpgan(
def upscale_stable_diffusion(
ctx: ServerContext,
stage: StageParams,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
@ -191,18 +161,12 @@ def upscale_stable_diffusion(
generator = torch.manual_seed(params.seed)
seed = generator.initial_seed()
def upscale_stage(_ctx: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image) -> Image:
return pipeline(
params.prompt,
image,
generator=torch.manual_seed(seed),
num_inference_steps=params.steps,
).images[0]
chain = ChainPipeline(stages=[
(upscale_stage, stage)
])
return chain(ctx, params, source)
return pipeline(
params.prompt,
source,
generator=torch.manual_seed(seed),
num_inference_steps=params.steps,
).images[0]
def run_upscale_correction(
@ -216,20 +180,21 @@ def run_upscale_correction(
print('running upscale pipeline')
chain = ChainPipeline()
kwargs = {'upscale': upscale}
if upscale.scale > 1:
if 'esrgan' in upscale.upscale_model:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
chain.append((upscale_resrgan, stage, {'upscale': upscale}))
chain.append((upscale_resrgan, stage, kwargs))
elif 'stable-diffusion' in upscale.upscale_model:
mini_tile = min(128, stage.tile_size)
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
chain.append((upscale_stable_diffusion, stage, {'upscale': upscale}))
chain.append((upscale_stable_diffusion, stage, kwargs))
if upscale.faces:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
chain.append((upscale_gfpgan, stage, {'upscale': upscale}))
chain.append((correct_gfpgan, stage, kwargs))
return chain(ctx, params, image)

View File

@ -4,51 +4,11 @@ from struct import pack
from typing import Any, Dict, List, Optional, Tuple, Union
from hashlib import sha256
Param = Union[str, int, float]
Point = Tuple[int, int]
class ImageParams:
def __init__(
self,
model: str,
provider: str,
scheduler: Any,
prompt: str,
negative_prompt: Optional[str],
cfg: float,
steps: int,
seed: int
) -> None:
self.model = model
self.provider = provider
self.scheduler = scheduler
self.prompt = prompt
self.negative_prompt = negative_prompt
self.cfg = cfg
self.steps = steps
self.seed = seed
def tojson(self) -> Dict[str, Param]:
return {
'model': self.model,
'provider': self.provider,
'scheduler': self.scheduler.__name__,
'seed': self.seed,
'prompt': self.prompt,
'cfg': self.cfg,
'negativePrompt': self.negative_prompt,
'steps': self.steps,
}
class Border:
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
self.left = left
self.right = right
self.top = top
self.bottom = bottom
from .params import (
ImageParams,
Param,
Size,
)
class ServerContext:
@ -83,25 +43,11 @@ class ServerContext:
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
block_platforms=environ.get('ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
block_platforms=environ.get(
'ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
)
class Size:
def __init__(self, width: int, height: int) -> None:
self.width = width
self.height = height
def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def tojson(self) -> Dict[str, int]:
return {
'height': self.height,
'width': self.width,
}
def is_debug() -> bool:
return environ.get('DEBUG') is not None