lint(api): make modules for chain pipeline and params
This commit is contained in:
parent
bcaf0f73e6
commit
caafc9ebc9
|
@ -17,12 +17,18 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .params import (
|
||||
UpscaleParams,
|
||||
ImageParams,
|
||||
Border,
|
||||
Point,
|
||||
Size,
|
||||
)
|
||||
from .upscale import (
|
||||
load_resrgan,
|
||||
run_upscale_correction,
|
||||
upscale_gfpgan,
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
get_and_clamp_float,
|
||||
|
@ -31,9 +37,5 @@ from .utils import (
|
|||
get_from_map,
|
||||
get_not_empty,
|
||||
base_join,
|
||||
ImageParams,
|
||||
Border,
|
||||
Point,
|
||||
ServerContext,
|
||||
Size,
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
from .base import (
|
||||
ChainPipeline,
|
||||
PipelineStage,
|
||||
StageCallback,
|
||||
StageParams,
|
||||
)
|
|
@ -2,31 +2,18 @@ from PIL import Image
|
|||
from os import path
|
||||
from typing import Any, List, Optional, Protocol, Tuple
|
||||
|
||||
from .image import (
|
||||
from ..image import (
|
||||
process_tiles,
|
||||
)
|
||||
from .utils import (
|
||||
from ..params import (
|
||||
StageParams,
|
||||
)
|
||||
from ..utils import (
|
||||
ImageParams,
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
|
||||
class StageParams:
|
||||
'''
|
||||
Parameters for a pipeline stage, assuming they can be chained.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
tile_size: int = 512,
|
||||
outscale: int = 1,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.tile_size = tile_size
|
||||
self.outscale = outscale
|
||||
|
||||
|
||||
class StageCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -39,7 +26,7 @@ class StageCallback(Protocol):
|
|||
pass
|
||||
|
||||
|
||||
PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]
|
||||
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
|
||||
|
||||
|
||||
class ChainPipeline:
|
||||
|
@ -71,17 +58,19 @@ class ChainPipeline:
|
|||
source.size)
|
||||
image = source
|
||||
|
||||
for stage_fn, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.label or stage_fn.__name__
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
name = stage_params.label or stage_pipe.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
print('running pipeline stage %s on result image with dimensions %sx%s' %
|
||||
(name, image.width, image.height))
|
||||
|
||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||
print('source image larger than tile size of %s, tiling stage' % (
|
||||
stage_params.tile_size))
|
||||
|
||||
def stage_tile(tile: Image.Image) -> Image.Image:
|
||||
tile = stage_fn(ctx, stage_params, params, tile,
|
||||
**stage_kwargs)
|
||||
tile = stage_pipe(ctx, stage_params, params, tile,
|
||||
**kwargs)
|
||||
tile.save(path.join(ctx.output_path, 'last-tile.png'))
|
||||
return tile
|
||||
|
||||
|
@ -89,8 +78,8 @@ class ChainPipeline:
|
|||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||
else:
|
||||
print('source image within tile size, running stage')
|
||||
image = stage_fn(ctx, stage_params, params, image,
|
||||
**stage_kwargs)
|
||||
image = stage_pipe(ctx, stage_params, params, image,
|
||||
**kwargs)
|
||||
|
||||
print('finished running pipeline stage %s, result size: %sx%s' %
|
||||
(name, image.width, image.height))
|
|
@ -18,6 +18,11 @@ from .chain import (
|
|||
from .image import (
|
||||
expand_image,
|
||||
)
|
||||
from .params import (
|
||||
ImageParams,
|
||||
Border,
|
||||
Size,
|
||||
)
|
||||
from .upscale import (
|
||||
run_upscale_correction,
|
||||
UpscaleParams,
|
||||
|
@ -25,10 +30,7 @@ from .upscale import (
|
|||
from .utils import (
|
||||
is_debug,
|
||||
base_join,
|
||||
ImageParams,
|
||||
Border,
|
||||
ServerContext,
|
||||
Size,
|
||||
)
|
||||
|
||||
last_pipeline_instance = None
|
||||
|
@ -120,7 +122,8 @@ def run_txt2img_pipeline(
|
|||
num_inference_steps=params.steps,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -154,7 +157,8 @@ def run_img2img_pipeline(
|
|||
strength=strength,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -219,7 +223,8 @@ def run_inpaint_pipeline(
|
|||
else:
|
||||
print('output image size does not match source, skipping post-blend')
|
||||
|
||||
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -238,7 +243,8 @@ def run_upscale_pipeline(
|
|||
upscale: UpscaleParams,
|
||||
source_image: Image
|
||||
):
|
||||
image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, source_image, upscale=upscale)
|
||||
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Callable, List
|
|||
|
||||
import numpy as np
|
||||
|
||||
from .utils import (
|
||||
from .params import (
|
||||
Border,
|
||||
Point,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class Border:
|
||||
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.top = top
|
||||
self.bottom = bottom
|
||||
|
||||
|
||||
class Size:
|
||||
def __init__(self, width: int, height: int) -> None:
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def add_border(self, border: Border):
|
||||
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
||||
|
||||
def tojson(self) -> Dict[str, int]:
|
||||
return {
|
||||
'height': self.height,
|
||||
'width': self.width,
|
||||
}
|
||||
|
||||
|
||||
class ImageParams:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
scheduler: Any,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.scheduler = scheduler
|
||||
self.prompt = prompt
|
||||
self.negative_prompt = negative_prompt
|
||||
self.cfg = cfg
|
||||
self.steps = steps
|
||||
self.seed = seed
|
||||
|
||||
def tojson(self) -> Dict[str, Param]:
|
||||
return {
|
||||
'model': self.model,
|
||||
'provider': self.provider,
|
||||
'scheduler': self.scheduler.__name__,
|
||||
'seed': self.seed,
|
||||
'prompt': self.prompt,
|
||||
'cfg': self.cfg,
|
||||
'negativePrompt': self.negative_prompt,
|
||||
'steps': self.steps,
|
||||
}
|
||||
|
||||
|
||||
class StageParams:
|
||||
'''
|
||||
Parameters for a chained pipeline stage
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
tile_size: int = 512,
|
||||
outscale: int = 1,
|
||||
# batch_size: int = 1,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.tile_size = tile_size
|
||||
self.outscale = outscale
|
||||
|
||||
|
||||
class UpscaleParams():
|
||||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
provider: str,
|
||||
correction_model: Optional[str] = None,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_strength: float = 0.5,
|
||||
format: Literal['onnx', 'pth'] = 'onnx',
|
||||
half=False,
|
||||
outscale: int = 1,
|
||||
scale: int = 4,
|
||||
pre_pad: int = 0,
|
||||
tile_pad: int = 10,
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.provider = provider
|
||||
self.correction_model = correction_model
|
||||
self.denoise = denoise
|
||||
self.faces = faces
|
||||
self.face_strength = face_strength
|
||||
self.format = format
|
||||
self.half = half
|
||||
self.outscale = outscale
|
||||
self.pre_pad = pre_pad
|
||||
self.scale = scale
|
||||
self.tile_pad = tile_pad
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
return Size(size.width * self.outscale, size.height * self.outscale)
|
|
@ -19,7 +19,7 @@ from glob import glob
|
|||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from onnxruntime import get_available_providers
|
||||
from os import makedirs, path, scandir
|
||||
from os import makedirs, path
|
||||
from typing import Tuple
|
||||
|
||||
from .diffusion import (
|
||||
|
@ -41,7 +41,15 @@ from .image import (
|
|||
noise_source_normal,
|
||||
noise_source_uniform,
|
||||
)
|
||||
from .params import (
|
||||
Border,
|
||||
ImageParams,
|
||||
Size,
|
||||
)
|
||||
from .upscale import (
|
||||
correct_gfpgan,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
|
@ -53,10 +61,7 @@ from .utils import (
|
|||
get_not_empty,
|
||||
make_output_name,
|
||||
base_join,
|
||||
ImageParams,
|
||||
Border,
|
||||
ServerContext,
|
||||
Size,
|
||||
)
|
||||
|
||||
import gc
|
||||
|
@ -102,6 +107,11 @@ mask_filters = {
|
|||
'gaussian-multiply': mask_filter_gaussian_multiply,
|
||||
'gaussian-screen': mask_filter_gaussian_screen,
|
||||
}
|
||||
chain_stages = {
|
||||
'correction-gfpgan': correct_gfpgan,
|
||||
'upscaling-resrgan': upscale_resrgan,
|
||||
'upscaling-stable-diffusion': upscale_stable_diffusion,
|
||||
}
|
||||
|
||||
# Available ORT providers
|
||||
available_platforms = []
|
||||
|
@ -172,7 +182,7 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
|||
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
|
||||
|
||||
params = ImageParams(model_path, provider, scheduler, prompt,
|
||||
negative_prompt, cfg, steps, seed)
|
||||
negative_prompt, cfg, steps, seed)
|
||||
size = Size(width, height)
|
||||
return (params, size)
|
||||
|
||||
|
@ -526,6 +536,8 @@ def upscale():
|
|||
@app.route('/api/chain', methods=['POST'])
|
||||
def chain():
|
||||
print('TODO: run chain pipeline')
|
||||
# parse body as json, list of stages
|
||||
# build and run chain pipeline
|
||||
return jsonify({})
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from gfpgan import GFPGANer
|
|||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -21,44 +21,14 @@ from .onnx import (
|
|||
ONNXNet,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .utils import (
|
||||
from .params import (
|
||||
ImageParams,
|
||||
ServerContext,
|
||||
Size,
|
||||
UpscaleParams,
|
||||
)
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
|
||||
class UpscaleParams():
|
||||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
provider: str,
|
||||
correction_model: Optional[str] = None,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_strength: float = 0.5,
|
||||
format: Literal['onnx', 'pth'] = 'onnx',
|
||||
half=False,
|
||||
outscale: int = 1,
|
||||
scale: int = 4,
|
||||
pre_pad: int = 0,
|
||||
tile_pad: int = 10,
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.provider = provider
|
||||
self.correction_model = correction_model
|
||||
self.denoise = denoise
|
||||
self.faces = faces
|
||||
self.face_strength = face_strength
|
||||
self.format = format
|
||||
self.half = half
|
||||
self.outscale = outscale
|
||||
self.pre_pad = pre_pad
|
||||
self.scale = scale
|
||||
self.tile_pad = tile_pad
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
return Size(size.width * self.outscale, size.height * self.outscale)
|
||||
|
||||
|
||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
|
@ -125,7 +95,7 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
|||
def upscale_resrgan(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
|
@ -142,10 +112,10 @@ def upscale_resrgan(
|
|||
return output
|
||||
|
||||
|
||||
def upscale_gfpgan(
|
||||
def correct_gfpgan(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
|
@ -179,7 +149,7 @@ def upscale_gfpgan(
|
|||
|
||||
def upscale_stable_diffusion(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
_stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
|
@ -191,18 +161,12 @@ def upscale_stable_diffusion(
|
|||
generator = torch.manual_seed(params.seed)
|
||||
seed = generator.initial_seed()
|
||||
|
||||
def upscale_stage(_ctx: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image) -> Image:
|
||||
return pipeline(
|
||||
params.prompt,
|
||||
image,
|
||||
generator=torch.manual_seed(seed),
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
||||
|
||||
chain = ChainPipeline(stages=[
|
||||
(upscale_stage, stage)
|
||||
])
|
||||
return chain(ctx, params, source)
|
||||
return pipeline(
|
||||
params.prompt,
|
||||
source,
|
||||
generator=torch.manual_seed(seed),
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
||||
|
||||
|
||||
def run_upscale_correction(
|
||||
|
@ -216,20 +180,21 @@ def run_upscale_correction(
|
|||
print('running upscale pipeline')
|
||||
|
||||
chain = ChainPipeline()
|
||||
kwargs = {'upscale': upscale}
|
||||
|
||||
if upscale.scale > 1:
|
||||
if 'esrgan' in upscale.upscale_model:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
chain.append((upscale_resrgan, stage, {'upscale': upscale}))
|
||||
chain.append((upscale_resrgan, stage, kwargs))
|
||||
elif 'stable-diffusion' in upscale.upscale_model:
|
||||
mini_tile = min(128, stage.tile_size)
|
||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
chain.append((upscale_stable_diffusion, stage, {'upscale': upscale}))
|
||||
chain.append((upscale_stable_diffusion, stage, kwargs))
|
||||
|
||||
if upscale.faces:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
chain.append((upscale_gfpgan, stage, {'upscale': upscale}))
|
||||
chain.append((correct_gfpgan, stage, kwargs))
|
||||
|
||||
return chain(ctx, params, image)
|
||||
|
|
|
@ -4,51 +4,11 @@ from struct import pack
|
|||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from hashlib import sha256
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class ImageParams:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
scheduler: Any,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.scheduler = scheduler
|
||||
self.prompt = prompt
|
||||
self.negative_prompt = negative_prompt
|
||||
self.cfg = cfg
|
||||
self.steps = steps
|
||||
self.seed = seed
|
||||
|
||||
def tojson(self) -> Dict[str, Param]:
|
||||
return {
|
||||
'model': self.model,
|
||||
'provider': self.provider,
|
||||
'scheduler': self.scheduler.__name__,
|
||||
'seed': self.seed,
|
||||
'prompt': self.prompt,
|
||||
'cfg': self.cfg,
|
||||
'negativePrompt': self.negative_prompt,
|
||||
'steps': self.steps,
|
||||
}
|
||||
|
||||
|
||||
class Border:
|
||||
def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.top = top
|
||||
self.bottom = bottom
|
||||
from .params import (
|
||||
ImageParams,
|
||||
Param,
|
||||
Size,
|
||||
)
|
||||
|
||||
|
||||
class ServerContext:
|
||||
|
@ -83,25 +43,11 @@ class ServerContext:
|
|||
# others
|
||||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
||||
block_platforms=environ.get('ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
|
||||
block_platforms=environ.get(
|
||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
|
||||
)
|
||||
|
||||
|
||||
class Size:
|
||||
def __init__(self, width: int, height: int) -> None:
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def add_border(self, border: Border):
|
||||
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
||||
|
||||
def tojson(self) -> Dict[str, int]:
|
||||
return {
|
||||
'height': self.height,
|
||||
'width': self.width,
|
||||
}
|
||||
|
||||
|
||||
def is_debug() -> bool:
|
||||
return environ.get('DEBUG') is not None
|
||||
|
||||
|
|
Loading…
Reference in New Issue