feat(api): start implementing chain pipelines
This commit is contained in:
parent
260f7a29f4
commit
71ff3bb1c4
|
@ -18,7 +18,7 @@ from .image import (
|
|||
noise_source_uniform,
|
||||
)
|
||||
from .upscale import (
|
||||
make_resrgan,
|
||||
load_resrgan,
|
||||
run_upscale_correction,
|
||||
upscale_gfpgan,
|
||||
upscale_resrgan,
|
||||
|
@ -30,8 +30,8 @@ from .utils import (
|
|||
get_from_list,
|
||||
get_from_map,
|
||||
get_not_empty,
|
||||
safer_join,
|
||||
BaseParams,
|
||||
base_join,
|
||||
ImageParams,
|
||||
Border,
|
||||
Point,
|
||||
ServerContext,
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
from PIL import Image
|
||||
from os import path
|
||||
from typing import Any, List, Optional, Protocol, Tuple
|
||||
|
||||
from .image import (
|
||||
process_tiles,
|
||||
)
|
||||
from .utils import (
|
||||
ImageParams,
|
||||
ServerContext,
|
||||
)
|
||||
|
||||
|
||||
class StageParams:
|
||||
'''
|
||||
Parameters for a pipeline stage, assuming they can be chained.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tile_size: int = 512,
|
||||
outscale: int = 1,
|
||||
) -> None:
|
||||
self.tile_size = tile_size
|
||||
self.outscale = outscale
|
||||
|
||||
|
||||
class StageCallback(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
**kwargs: Any
|
||||
) -> Image.Image:
|
||||
pass
|
||||
|
||||
|
||||
PipelineStage = Tuple[StageCallback, StageParams, Optional[Any]]
|
||||
|
||||
|
||||
class ChainPipeline:
|
||||
'''
|
||||
Run many stages in series, passing the image results from each to the next, and processing
|
||||
tiles as needed.
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[PipelineStage],
|
||||
):
|
||||
'''
|
||||
Create a new pipeline that will run the given stages.
|
||||
'''
|
||||
self.stages = stages
|
||||
|
||||
def append(self, stage: PipelineStage):
|
||||
'''
|
||||
Append an additional stage to this pipeline.
|
||||
'''
|
||||
self.stages.append(stage)
|
||||
|
||||
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image) -> Image.Image:
|
||||
'''
|
||||
TODO: handle List[Image] outputs
|
||||
'''
|
||||
print('running pipeline on source image with dimensions %sx%s' %
|
||||
source.size)
|
||||
image = source
|
||||
|
||||
for stage_fn, stage_params, stage_kwargs in self.stages:
|
||||
print('running pipeline stage on result image with dimensions %sx%s' %
|
||||
image.size)
|
||||
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
|
||||
print('source image larger than tile size, tiling stage',
|
||||
stage_params.tile_size)
|
||||
|
||||
def stage_tile(tile: Image.Image) -> Image.Image:
|
||||
tile = stage_fn(ctx, stage_params, tile,
|
||||
params, **stage_kwargs)
|
||||
tile.save(path.join(ctx.output_path, 'last-tile.png'))
|
||||
return tile
|
||||
|
||||
image = process_tiles(
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||
else:
|
||||
print('source image within tile size, run stage')
|
||||
image = stage_fn(ctx, stage_params, image,
|
||||
params, **stage_kwargs)
|
||||
|
||||
print('finished running pipeline stage, result size: %sx%s' % image.size)
|
||||
|
||||
print('finished running pipeline, result size: %sx%s' % image.size)
|
||||
return image
|
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||
from shutil import copyfile, rmtree
|
||||
from sys import exit
|
||||
from torch.onnx import export
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
|
@ -18,7 +18,7 @@ warnings.filterwarnings('ignore', '.*The shape inference of prim::Constant type
|
|||
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
|
||||
warnings.filterwarnings('ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
|
||||
|
||||
Models = Dict[str, List[Tuple[str, str, Union[int, None]]]]
|
||||
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
|
|
|
@ -6,7 +6,7 @@ from diffusers import (
|
|||
OnnxStableDiffusionInpaintPipeline,
|
||||
)
|
||||
from PIL import Image, ImageChops
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import gc
|
||||
import numpy as np
|
||||
|
@ -21,10 +21,11 @@ from .upscale import (
|
|||
)
|
||||
from .utils import (
|
||||
is_debug,
|
||||
safer_join,
|
||||
BaseParams,
|
||||
base_join,
|
||||
ImageParams,
|
||||
Border,
|
||||
ServerContext,
|
||||
StageParams,
|
||||
Size,
|
||||
)
|
||||
|
||||
|
@ -45,7 +46,7 @@ def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
|
|||
return image_latents
|
||||
|
||||
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Union[str, None] = None):
|
||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_scheduler
|
||||
global last_pipeline_options
|
||||
|
@ -95,7 +96,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
|||
|
||||
def run_txt2img_pipeline(
|
||||
ctx: ServerContext,
|
||||
params: BaseParams,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams
|
||||
|
@ -117,9 +118,9 @@ def run_txt2img_pipeline(
|
|||
num_inference_steps=params.steps,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(ctx, upscale, image)
|
||||
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
||||
del image
|
||||
|
@ -130,7 +131,7 @@ def run_txt2img_pipeline(
|
|||
|
||||
def run_img2img_pipeline(
|
||||
ctx: ServerContext,
|
||||
params: BaseParams,
|
||||
params: ImageParams,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image,
|
||||
|
@ -151,9 +152,9 @@ def run_img2img_pipeline(
|
|||
strength=strength,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(ctx, upscale, image)
|
||||
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
||||
del image
|
||||
|
@ -164,7 +165,8 @@ def run_img2img_pipeline(
|
|||
|
||||
def run_inpaint_pipeline(
|
||||
ctx: ServerContext,
|
||||
params: BaseParams,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
|
@ -192,9 +194,9 @@ def run_inpaint_pipeline(
|
|||
mask_filter=mask_filter)
|
||||
|
||||
if is_debug():
|
||||
source_image.save(safer_join(ctx.output_path, 'last-source.png'))
|
||||
mask_image.save(safer_join(ctx.output_path, 'last-mask.png'))
|
||||
noise_image.save(safer_join(ctx.output_path, 'last-noise.png'))
|
||||
source_image.save(base_join(ctx.output_path, 'last-source.png'))
|
||||
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
|
||||
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
|
||||
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
|
@ -215,9 +217,9 @@ def run_inpaint_pipeline(
|
|||
else:
|
||||
print('output image size does not match source, skipping post-blend')
|
||||
|
||||
image = run_upscale_correction(ctx, upscale, image)
|
||||
image = run_upscale_correction(ctx, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
||||
del image
|
||||
|
@ -228,15 +230,15 @@ def run_inpaint_pipeline(
|
|||
|
||||
def run_upscale_pipeline(
|
||||
ctx: ServerContext,
|
||||
_params: BaseParams,
|
||||
params: ImageParams,
|
||||
_size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image
|
||||
):
|
||||
image = run_upscale_correction(ctx, upscale, source_image)
|
||||
image = run_upscale_correction(ctx, StageParams(), params, source_image, upscale=upscale)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
dest = base_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
||||
del image
|
||||
|
|
|
@ -205,7 +205,7 @@ def process_tiles(
|
|||
idx = (y * tiles_x) + x
|
||||
left = x * tile
|
||||
top = y * tile
|
||||
print('processing tile %s of %s, %s.%s', idx, total, x, y)
|
||||
print('processing tile %s of %s, %s.%s' % (idx, total, y, x))
|
||||
tile_image = source.crop((left, top, left + tile, top + tile))
|
||||
|
||||
for filter in filters:
|
||||
|
|
|
@ -52,8 +52,8 @@ from .utils import (
|
|||
get_from_map,
|
||||
get_not_empty,
|
||||
make_output_name,
|
||||
safer_join,
|
||||
BaseParams,
|
||||
base_join,
|
||||
ImageParams,
|
||||
Border,
|
||||
ServerContext,
|
||||
Size,
|
||||
|
@ -124,7 +124,7 @@ def url_from_rule(rule) -> str:
|
|||
return url_for(rule.endpoint, **options)
|
||||
|
||||
|
||||
def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
||||
def pipeline_from_request() -> Tuple[ImageParams, Size]:
|
||||
user = request.remote_addr
|
||||
|
||||
# pipeline stuff
|
||||
|
@ -171,7 +171,7 @@ def pipeline_from_request() -> Tuple[BaseParams, Size]:
|
|||
print("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s" %
|
||||
(user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt))
|
||||
|
||||
params = BaseParams(model_path, provider, scheduler, prompt,
|
||||
params = ImageParams(model_path, provider, scheduler, prompt,
|
||||
negative_prompt, cfg, steps, seed)
|
||||
size = Size(width, height)
|
||||
return (params, size)
|
||||
|
@ -288,7 +288,7 @@ if is_debug():
|
|||
# TODO: these two use context
|
||||
|
||||
def get_model_path(model: str):
|
||||
return safer_join(context.model_path, model)
|
||||
return base_join(context.model_path, model)
|
||||
|
||||
|
||||
def ready_reply(ready: bool):
|
||||
|
@ -523,6 +523,12 @@ def upscale():
|
|||
})
|
||||
|
||||
|
||||
@app.route('/api/chain', methods=['POST'])
|
||||
def chain():
|
||||
print('TODO: run chain pipeline')
|
||||
return jsonify({})
|
||||
|
||||
|
||||
@app.route('/api/ready')
|
||||
def ready():
|
||||
output_file = request.args.get('output', None)
|
||||
|
@ -530,7 +536,7 @@ def ready():
|
|||
done = executor.futures.done(output_file)
|
||||
|
||||
if done is None:
|
||||
file = safer_join(context.output_path, output_file)
|
||||
file = base_join(context.output_path, output_file)
|
||||
if path.exists(file):
|
||||
return ready_reply(True)
|
||||
|
||||
|
|
|
@ -8,34 +8,32 @@ from gfpgan import GFPGANer
|
|||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .image import (
|
||||
process_tiles
|
||||
from .chain import (
|
||||
ChainPipeline,
|
||||
StageParams,
|
||||
)
|
||||
from .onnx import (
|
||||
ONNXNet,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .utils import (
|
||||
ImageParams,
|
||||
ServerContext,
|
||||
Size,
|
||||
)
|
||||
|
||||
# TODO: these should all be params or config
|
||||
pre_pad = 0
|
||||
tile_pad = 10
|
||||
|
||||
|
||||
class UpscaleParams():
|
||||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
provider: str,
|
||||
correction_model: Union[str, None] = None,
|
||||
correction_model: Optional[str] = None,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_strength: float = 0.5,
|
||||
|
@ -43,6 +41,8 @@ class UpscaleParams():
|
|||
half=False,
|
||||
outscale: int = 1,
|
||||
scale: int = 4,
|
||||
pre_pad: int = 0,
|
||||
tile_pad: int = 10,
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.provider = provider
|
||||
|
@ -53,13 +53,18 @@ class UpscaleParams():
|
|||
self.format = format
|
||||
self.half = half
|
||||
self.outscale = outscale
|
||||
self.pre_pad = pre_pad
|
||||
self.scale = scale
|
||||
self.tile_pad = tile_pad
|
||||
|
||||
def resize(self, size: Size) -> Size:
|
||||
return Size(size.width * self.outscale, size.height * self.outscale)
|
||||
|
||||
|
||||
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
'''
|
||||
TODO: cache this instance
|
||||
'''
|
||||
model_file = '%s.%s' % (params.upscale_model, params.format)
|
||||
model_path = path.join(ctx.model_path, model_file)
|
||||
if not path.isfile(model_path):
|
||||
|
@ -71,7 +76,6 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
|||
elif params.format == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
else:
|
||||
raise Exception('unknown platform %s' % params.format)
|
||||
|
||||
dni_weight = None
|
||||
|
@ -88,88 +92,143 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
|||
dni_weight=dni_weight,
|
||||
model=model,
|
||||
tile=tile,
|
||||
tile_pad=tile_pad,
|
||||
pre_pad=pre_pad,
|
||||
tile_pad=params.tile_pad,
|
||||
pre_pad=params.pre_pad,
|
||||
half=params.half)
|
||||
|
||||
return upsampler
|
||||
|
||||
|
||||
def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image:
|
||||
print('upscaling image with Real ESRGAN', params.scale)
|
||||
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||
'''
|
||||
TODO: cache this instance
|
||||
'''
|
||||
if upscale.format == 'onnx':
|
||||
model_path = path.join(ctx.model_path, upscale.upscale_model)
|
||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_path,
|
||||
vae=AutoencoderKL.from_pretrained(
|
||||
model_path, subfolder='vae_encoder'),
|
||||
low_res_scheduler=DDPMScheduler.from_pretrained(
|
||||
model_path, subfolder='scheduler'),
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
'stabilityai/stable-diffusion-x4-upscaler')
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def upscale_resrgan(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image:
|
||||
print('upscaling image with Real ESRGAN', upscale.scale)
|
||||
|
||||
output = np.array(source_image)
|
||||
upsampler = make_resrgan(ctx, params, tile=512)
|
||||
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
|
||||
|
||||
output, _ = upsampler.enhance(output, outscale=params.outscale)
|
||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||
|
||||
output = Image.fromarray(output, 'RGB')
|
||||
print('final output image size', output.size)
|
||||
return output
|
||||
|
||||
|
||||
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
|
||||
print('correcting faces with GFPGAN model: %s' % params.correction_model)
|
||||
|
||||
if params.correction_model is None:
|
||||
def upscale_gfpgan(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
upsampler: Optional[RealESRGANer] = None,
|
||||
) -> Image:
|
||||
if upscale.correction_model is None:
|
||||
print('no face model given, skipping')
|
||||
return image
|
||||
|
||||
if upsampler is None:
|
||||
upsampler = make_resrgan(ctx, params)
|
||||
print('correcting faces with GFPGAN model: %s' % upscale.correction_model)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))
|
||||
if upsampler is None:
|
||||
upsampler = load_resrgan(ctx, upscale)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' %
|
||||
(upscale.correction_model))
|
||||
|
||||
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
||||
face_enhancer = GFPGANer(
|
||||
model_path=face_path,
|
||||
upscale=params.outscale,
|
||||
upscale=upscale.outscale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler)
|
||||
|
||||
_, _, output = face_enhancer.enhance(
|
||||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength)
|
||||
image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def upscale_stable_diffusion(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
def upscale_stable_diffusion(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image:
|
||||
print('upscaling with Stable Diffusion')
|
||||
model_path = '../models/%s' % params.upscale_model
|
||||
# ValueError: Pipeline <class 'onnx_web.onnx.pipeline_onnx_stable_diffusion_upscale.OnnxStableDiffusionUpscalePipeline'>
|
||||
# expected {'vae', 'unet', 'text_encoder', 'tokenizer', 'scheduler', 'low_res_scheduler'},
|
||||
# but only {'scheduler', 'tokenizer', 'text_encoder', 'unet'} were passed.
|
||||
# pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
# model_path,
|
||||
# vae=AutoencoderKL.from_pretrained(model_path, subfolder='vae_encoder'),
|
||||
# low_res_scheduler=DDPMScheduler.from_pretrained(model_path, subfolder='scheduler'),
|
||||
# )
|
||||
# result = pipeline('', image=image)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
pipeline = load_stable_diffusion(ctx, upscale)
|
||||
generator = torch.manual_seed(params.seed)
|
||||
seed = generator.initial_seed()
|
||||
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained('stabilityai/stable-diffusion-x4-upscaler')
|
||||
upscale = lambda i: pipeline(
|
||||
'an astronaut eating a hamburger',
|
||||
image=i,
|
||||
generator=torch.manual_seed(initial_seed),
|
||||
).images[0]
|
||||
result = process_tiles(image, 128, 4, [upscale])
|
||||
return result
|
||||
def upscale_stage(_ctx: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image) -> Image:
|
||||
return pipeline(
|
||||
params.prompt,
|
||||
image=image,
|
||||
generator=torch.manual_seed(seed),
|
||||
num_inference_steps=params.steps,
|
||||
).images[0]
|
||||
|
||||
chain = ChainPipeline(stages=[
|
||||
(upscale_stage, stage)
|
||||
])
|
||||
return chain(ctx, params, source)
|
||||
|
||||
|
||||
def run_upscale_correction(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image:
|
||||
def run_upscale_correction(
|
||||
ctx: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
) -> Image.Image:
|
||||
print('running upscale pipeline')
|
||||
|
||||
if params.scale > 1:
|
||||
if 'esrgan' in params.upscale_model:
|
||||
image = upscale_resrgan(ctx, params, image)
|
||||
elif 'stable-diffusion' in params.upscale_model:
|
||||
image = upscale_stable_diffusion(ctx, params, image)
|
||||
if upscale.scale > 1:
|
||||
if 'esrgan' in upscale.upscale_model:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
image = upscale_resrgan(ctx, stage, params, image, upscale=upscale)
|
||||
elif 'stable-diffusion' in upscale.upscale_model:
|
||||
mini_tile = max(128, stage.tile_size)
|
||||
stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
image = upscale_stable_diffusion(
|
||||
ctx, stage, params, image, upscale=upscale)
|
||||
|
||||
if params.faces:
|
||||
image = upscale_gfpgan(ctx, params, image)
|
||||
if upscale.faces:
|
||||
stage = StageParams(tile_size=stage.tile_size,
|
||||
outscale=upscale.outscale)
|
||||
image = upscale_gfpgan(ctx, stage, params, image, upscale=upscale)
|
||||
|
||||
return image
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from os import environ, path
|
||||
from time import time
|
||||
from struct import pack
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from hashlib import sha256
|
||||
|
||||
|
||||
|
@ -9,14 +9,14 @@ Param = Union[str, int, float]
|
|||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class BaseParams:
|
||||
class ImageParams:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
scheduler: Any,
|
||||
prompt: str,
|
||||
negative_prompt: Union[None, str],
|
||||
negative_prompt: Optional[str],
|
||||
cfg: float,
|
||||
steps: int,
|
||||
seed: int
|
||||
|
@ -114,7 +114,7 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
|
|||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||
|
||||
|
||||
def get_from_list(args: Any, key: str, values: List[Any]) -> Union[Any, None]:
|
||||
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
||||
selected = args.get(key, None)
|
||||
if selected in values:
|
||||
return selected
|
||||
|
@ -158,9 +158,9 @@ def hash_value(sha, param: Param):
|
|||
|
||||
def make_output_name(
|
||||
mode: str,
|
||||
params: BaseParams,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Union[None, Tuple[Param]] = None
|
||||
extras: Optional[Tuple[Param]] = None
|
||||
) -> str:
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
@ -184,6 +184,6 @@ def make_output_name(
|
|||
return '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
|
||||
|
||||
|
||||
def safer_join(base: str, tail: str) -> str:
|
||||
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||
return path.join(base, safer_path)
|
||||
def base_join(base: str, tail: str) -> str:
|
||||
tail_path = path.relpath(path.normpath(path.join('/', tail)), '/')
|
||||
return path.join(base, tail_path)
|
||||
|
|
Loading…
Reference in New Issue