feat(api): start implementing chain pipelines
This commit is contained in:
parent
260f7a29f4
commit
71ff3bb1c4
|
@ -18,7 +18,7 @@ from .image import (
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from .upscale import (
|
from .upscale import (
|
||||||
make_resrgan,
|
load_resrgan,
|
||||||
run_upscale_correction,
|
run_upscale_correction,
|
||||||
upscale_gfpgan,
|
upscale_gfpgan,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
|
@ -30,8 +30,8 @@ from .utils import (
|
||||||
get_from_list,
|
get_from_list,
|
||||||
get_from_map,
|
get_from_map,
|
||||||
get_not_empty,
|
get_not_empty,
|
||||||
safer_join,
|
base_join,
|
||||||
BaseParams,
|
ImageParams,
|
||||||
Border,
|
Border,
|
||||||
Point,
|
Point,
|
||||||
ServerContext,
|
ServerContext,
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
from PIL import Image
|
||||||
|
from os import path
|
||||||
|
from typing import Any, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
|
from .image import (
|
||||||
|
process_tiles,
|
||||||
|
)
|
||||||
|
from .utils import (
|
||||||
|
ImageParams,
|
||||||
|
ServerContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StageParams:
|
||||||
|
'''
|
||||||
|
Parameters for a pipeline stage, assuming they can be chained.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tile_size: int = 512,
|
||||||
|
outscale: int = 1,
|
||||||
|
) -> None:
|
||||||
|
self.tile_size = tile_size
|
||||||
|
self.outscale = outscale
|
||||||
|
|
||||||
|
|
||||||
|
class StageCallback(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ctx: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Image.Image:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ChainPipeline:
|
||||||
|
'''
|
||||||
|
Run many stages in series, passing the image results from each to the next, and processing
|
||||||
|
tiles as needed.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stages: List[PipelineStage],
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Create a new pipeline that will run the given stages.
|
||||||
|
'''
|
||||||
|
self.stages = stages
|
||||||
|
|
||||||
|
def append(self, stage: PipelineStage):
|
||||||
|
'''
|
||||||
|
Append an additional stage to this pipeline.
|
||||||
|
'''
|
||||||
|
self.stages.append(stage)
|
||||||
|
|
||||||
|
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image) -> Image.Image:
|
||||||
|
'''
|
||||||
|
TODO: handle List[Image] outputs
|
||||||
|
'''
|
||||||
|
print('running pipeline on source image with dimensions %sx%s' %
|
||||||
|
source.size)
|
||||||
|
image = source
|
||||||
|
|
||||||
|
for stage_fn, stage_params, stage_kwargs in self.stages:
|
||||||
|
print('running pipeline stage on result image with dimensions %sx%s' %
|
||||||
|
image.size)
|
||||||
|
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||||
|
print('source image larger than tile size, tiling stage',
|
||||||
|
stage_params.tile_size)
|
||||||
|
|
||||||
|
def stage_tile(tile: Image.Image) -> Image.Image:
|
||||||
|
tile = stage_fn(ctx, stage_params, tile,
|
||||||
|
params, **stage_kwargs)
|
||||||
|
tile.save(path.join(ctx.output_path, 'last-tile.png'))
|
||||||
|
return tile
|
||||||
|
|
||||||
|
image = process_tiles(
|
||||||
|
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||||
|
else:
|
||||||
|
print('source image within tile size, run stage')
|
||||||
|
image = stage_fn(ctx, stage_params, image,
|
||||||
|
params, **stage_kwargs)
|
||||||
|
|
||||||
|
print('finished running pipeline stage, result size: %sx%s' % image.size)
|
||||||
|
|
||||||
|
print('finished running pipeline, result size: %sx%s' % image.size)
|
||||||
|
return image
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', '.*The shape inference of prim::Constant type
|
||||||
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
|
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
|
||||||
warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
|
warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
|
||||||
|
|
||||||
Models = Dict[str, List[Tuple[str, str, Union[int, None]]]]
|
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||||
|
|
||||||
# recommended models
|
# recommended models
|
||||||
base_models: Models = {
|
base_models: Models = {
|
||||||
|
|
|
@ -6,7 +6,7 @@ from diffusers import (
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
)
|
)
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
from typing import Any, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -21,10 +21,11 @@ from .upscale import (
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
is_debug,
|
is_debug,
|
||||||
safer_join,
|
base_join,
|
||||||
BaseParams,
|
ImageParams,
|
||||||
Border,
|
Border,
|
||||||
ServerContext,
|
ServerContext,
|
||||||
|
StageParams,
|
||||||
Size,
|
Size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,7 +46,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None] = None):
|
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_scheduler
|
global last_pipeline_scheduler
|
||||||
global last_pipeline_options
|
global last_pipeline_options
|
||||||
|
@ -95,7 +96,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
params: BaseParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
output: str,
|
output: str,
|
||||||
upscale: UpscaleParams
|
upscale: UpscaleParams
|
||||||
|
@ -117,9 +118,9 @@ def run_txt2img_pipeline(
|
||||||
num_inference_steps=params.steps,
|
num_inference_steps=params.steps,
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(ctx, upscale, image)
|
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = base_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
|
@ -130,7 +131,7 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
params: BaseParams,
|
params: ImageParams,
|
||||||
output: str,
|
output: str,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
source_image: Image,
|
source_image: Image,
|
||||||
|
@ -151,9 +152,9 @@ def run_img2img_pipeline(
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(ctx, upscale, image)
|
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = base_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
|
@ -164,7 +165,8 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
params: BaseParams,
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
output: str,
|
output: str,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
@ -192,9 +194,9 @@ def run_inpaint_pipeline(
|
||||||
mask_filter=mask_filter)
|
mask_filter=mask_filter)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
source_image.save(safer_join(ctx.output_path, 'last-source.png'))
|
source_image.save(base_join(ctx.output_path, 'last-source.png'))
|
||||||
mask_image.save(safer_join(ctx.output_path, 'last-mask.png'))
|
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
|
||||||
noise_image.save(safer_join(ctx.output_path, 'last-noise.png'))
|
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
|
||||||
|
|
||||||
result = pipe(
|
result = pipe(
|
||||||
params.prompt,
|
params.prompt,
|
||||||
|
@ -215,9 +217,9 @@ def run_inpaint_pipeline(
|
||||||
else:
|
else:
|
||||||
print('output image size does not match source, skipping post-blend')
|
print('output image size does not match source, skipping post-blend')
|
||||||
|
|
||||||
image = run_upscale_correction(ctx, upscale, image)
|
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = base_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
|
@ -228,15 +230,15 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
_params: BaseParams,
|
params: ImageParams,
|
||||||
_size: Size,
|
_size: Size,
|
||||||
output: str,
|
output: str,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
source_image: Image
|
source_image: Image
|
||||||
):
|
):
|
||||||
image = run_upscale_correction(ctx, upscale, source_image)
|
image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)
|
||||||
|
|
||||||
dest = safer_join(ctx.output_path, output)
|
dest = base_join(ctx.output_path, output)
|
||||||
image.save(dest)
|
image.save(dest)
|
||||||
|
|
||||||
del image
|
del image
|
||||||
|
|
|
@ -205,7 +205,7 @@ def process_tiles(
|
||||||
idx = (y * tiles_x) + x
|
idx = (y * tiles_x) + x
|
||||||
left = x * tile
|
left = x * tile
|
||||||
top = y * tile
|
top = y * tile
|
||||||
print('processing tile %s of %s, %s.%s', idx, total, x, y)
|
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
|
||||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
|
|
|
@ -52,8 +52,8 @@ from .utils import (
|
||||||
get_from_map,
|
get_from_map,
|
||||||
get_not_empty,
|
get_not_empty,
|
||||||
make_output_name,
|
make_output_name,
|
||||||
safer_join,
|
base_join,
|
||||||
BaseParams,
|
ImageParams,
|
||||||
Border,
|
Border,
|
||||||
ServerContext,
|
ServerContext,
|
||||||
Size,
|
Size,
|
||||||
|
@ -124,7 +124,7 @@ def url_from_rule(rule) -> str:
|
||||||
return url_for(rule.endpoint, **options)
|
return url_for(rule.endpoint, **options)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
|
@ -171,7 +171,7 @@ def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
||||||
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||||
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
|
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
|
||||||
|
|
||||||
params = BaseParams(model_path, provider, scheduler, prompt,
|
params = ImageParams(model_path, provider, scheduler, prompt,
|
||||||
negative_prompt, cfg, steps, seed)
|
negative_prompt, cfg, steps, seed)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (params, size)
|
return (params, size)
|
||||||
|
@ -288,7 +288,7 @@ if is_debug():
|
||||||
# TODO: these two use context
|
# TODO: these two use context
|
||||||
|
|
||||||
def get_model_path(model: str):
|
def get_model_path(model: str):
|
||||||
return safer_join(context.model_path, model)
|
return base_join(context.model_path, model)
|
||||||
|
|
||||||
|
|
||||||
def ready_reply(ready: bool):
|
def ready_reply(ready: bool):
|
||||||
|
@ -523,6 +523,12 @@ def upscale():
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/chain', methods=['POST'])
|
||||||
|
def chain():
|
||||||
|
print('TODO: run chain pipeline')
|
||||||
|
return jsonify({})
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/ready')
|
@app.route('/api/ready')
|
||||||
def ready():
|
def ready():
|
||||||
output_file = request.args.get('output', None)
|
output_file = request.args.get('output', None)
|
||||||
|
@ -530,7 +536,7 @@ def ready():
|
||||||
done = executor.futures.done(output_file)
|
done = executor.futures.done(output_file)
|
||||||
|
|
||||||
if done is None:
|
if done is None:
|
||||||
file = safer_join(context.output_path, output_file)
|
file = base_join(context.output_path, output_file)
|
||||||
if path.exists(file):
|
if path.exists(file):
|
||||||
return ready_reply(True)
|
return ready_reply(True)
|
||||||
|
|
||||||
|
|
|
@ -8,34 +8,32 @@ from gfpgan import GFPGANer
|
||||||
from os import path
|
from os import path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from typing import Literal, Union
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .image import (
|
from .chain import (
|
||||||
process_tiles
|
ChainPipeline,
|
||||||
|
StageParams,
|
||||||
)
|
)
|
||||||
from .onnx import (
|
from .onnx import (
|
||||||
ONNXNet,
|
ONNXNet,
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
ImageParams,
|
||||||
ServerContext,
|
ServerContext,
|
||||||
Size,
|
Size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: these should all be params or config
|
|
||||||
pre_pad = 0
|
|
||||||
tile_pad = 10
|
|
||||||
|
|
||||||
|
|
||||||
class UpscaleParams():
|
class UpscaleParams():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
upscale_model: str,
|
upscale_model: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
correction_model: Union[str, None] = None,
|
correction_model: Optional[str] = None,
|
||||||
denoise: float = 0.5,
|
denoise: float = 0.5,
|
||||||
faces=True,
|
faces=True,
|
||||||
face_strength: float = 0.5,
|
face_strength: float = 0.5,
|
||||||
|
@ -43,6 +41,8 @@ class UpscaleParams():
|
||||||
half=False,
|
half=False,
|
||||||
outscale: int = 1,
|
outscale: int = 1,
|
||||||
scale: int = 4,
|
scale: int = 4,
|
||||||
|
pre_pad: int = 0,
|
||||||
|
tile_pad: int = 10,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.upscale_model = upscale_model
|
self.upscale_model = upscale_model
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
@ -53,13 +53,18 @@ class UpscaleParams():
|
||||||
self.format = format
|
self.format = format
|
||||||
self.half = half
|
self.half = half
|
||||||
self.outscale = outscale
|
self.outscale = outscale
|
||||||
|
self.pre_pad = pre_pad
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.tile_pad = tile_pad
|
||||||
|
|
||||||
def resize(self, size: Size) -> Size:
|
def resize(self, size: Size) -> Size:
|
||||||
return Size(size.width * self.outscale, size.height * self.outscale)
|
return Size(size.width * self.outscale, size.height * self.outscale)
|
||||||
|
|
||||||
|
|
||||||
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
|
'''
|
||||||
|
TODO: cache this instance
|
||||||
|
'''
|
||||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
model_file = '%s.%s' % (params.upscale_model, params.format)
|
||||||
model_path = path.join(ctx.model_path, model_file)
|
model_path = path.join(ctx.model_path, model_file)
|
||||||
if not path.isfile(model_path):
|
if not path.isfile(model_path):
|
||||||
|
@ -71,7 +76,6 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
elif params.format == 'pth':
|
elif params.format == 'pth':
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||||
else:
|
|
||||||
raise Exception('unknown platform %s' % params.format)
|
raise Exception('unknown platform %s' % params.format)
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
|
@ -88,88 +92,143 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
dni_weight=dni_weight,
|
dni_weight=dni_weight,
|
||||||
model=model,
|
model=model,
|
||||||
tile=tile,
|
tile=tile,
|
||||||
tile_pad=tile_pad,
|
tile_pad=params.tile_pad,
|
||||||
pre_pad=pre_pad,
|
pre_pad=params.pre_pad,
|
||||||
half=params.half)
|
half=params.half)
|
||||||
|
|
||||||
return upsampler
|
return upsampler
|
||||||
|
|
||||||
|
|
||||||
def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image:
|
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||||
print('upscaling image with Real ESRGAN', params.scale)
|
'''
|
||||||
|
TODO: cache this instance
|
||||||
|
'''
|
||||||
|
if upscale.format == 'onnx':
|
||||||
|
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||||
|
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||||
|
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||||
|
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||||
|
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
vae=AutoencoderKL.from_pretrained(
|
||||||
|
model_path, subfolder='vae_encoder'),
|
||||||
|
low_res_scheduler=DDPMScheduler.from_pretrained(
|
||||||
|
model_path, subfolder='scheduler'),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||||
|
'stabilityai/stable-diffusion-x4-upscaler')
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_resrgan(
|
||||||
|
ctx: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
source_image: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
) -> Image:
|
||||||
|
print('upscaling image with Real ESRGAN', upscale.scale)
|
||||||
|
|
||||||
output = np.array(source_image)
|
output = np.array(source_image)
|
||||||
upsampler = make_resrgan(ctx, params, tile=512)
|
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=params.outscale)
|
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||||
|
|
||||||
output = Image.fromarray(output, 'RGB')
|
output = Image.fromarray(output, 'RGB')
|
||||||
print('final output image size', output.size)
|
print('final output image size', output.size)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
|
def upscale_gfpgan(
|
||||||
print('correcting faces with GFPGAN model: %s' % params.correction_model)
|
ctx: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
if params.correction_model is None:
|
params: ImageParams,
|
||||||
|
image: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
upsampler: Optional[RealESRGANer] = None,
|
||||||
|
) -> Image:
|
||||||
|
if upscale.correction_model is None:
|
||||||
print('no face model given, skipping')
|
print('no face model given, skipping')
|
||||||
return image
|
return image
|
||||||
|
|
||||||
if upsampler is None:
|
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
|
||||||
upsampler = make_resrgan(ctx, params)
|
|
||||||
|
|
||||||
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))
|
if upsampler is None:
|
||||||
|
upsampler = load_resrgan(ctx, upscale)
|
||||||
|
|
||||||
|
face_path = path.join(ctx.model_path, '%s.pth' %
|
||||||
|
(upscale.correction_model))
|
||||||
|
|
||||||
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
||||||
face_enhancer = GFPGANer(
|
face_enhancer = GFPGANer(
|
||||||
model_path=face_path,
|
model_path=face_path,
|
||||||
upscale=params.outscale,
|
upscale=upscale.outscale,
|
||||||
arch='clean',
|
arch='clean',
|
||||||
channel_multiplier=2,
|
channel_multiplier=2,
|
||||||
bg_upsampler=upsampler)
|
bg_upsampler=upsampler)
|
||||||
|
|
||||||
_, _, output = face_enhancer.enhance(
|
_, _, output = face_enhancer.enhance(
|
||||||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
|
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
def upscale_stable_diffusion(
|
||||||
|
ctx: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
source: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
) -> Image:
|
||||||
print('upscaling with Stable Diffusion')
|
print('upscaling with Stable Diffusion')
|
||||||
model_path = '../models/%s' % params.upscale_model
|
|
||||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
|
||||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
|
||||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
|
||||||
# pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
|
||||||
# model_path,
|
|
||||||
# vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'),
|
|
||||||
# low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'),
|
|
||||||
# )
|
|
||||||
# result = pipeline('', image=image)
|
|
||||||
|
|
||||||
generator = torch.manual_seed(0)
|
pipeline = load_stable_diffusion(ctx, upscale)
|
||||||
|
generator = torch.manual_seed(params.seed)
|
||||||
seed = generator.initial_seed()
|
seed = generator.initial_seed()
|
||||||
|
|
||||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained('stabilityai/stable-diffusion-x4-upscaler')
|
def upscale_stage(_ctx: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image) -> Image:
|
||||||
upscale = lambda i: pipeline(
|
return pipeline(
|
||||||
'an astronaut eating a hamburger',
|
params.prompt,
|
||||||
image=i,
|
image=image,
|
||||||
generator=torch.manual_seed(initial_seed),
|
generator=torch.manual_seed(seed),
|
||||||
).images[0]
|
num_inference_steps=params.steps,
|
||||||
result = process_tiles(image, 128, 4, [upscale])
|
).images[0]
|
||||||
return result
|
|
||||||
|
chain = ChainPipeline(stages=[
|
||||||
|
(upscale_stage, stage)
|
||||||
|
])
|
||||||
|
return chain(ctx, params, source)
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
def run_upscale_correction(
|
||||||
|
ctx: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
image: Image.Image,
|
||||||
|
*,
|
||||||
|
upscale: UpscaleParams,
|
||||||
|
) -> Image.Image:
|
||||||
print('running upscale pipeline')
|
print('running upscale pipeline')
|
||||||
|
|
||||||
if params.scale > 1:
|
if upscale.scale > 1:
|
||||||
if 'esrgan' in params.upscale_model:
|
if 'esrgan' in upscale.upscale_model:
|
||||||
image = upscale_resrgan(ctx, params, image)
|
stage = StageParams(tile_size=stage.tile_size,
|
||||||
elif 'stable-diffusion' in params.upscale_model:
|
outscale=upscale.outscale)
|
||||||
image = upscale_stable_diffusion(ctx, params, image)
|
image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
|
||||||
|
elif 'stable-diffusion' in upscale.upscale_model:
|
||||||
|
mini_tile = max(128, stage.tile_size)
|
||||||
|
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||||
|
image = upscale_stable_diffusion(
|
||||||
|
ctx, stage, params, image, upscale=upscale)
|
||||||
|
|
||||||
if params.faces:
|
if upscale.faces:
|
||||||
image = upscale_gfpgan(ctx, params, image)
|
stage = StageParams(tile_size=stage.tile_size,
|
||||||
|
outscale=upscale.outscale)
|
||||||
|
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from time import time
|
from time import time
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,14 +9,14 @@ Param = Union[str, int, float]
|
||||||
Point = Tuple[int, int]
|
Point = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
class BaseParams:
|
class ImageParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
scheduler: Any,
|
scheduler: Any,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: Union[None, str],
|
negative_prompt: Optional[str],
|
||||||
cfg: float,
|
cfg: float,
|
||||||
steps: int,
|
steps: int,
|
||||||
seed: int
|
seed: int
|
||||||
|
@ -114,7 +114,7 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
|
||||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
def get_from_list(args: Any, key: str, values: List[Any]) -> Union[Any, None]:
|
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
||||||
selected = args.get(key, None)
|
selected = args.get(key, None)
|
||||||
if selected in values:
|
if selected in values:
|
||||||
return selected
|
return selected
|
||||||
|
@ -158,9 +158,9 @@ def hash_value(sha, param: Param):
|
||||||
|
|
||||||
def make_output_name(
|
def make_output_name(
|
||||||
mode: str,
|
mode: str,
|
||||||
params: BaseParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
extras: Union[None, Tuple[Param]] = None
|
extras: Optional[Tuple[Param]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
now = int(time())
|
now = int(time())
|
||||||
sha = sha256()
|
sha = sha256()
|
||||||
|
@ -184,6 +184,6 @@ def make_output_name(
|
||||||
return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
|
return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
|
||||||
|
|
||||||
|
|
||||||
def safer_join(base: str, tail: str) -> str:
|
def base_join(base: str, tail: str) -> str:
|
||||||
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||||
return path.join(base, safer_path)
|
return path.join(base, tail_path)
|
||||||
|
|
Loading…
Reference in New Issue