1
0
Fork 0

feat(api): return all types of models

This commit is contained in:
Sean Sube 2023-01-16 20:10:52 -06:00
parent dba6113c09
commit ee6308a091
3 changed files with 45 additions and 25 deletions

View File

@ -16,6 +16,7 @@ from diffusers import (
from flask import Flask, jsonify, request, send_from_directory, url_for from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_cors import CORS from flask_cors import CORS
from flask_executor import Executor from flask_executor import Executor
from glob import glob
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from os import makedirs, path, scandir from os import makedirs, path, scandir
@ -45,6 +46,7 @@ from .upscale import (
from .utils import ( from .utils import (
get_and_clamp_float, get_and_clamp_float,
get_and_clamp_int, get_and_clamp_int,
get_from_list,
get_from_map, get_from_map,
make_output_name, make_output_name,
safer_join, safer_join,
@ -58,7 +60,6 @@ import json
import numpy as np import numpy as np
# pipeline caching # pipeline caching
available_models = []
config_params = {} config_params = {}
# pipeline params # pipeline params
@ -95,14 +96,10 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen, 'gaussian-screen': mask_filter_gaussian_screen,
} }
# TODO: load from model_path # loaded from model_path
upscale_models = [ diffusion_models = []
'RealESRGAN_x4plus', correction_models = []
] upscaling_models = []
face_models = [
'GFPGANv1.3',
]
def url_from_rule(rule) -> str: def url_from_rule(rule) -> str:
@ -183,13 +180,16 @@ def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1) outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
upscaling = get_from_list(request.args, 'upscaling', upscaling_models)
correction = get_from_list(request.args, 'correction', correction_models)
faces = request.args.get('faces', 'false') == 'true' faces = request.args.get('faces', 'false') == 'true'
return UpscaleParams( return UpscaleParams(
upscale_models[0], upscaling,
correction_model=correction,
scale=scale, scale=scale,
outscale=outscale, outscale=outscale,
faces=faces, faces=faces,
face_model=face_models[0],
platform='onnx', platform='onnx',
denoise=denoise, denoise=denoise,
) )
@ -204,9 +204,16 @@ def check_paths(context: ServerContext):
def load_models(context: ServerContext): def load_models(context: ServerContext):
global available_models global diffusion_models
available_models = [f.name for f in scandir( global correction_models
context.model_path) if f.is_dir()] global upscaling_models
diffusion_models = glob(context.model_path, 'diffusion-*')
diffusion_models.append(glob(context.model_path, 'stable-diffusion-*'))
correction_models = glob(context.model_path, 'correction-*')
upscaling_models = glob(context.model_path, 'upscaling-*')
def load_params(context: ServerContext): def load_params(context: ServerContext):
@ -271,7 +278,11 @@ def list_mask_filters():
@app.route('/api/settings/models') @app.route('/api/settings/models')
def list_models(): def list_models():
return jsonify(available_models) return jsonify({
'diffusion': diffusion_models,
'correction': correction_models,
'upscaling': upscaling_models,
})
@app.route('/api/settings/noises') @app.route('/api/settings/noises')
@ -397,7 +408,7 @@ def inpaint():
return jsonify({ return jsonify({
'output': output, 'output': output,
'params': params.tojson(), 'params': params.tojson(),
'size': upscale.resize(size.with_border(expand)).tojson(), 'size': upscale.resize(size.add_border(expand)).tojson(),
}) })

View File

@ -86,20 +86,20 @@ class UpscaleParams():
def __init__( def __init__(
self, self,
upscale_model: str, upscale_model: str,
correction_model: Union[str, None] = None,
scale: int = 4, scale: int = 4,
outscale: int = 1, outscale: int = 1,
denoise: float = 0.5, denoise: float = 0.5,
faces=True, faces=True,
face_model: Union[str, None] = None,
platform: str = 'onnx', platform: str = 'onnx',
half=False half=False
) -> None: ) -> None:
self.upscale_model = upscale_model self.upscale_model = upscale_model
self.correction_model = correction_model
self.scale = scale self.scale = scale
self.outscale = outscale self.outscale = outscale
self.denoise = denoise self.denoise = denoise
self.faces = faces self.faces = faces
self.face_model = face_model
self.platform = platform self.platform = platform
self.half = half self.half = half
@ -158,16 +158,16 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image: def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN model: %s' % params.face_model) print('correcting faces with GFPGAN model: %s' % params.correction_model)
if params.face_model is None: if params.correction_model is None:
print('no face model given, skipping') print('no face model given, skipping')
return image return image
if upsampler is None: if upsampler is None:
upsampler = make_resrgan(ctx, params, tile=512) upsampler = make_resrgan(ctx, params, tile=512)
face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model)) face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))
# TODO: doesn't have a model param, not sure how to pass ONNX model # TODO: doesn't have a model param, not sure how to pass ONNX model
face_enhancer = GFPGANer( face_enhancer = GFPGANer(

View File

@ -1,7 +1,7 @@
from os import environ, path from os import environ, path
from time import time from time import time
from struct import pack from struct import pack
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from hashlib import sha256 from hashlib import sha256
@ -89,15 +89,15 @@ class Size:
self.width = width self.width = width
self.height = height self.height = height
def add_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def tojson(self) -> Dict[str, int]: def tojson(self) -> Dict[str, int]:
return { return {
'height': self.height, 'height': self.height,
'width': self.width, 'width': self.width,
} }
def with_border(self, border: Border):
return Size(border.left + self.width + border.right, border.top + self.height + border.right)
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value) return min(max(float(args.get(key, default_value)), min_value), max_value)
@ -107,6 +107,15 @@ def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, m
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]):
selected = args.get(key, values[0])
if selected in values:
return selected
else:
print('invalid selection: %s' % (selected))
return values[0]
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any): def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any):
selected = args.get(key, default) selected = args.get(key, default)
if selected in values: if selected in values: