From ee6308a0918a865bfde2615d9a75e86f8796bb89 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 20:10:52 -0600 Subject: [PATCH] feat(api): return all types of models --- api/onnx_web/serve.py | 43 ++++++++++++++++++++++++++--------------- api/onnx_web/upscale.py | 10 +++++----- api/onnx_web/utils.py | 17 ++++++++++++---- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index ceed0f39..996e5956 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -16,6 +16,7 @@ from diffusers import ( from flask import Flask, jsonify, request, send_from_directory, url_for from flask_cors import CORS from flask_executor import Executor +from glob import glob from io import BytesIO from PIL import Image from os import makedirs, path, scandir @@ -45,6 +46,7 @@ from .upscale import ( from .utils import ( get_and_clamp_float, get_and_clamp_int, + get_from_list, get_from_map, make_output_name, safer_join, @@ -58,7 +60,6 @@ import json import numpy as np # pipeline caching -available_models = [] config_params = {} # pipeline params @@ -95,14 +96,10 @@ mask_filters = { 'gaussian-screen': mask_filter_gaussian_screen, } -# TODO: load from model_path -upscale_models = [ - 'RealESRGAN_x4plus', -] - -face_models = [ - 'GFPGANv1.3', -] +# loaded from model_path +diffusion_models = [] +correction_models = [] +upscaling_models = [] def url_from_rule(rule) -> str: @@ -183,13 +180,16 @@ def upscale_from_request() -> UpscaleParams: denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1) + upscaling = get_from_list(request.args, 'upscaling', upscaling_models) + correction = get_from_list(request.args, 'correction', correction_models) faces = request.args.get('faces', 'false') == 'true' + return UpscaleParams( - upscale_models[0], + upscaling, + correction_model=correction, scale=scale, outscale=outscale, faces=faces, - face_model=face_models[0], platform='onnx', denoise=denoise, ) @@ -204,9 +204,16 @@ def check_paths(context: ServerContext): def load_models(context: ServerContext): - global available_models - available_models = [f.name for f in scandir( - context.model_path) if f.is_dir()] + global diffusion_models + global correction_models + global upscaling_models + + diffusion_models = glob(context.model_path, 'diffusion-*') + diffusion_models.append(glob(context.model_path, 'stable-diffusion-*')) + + correction_models = glob(context.model_path, 'correction-*') + upscaling_models = glob(context.model_path, 'upscaling-*') + def load_params(context: ServerContext): @@ -271,7 +278,11 @@ def list_mask_filters(): @app.route('/api/settings/models') def list_models(): - return jsonify(available_models) + return jsonify({ + 'diffusion': diffusion_models, + 'correction': correction_models, + 'upscaling': upscaling_models, + }) @app.route('/api/settings/noises') @@ -397,7 +408,7 @@ def inpaint(): return jsonify({ 'output': output, 'params': params.tojson(), - 'size': upscale.resize(size.with_border(expand)).tojson(), + 'size': upscale.resize(size.add_border(expand)).tojson(), }) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 928ee78f..439e437a 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -86,20 +86,20 @@ class UpscaleParams(): def __init__( self, upscale_model: str, + correction_model: Union[str, None] = None, scale: int = 4, outscale: int = 1, denoise: float = 0.5, faces=True, - face_model: Union[str, None] = None, platform: str = 'onnx', half=False ) -> None: self.upscale_model = upscale_model + self.correction_model = correction_model self.scale = scale self.outscale = outscale self.denoise = denoise self.faces = faces - self.face_model = face_model self.platform = platform self.half = half @@ -158,16 +158,16 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image: - print('correcting faces with GFPGAN model: %s' % params.face_model) + print('correcting faces with GFPGAN model: %s' % params.correction_model) - if params.face_model is None: + if params.correction_model is None: print('no face model given, skipping') return image if upsampler is None: upsampler = make_resrgan(ctx, params, tile=512) - face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model)) + face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model)) # TODO: doesn't have a model param, not sure how to pass ONNX model face_enhancer = GFPGANer( diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 8f8f0483..35088147 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -1,7 +1,7 @@ from os import environ, path from time import time from struct import pack -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from hashlib import sha256 @@ -89,15 +89,15 @@ class Size: self.width = width self.height = height + def add_border(self, border: Border): + return Size(border.left + self.width + border.right, border.top + self.height + border.right) + def tojson(self) -> Dict[str, int]: return { 'height': self.height, 'width': self.width, } - def with_border(self, border: Border): - return Size(border.left + self.width + border.right, border.top + self.height + border.right) - def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: return min(max(float(args.get(key, default_value)), min_value), max_value) @@ -107,6 +107,15 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m return min(max(int(args.get(key, default_value)), min_value), max_value) +def get_from_list(args: Any, key: str, values: List[Any]): + selected = args.get(key, values[0]) + if selected in values: + return selected + else: + print('invalid selection: %s' % (selected)) + return values[0] + + def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any): selected = args.get(key, default) if selected in values: