feat(api): return all types of models
This commit is contained in:
parent
dba6113c09
commit
ee6308a091
|
@ -16,6 +16,7 @@ from diffusers import (
|
|||
from flask import Flask, jsonify, request, send_from_directory, url_for
|
||||
from flask_cors import CORS
|
||||
from flask_executor import Executor
|
||||
from glob import glob
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from os import makedirs, path, scandir
|
||||
|
@ -45,6 +46,7 @@ from .upscale import (
|
|||
from .utils import (
|
||||
get_and_clamp_float,
|
||||
get_and_clamp_int,
|
||||
get_from_list,
|
||||
get_from_map,
|
||||
make_output_name,
|
||||
safer_join,
|
||||
|
@ -58,7 +60,6 @@ import json
|
|||
import numpy as np
|
||||
|
||||
# pipeline caching
|
||||
available_models = []
|
||||
config_params = {}
|
||||
|
||||
# pipeline params
|
||||
|
@ -95,14 +96,10 @@ mask_filters = {
|
|||
'gaussian-screen': mask_filter_gaussian_screen,
|
||||
}
|
||||
|
||||
# TODO: load from model_path
|
||||
upscale_models = [
|
||||
'RealESRGAN_x4plus',
|
||||
]
|
||||
|
||||
face_models = [
|
||||
'GFPGANv1.3',
|
||||
]
|
||||
# loaded from model_path
|
||||
diffusion_models = []
|
||||
correction_models = []
|
||||
upscaling_models = []
|
||||
|
||||
|
||||
def url_from_rule(rule) -> str:
|
||||
|
@ -183,13 +180,16 @@ def upscale_from_request() -> UpscaleParams:
|
|||
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
||||
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
||||
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
||||
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
|
||||
correction = get_from_list(request.args, 'correction', correction_models)
|
||||
faces = request.args.get('faces', 'false') == 'true'
|
||||
|
||||
return UpscaleParams(
|
||||
upscale_models[0],
|
||||
upscaling,
|
||||
correction_model=correction,
|
||||
scale=scale,
|
||||
outscale=outscale,
|
||||
faces=faces,
|
||||
face_model=face_models[0],
|
||||
platform='onnx',
|
||||
denoise=denoise,
|
||||
)
|
||||
|
@ -204,9 +204,16 @@ def check_paths(context: ServerContext):
|
|||
|
||||
|
||||
def load_models(context: ServerContext):
|
||||
global available_models
|
||||
available_models = [f.name for f in scandir(
|
||||
context.model_path) if f.is_dir()]
|
||||
global diffusion_models
|
||||
global correction_models
|
||||
global upscaling_models
|
||||
|
||||
diffusion_models = glob(context.model_path, 'diffusion-*')
|
||||
diffusion_models.append(glob(context.model_path, 'stable-diffusion-*'))
|
||||
|
||||
correction_models = glob(context.model_path, 'correction-*')
|
||||
upscaling_models = glob(context.model_path, 'upscaling-*')
|
||||
|
||||
|
||||
|
||||
def load_params(context: ServerContext):
|
||||
|
@ -271,7 +278,11 @@ def list_mask_filters():
|
|||
|
||||
@app.route('/api/settings/models')
|
||||
def list_models():
|
||||
return jsonify(available_models)
|
||||
return jsonify({
|
||||
'diffusion': diffusion_models,
|
||||
'correction': correction_models,
|
||||
'upscaling': upscaling_models,
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/settings/noises')
|
||||
|
@ -397,7 +408,7 @@ def inpaint():
|
|||
return jsonify({
|
||||
'output': output,
|
||||
'params': params.tojson(),
|
||||
'size': upscale.resize(size.with_border(expand)).tojson(),
|
||||
'size': upscale.resize(size.add_border(expand)).tojson(),
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -86,20 +86,20 @@ class UpscaleParams():
|
|||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
correction_model: Union[str, None] = None,
|
||||
scale: int = 4,
|
||||
outscale: int = 1,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_model: Union[str, None] = None,
|
||||
platform: str = 'onnx',
|
||||
half=False
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.correction_model = correction_model
|
||||
self.scale = scale
|
||||
self.outscale = outscale
|
||||
self.denoise = denoise
|
||||
self.faces = faces
|
||||
self.face_model = face_model
|
||||
self.platform = platform
|
||||
self.half = half
|
||||
|
||||
|
@ -158,16 +158,16 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
|
|||
|
||||
|
||||
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
|
||||
print('correcting faces with GFPGAN model: %s' % params.face_model)
|
||||
print('correcting faces with GFPGAN model: %s' % params.correction_model)
|
||||
|
||||
if params.face_model is None:
|
||||
if params.correction_model is None:
|
||||
print('no face model given, skipping')
|
||||
return image
|
||||
|
||||
if upsampler is None:
|
||||
upsampler = make_resrgan(ctx, params, tile=512)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model))
|
||||
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))
|
||||
|
||||
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
||||
face_enhancer = GFPGANer(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from os import environ, path
|
||||
from time import time
|
||||
from struct import pack
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from hashlib import sha256
|
||||
|
||||
|
||||
|
@ -89,15 +89,15 @@ class Size:
|
|||
self.width = width
|
||||
self.height = height
|
||||
|
||||
def add_border(self, border: Border):
|
||||
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
||||
|
||||
def tojson(self) -> Dict[str, int]:
|
||||
return {
|
||||
'height': self.height,
|
||||
'width': self.width,
|
||||
}
|
||||
|
||||
def with_border(self, border: Border):
|
||||
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
|
||||
|
||||
|
||||
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
||||
|
@ -107,6 +107,15 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
|
|||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||
|
||||
|
||||
def get_from_list(args: Any, key: str, values: List[Any]):
|
||||
selected = args.get(key, values[0])
|
||||
if selected in values:
|
||||
return selected
|
||||
else:
|
||||
print('invalid selection: %s' % (selected))
|
||||
return values[0]
|
||||
|
||||
|
||||
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any):
|
||||
selected = args.get(key, default)
|
||||
if selected in values:
|
||||
|
|
Loading…
Reference in New Issue