1
0
Fork 0

feat(api): return all types of models

This commit is contained in:
Sean Sube 2023-01-16 20:10:52 -06:00
parent dba6113c09
commit ee6308a091
3 changed files with 45 additions and 25 deletions

View File

@ -16,6 +16,7 @@ from diffusers import (
from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor
from glob import glob
from io import BytesIO
from PIL import Image
from os import makedirs, path, scandir
@ -45,6 +46,7 @@ from .upscale import (
from .utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_list,
get_from_map,
make_output_name,
safer_join,
@ -58,7 +60,6 @@ import json
import numpy as np
# pipeline caching
available_models = []
config_params = {}
# pipeline params
@ -95,14 +96,10 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen,
}
# TODO: load from model_path
upscale_models = [
'RealESRGAN_x4plus',
]
face_models = [
'GFPGANv1.3',
]
# loaded from model_path
diffusion_models = []
correction_models = []
upscaling_models = []
def url_from_rule(rule) -> str:
@ -183,13 +180,16 @@ def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models)
faces = request.args.get('faces', 'false') == 'true'
return UpscaleParams(
upscale_models[0],
upscaling,
correction_model=correction,
scale=scale,
outscale=outscale,
faces=faces,
face_model=face_models[0],
platform='onnx',
denoise=denoise,
)
@ -204,9 +204,16 @@ def check_paths(context: ServerContext):
def load_models(context: ServerContext):
global available_models
available_models = [f.name for f in scandir(
context.model_path) if f.is_dir()]
global diffusion_models
global correction_models
global upscaling_models
diffusion_models = glob(context.model_path, 'diffusion-*')
diffusion_models.append(glob(context.model_path, 'stable-diffusion-*'))
correction_models = glob(context.model_path, 'correction-*')
upscaling_models = glob(context.model_path, 'upscaling-*')
def load_params(context: ServerContext):
@ -271,7 +278,11 @@ def list_mask_filters():
@app.route('/api/settings/models')
def list_models():
return jsonify(available_models)
return jsonify({
'diffusion': diffusion_models,
'correction': correction_models,
'upscaling': upscaling_models,
})
@app.route('/api/settings/noises')
@ -397,7 +408,7 @@ def inpaint():
return jsonify({
'output': output,
'params': params.tojson(),
'size': upscale.resize(size.with_border(expand)).tojson(),
'size': upscale.resize(size.add_border(expand)).tojson(),
})

View File

@ -86,20 +86,20 @@ class UpscaleParams():
def __init__(
self,
upscale_model: str,
correction_model: Union[str, None] = None,
scale: int = 4,
outscale: int = 1,
denoise: float = 0.5,
faces=True,
face_model: Union[str, None] = None,
platform: str = 'onnx',
half=False
) -> None:
self.upscale_model = upscale_model
self.correction_model = correction_model
self.scale = scale
self.outscale = outscale
self.denoise = denoise
self.faces = faces
self.face_model = face_model
self.platform = platform
self.half = half
@ -158,16 +158,16 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN model: %s' % params.face_model)
print('correcting faces with GFPGAN model: %s' % params.correction_model)
if params.face_model is None:
if params.correction_model is None:
print('no face model given, skipping')
return image
if upsampler is None:
upsampler = make_resrgan(ctx, params, tile=512)
face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model))
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))
# TODO: doesn't have a model param, not sure how to pass ONNX model
face_enhancer = GFPGANer(

View File

@ -1,7 +1,7 @@
from os import environ, path
from time import time
from struct import pack
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
from hashlib import sha256
@ -89,15 +89,15 @@ class Size:
self.width = width
self.height = height
def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def tojson(self) -> Dict[str, int]:
return {
'height': self.height,
'width': self.width,
}
def with_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
@ -107,6 +107,15 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]):
selected = args.get(key, values[0])
if selected in values:
return selected
else:
print('invalid selection: %s' % (selected))
return values[0]
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any):
selected = args.get(key, default)
if selected in values: