feat(api): remove Flask app from global scope
This commit is contained in:
parent
943281feb5
commit
06c74a7a96
|
@ -1,6 +1,7 @@
|
||||||
.coverage
|
.coverage
|
||||||
coverage.xml
|
coverage.xml
|
||||||
|
|
||||||
|
*.log
|
||||||
*.swp
|
*.swp
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|
||||||
|
|
|
@ -25,4 +25,4 @@ python3 -m onnx_web.convert \
|
||||||
--token=${HF_TOKEN:-}
|
--token=${HF_TOKEN:-}
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
flask --app='onnx_web.main:main()' run --host=0.0.0.0
|
||||||
|
|
|
@ -24,4 +24,4 @@ python3 -m onnx_web.convert \
|
||||||
--token=${HF_TOKEN:-}
|
--token=${HF_TOKEN:-}
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
flask --app=onnx_web.serve run --host=0.0.0.0
|
flask --app='onnx_web.main:main()' run --host=0.0.0.0
|
||||||
|
|
|
@ -1,5 +1,10 @@
|
||||||
from . import logging
|
from . import logging
|
||||||
from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion
|
from .chain import (
|
||||||
|
correct_codeformer,
|
||||||
|
correct_gfpgan,
|
||||||
|
upscale_resrgan,
|
||||||
|
upscale_stable_diffusion,
|
||||||
|
)
|
||||||
from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline
|
from .diffusion.load import get_latents_from_seed, load_pipeline, optimize_pipeline
|
||||||
from .diffusion.run import (
|
from .diffusion.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
|
|
@ -13,3 +13,20 @@ from .source_txt2img import source_txt2img
|
||||||
from .upscale_outpaint import upscale_outpaint
|
from .upscale_outpaint import upscale_outpaint
|
||||||
from .upscale_resrgan import upscale_resrgan
|
from .upscale_resrgan import upscale_resrgan
|
||||||
from .upscale_stable_diffusion import upscale_stable_diffusion
|
from .upscale_stable_diffusion import upscale_stable_diffusion
|
||||||
|
|
||||||
|
CHAIN_STAGES = {
|
||||||
|
"blend-img2img": blend_img2img,
|
||||||
|
"blend-inpaint": blend_inpaint,
|
||||||
|
"blend-mask": blend_mask,
|
||||||
|
"correct-codeformer": correct_codeformer,
|
||||||
|
"correct-gfpgan": correct_gfpgan,
|
||||||
|
"persist-disk": persist_disk,
|
||||||
|
"persist-s3": persist_s3,
|
||||||
|
"reduce-crop": reduce_crop,
|
||||||
|
"reduce-thumbnail": reduce_thumbnail,
|
||||||
|
"source-noise": source_noise,
|
||||||
|
"source-txt2img": source_txt2img,
|
||||||
|
"upscale-outpaint": upscale_outpaint,
|
||||||
|
"upscale-resrgan": upscale_resrgan,
|
||||||
|
"upscale-stable-diffusion": upscale_stable_diffusion,
|
||||||
|
}
|
|
@ -4,9 +4,11 @@ from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
DiffusionPipeline,
|
||||||
|
OnnxRuntimeModel,
|
||||||
|
StableDiffusionPipeline,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DiffusionPipeline,
|
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
DPMSolverSinglestepScheduler,
|
DPMSolverSinglestepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
|
@ -17,15 +19,13 @@ from diffusers import (
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
OnnxRuntimeModel,
|
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
StableDiffusionPipeline,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import DEISMultistepScheduler
|
from diffusers import DEISMultistepScheduler
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .stub_scheduler import StubScheduler as DEISMultistepScheduler
|
from ..diffusion.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
||||||
|
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
@ -54,6 +54,10 @@ pipeline_schedulers = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_pipeline_schedulers():
|
||||||
|
return pipeline_schedulers
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||||
for k, v in pipeline_schedulers.items():
|
for k, v in pipeline_schedulers.items():
|
||||||
if scheduler == v or scheduler == v.__name__:
|
if scheduler == v or scheduler == v.__name__:
|
||||||
|
@ -137,13 +141,14 @@ def load_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipeline: DiffusionPipeline,
|
pipeline: DiffusionPipeline,
|
||||||
model: str,
|
model: str,
|
||||||
scheduler_type: Any,
|
scheduler_name: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
inversion: Optional[str],
|
inversion: Optional[str],
|
||||||
):
|
):
|
||||||
pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion)
|
pipe_key = (pipeline, model, device.device, device.provider, lpw, inversion)
|
||||||
scheduler_key = (scheduler_type, model)
|
scheduler_key = (scheduler_name, model)
|
||||||
|
scheduler_type = get_pipeline_schedulers()[scheduler_name]
|
||||||
|
|
||||||
cache_pipe = server.cache.get("diffusion", pipe_key)
|
cache_pipe = server.cache.get("diffusion", pipe_key)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from diffusers.utils.logging import disable_progress_bar
|
||||||
|
from flask import Flask
|
||||||
|
from flask_cors import CORS
|
||||||
|
from huggingface_hub.utils.tqdm import disable_progress_bars
|
||||||
|
|
||||||
|
from .server.api import register_api_routes
|
||||||
|
from .server.static import register_static_routes
|
||||||
|
from .server.config import get_available_platforms, load_models, load_params, load_platforms
|
||||||
|
from .server.utils import check_paths
|
||||||
|
from .server.context import ServerContext
|
||||||
|
from .server.hacks import apply_patches
|
||||||
|
from .utils import (
|
||||||
|
is_debug,
|
||||||
|
)
|
||||||
|
from .worker import DevicePoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
context = ServerContext.from_environ()
|
||||||
|
apply_patches(context)
|
||||||
|
check_paths(context)
|
||||||
|
load_models(context)
|
||||||
|
load_params(context)
|
||||||
|
load_platforms(context)
|
||||||
|
|
||||||
|
if is_debug():
|
||||||
|
gc.set_debug(gc.DEBUG_STATS)
|
||||||
|
|
||||||
|
if not context.show_progress:
|
||||||
|
disable_progress_bar()
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app, origins=context.cors_origin)
|
||||||
|
|
||||||
|
# any is a fake device, should not be in the pool
|
||||||
|
pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"])
|
||||||
|
|
||||||
|
# register routes
|
||||||
|
register_static_routes(app, context, pool)
|
||||||
|
register_api_routes(app, context, pool)
|
||||||
|
|
||||||
|
return app #, context, pool
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# app, context, pool = main()
|
||||||
|
app = main()
|
||||||
|
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||||
|
# pool.join()
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .diffusion.load import get_scheduler_name
|
|
||||||
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
from .params import Border, ImageParams, Param, Size, UpscaleParams
|
||||||
from .server import ServerContext
|
from .server import ServerContext
|
||||||
from .utils import base_join
|
from .utils import base_join
|
||||||
|
@ -44,7 +43,7 @@ def json_params(
|
||||||
}
|
}
|
||||||
|
|
||||||
json["params"]["model"] = path.basename(params.model)
|
json["params"]["model"] = path.basename(params.model)
|
||||||
json["params"]["scheduler"] = get_scheduler_name(params.scheduler)
|
json["params"]["scheduler"] = params.scheduler
|
||||||
|
|
||||||
if border is not None:
|
if border is not None:
|
||||||
json["border"] = border.tojson()
|
json["border"] = border.tojson()
|
||||||
|
@ -71,7 +70,7 @@ def make_output_name(
|
||||||
|
|
||||||
hash_value(sha, mode)
|
hash_value(sha, mode)
|
||||||
hash_value(sha, params.model)
|
hash_value(sha, params.model)
|
||||||
hash_value(sha, params.scheduler.__name__)
|
hash_value(sha, params.scheduler)
|
||||||
hash_value(sha, params.prompt)
|
hash_value(sha, params.prompt)
|
||||||
hash_value(sha, params.negative_prompt)
|
hash_value(sha, params.negative_prompt)
|
||||||
hash_value(sha, params.cfg)
|
hash_value(sha, params.cfg)
|
||||||
|
|
|
@ -148,7 +148,7 @@ class ImageParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
scheduler: Any,
|
scheduler: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
cfg: float,
|
cfg: float,
|
||||||
steps: int,
|
steps: int,
|
||||||
|
@ -174,7 +174,7 @@ class ImageParams:
|
||||||
def tojson(self) -> Dict[str, Optional[Param]]:
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||||
return {
|
return {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"scheduler": self.scheduler.__name__,
|
"scheduler": self.scheduler,
|
||||||
"prompt": self.prompt,
|
"prompt": self.prompt,
|
||||||
"negative_prompt": self.negative_prompt,
|
"negative_prompt": self.negative_prompt,
|
||||||
"cfg": self.cfg,
|
"cfg": self.cfg,
|
||||||
|
|
|
@ -1,881 +0,0 @@
|
||||||
import gc
|
|
||||||
from functools import cmp_to_key
|
|
||||||
from glob import glob
|
|
||||||
from io import BytesIO
|
|
||||||
from logging import getLogger
|
|
||||||
from os import makedirs, path
|
|
||||||
from typing import Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from diffusers.utils.logging import disable_progress_bar
|
|
||||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
|
||||||
from flask_cors import CORS
|
|
||||||
from huggingface_hub.utils.tqdm import disable_progress_bars
|
|
||||||
from jsonschema import validate
|
|
||||||
from onnxruntime import get_available_providers
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from .chain import (
|
|
||||||
ChainPipeline,
|
|
||||||
blend_img2img,
|
|
||||||
blend_inpaint,
|
|
||||||
correct_codeformer,
|
|
||||||
correct_gfpgan,
|
|
||||||
persist_disk,
|
|
||||||
persist_s3,
|
|
||||||
reduce_crop,
|
|
||||||
reduce_thumbnail,
|
|
||||||
source_noise,
|
|
||||||
source_txt2img,
|
|
||||||
upscale_outpaint,
|
|
||||||
upscale_resrgan,
|
|
||||||
upscale_stable_diffusion,
|
|
||||||
)
|
|
||||||
from .diffusion.load import pipeline_schedulers
|
|
||||||
from .diffusion.run import (
|
|
||||||
run_blend_pipeline,
|
|
||||||
run_img2img_pipeline,
|
|
||||||
run_inpaint_pipeline,
|
|
||||||
run_txt2img_pipeline,
|
|
||||||
run_upscale_pipeline,
|
|
||||||
)
|
|
||||||
from .image import ( # mask filters; noise sources
|
|
||||||
mask_filter_gaussian_multiply,
|
|
||||||
mask_filter_gaussian_screen,
|
|
||||||
mask_filter_none,
|
|
||||||
noise_source_fill_edge,
|
|
||||||
noise_source_fill_mask,
|
|
||||||
noise_source_gaussian,
|
|
||||||
noise_source_histogram,
|
|
||||||
noise_source_normal,
|
|
||||||
noise_source_uniform,
|
|
||||||
valid_image,
|
|
||||||
)
|
|
||||||
from .output import json_params, make_output_name
|
|
||||||
from .params import (
|
|
||||||
Border,
|
|
||||||
DeviceParams,
|
|
||||||
ImageParams,
|
|
||||||
Size,
|
|
||||||
StageParams,
|
|
||||||
TileOrder,
|
|
||||||
UpscaleParams,
|
|
||||||
)
|
|
||||||
from .server import ServerContext, apply_patches
|
|
||||||
from .transformers import run_txt2txt_pipeline
|
|
||||||
from .utils import (
|
|
||||||
base_join,
|
|
||||||
get_and_clamp_float,
|
|
||||||
get_and_clamp_int,
|
|
||||||
get_from_list,
|
|
||||||
get_from_map,
|
|
||||||
get_not_empty,
|
|
||||||
get_size,
|
|
||||||
is_debug,
|
|
||||||
)
|
|
||||||
from .worker import DevicePoolExecutor
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
|
||||||
|
|
||||||
# config caching
|
|
||||||
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
|
|
||||||
|
|
||||||
# pipeline params
|
|
||||||
platform_providers = {
|
|
||||||
"cpu": "CPUExecutionProvider",
|
|
||||||
"cuda": "CUDAExecutionProvider",
|
|
||||||
"directml": "DmlExecutionProvider",
|
|
||||||
"rocm": "ROCMExecutionProvider",
|
|
||||||
}
|
|
||||||
|
|
||||||
noise_sources = {
|
|
||||||
"fill-edge": noise_source_fill_edge,
|
|
||||||
"fill-mask": noise_source_fill_mask,
|
|
||||||
"gaussian": noise_source_gaussian,
|
|
||||||
"histogram": noise_source_histogram,
|
|
||||||
"normal": noise_source_normal,
|
|
||||||
"uniform": noise_source_uniform,
|
|
||||||
}
|
|
||||||
mask_filters = {
|
|
||||||
"none": mask_filter_none,
|
|
||||||
"gaussian-multiply": mask_filter_gaussian_multiply,
|
|
||||||
"gaussian-screen": mask_filter_gaussian_screen,
|
|
||||||
}
|
|
||||||
chain_stages = {
|
|
||||||
"blend-img2img": blend_img2img,
|
|
||||||
"blend-inpaint": blend_inpaint,
|
|
||||||
"correct-codeformer": correct_codeformer,
|
|
||||||
"correct-gfpgan": correct_gfpgan,
|
|
||||||
"persist-disk": persist_disk,
|
|
||||||
"persist-s3": persist_s3,
|
|
||||||
"reduce-crop": reduce_crop,
|
|
||||||
"reduce-thumbnail": reduce_thumbnail,
|
|
||||||
"source-noise": source_noise,
|
|
||||||
"source-txt2img": source_txt2img,
|
|
||||||
"upscale-outpaint": upscale_outpaint,
|
|
||||||
"upscale-resrgan": upscale_resrgan,
|
|
||||||
"upscale-stable-diffusion": upscale_stable_diffusion,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Available ORT providers
|
|
||||||
available_platforms: List[DeviceParams] = []
|
|
||||||
|
|
||||||
# loaded from model_path
|
|
||||||
correction_models: List[str] = []
|
|
||||||
diffusion_models: List[str] = []
|
|
||||||
inversion_models: List[str] = []
|
|
||||||
upscaling_models: List[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_value(key: str, subkey: str = "default", default=None):
|
|
||||||
return config_params.get(key, {}).get(subkey, default)
|
|
||||||
|
|
||||||
|
|
||||||
def url_from_rule(rule) -> str:
|
|
||||||
options = {}
|
|
||||||
for arg in rule.arguments:
|
|
||||||
options[arg] = ":%s" % (arg)
|
|
||||||
|
|
||||||
return url_for(rule.endpoint, **options)
|
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
|
||||||
user = request.remote_addr
|
|
||||||
|
|
||||||
# platform stuff
|
|
||||||
device = None
|
|
||||||
device_name = request.args.get("platform")
|
|
||||||
|
|
||||||
if device_name is not None and device_name != "any":
|
|
||||||
for platform in available_platforms:
|
|
||||||
if platform.device == device_name:
|
|
||||||
device = platform
|
|
||||||
|
|
||||||
# pipeline stuff
|
|
||||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
|
||||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
|
||||||
model_path = get_model_path(model)
|
|
||||||
scheduler = get_from_map(
|
|
||||||
request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler")
|
|
||||||
)
|
|
||||||
|
|
||||||
inversion = request.args.get("inversion", None)
|
|
||||||
inversion_path = None
|
|
||||||
if inversion is not None and inversion.strip() != "":
|
|
||||||
inversion_path = get_model_path(inversion)
|
|
||||||
|
|
||||||
# image params
|
|
||||||
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
|
|
||||||
negative_prompt = request.args.get("negativePrompt", None)
|
|
||||||
|
|
||||||
if negative_prompt is not None and negative_prompt.strip() == "":
|
|
||||||
negative_prompt = None
|
|
||||||
|
|
||||||
batch = get_and_clamp_int(
|
|
||||||
request.args,
|
|
||||||
"batch",
|
|
||||||
get_config_value("batch"),
|
|
||||||
get_config_value("batch", "max"),
|
|
||||||
get_config_value("batch", "min"),
|
|
||||||
)
|
|
||||||
cfg = get_and_clamp_float(
|
|
||||||
request.args,
|
|
||||||
"cfg",
|
|
||||||
get_config_value("cfg"),
|
|
||||||
get_config_value("cfg", "max"),
|
|
||||||
get_config_value("cfg", "min"),
|
|
||||||
)
|
|
||||||
eta = get_and_clamp_float(
|
|
||||||
request.args,
|
|
||||||
"eta",
|
|
||||||
get_config_value("eta"),
|
|
||||||
get_config_value("eta", "max"),
|
|
||||||
get_config_value("eta", "min"),
|
|
||||||
)
|
|
||||||
steps = get_and_clamp_int(
|
|
||||||
request.args,
|
|
||||||
"steps",
|
|
||||||
get_config_value("steps"),
|
|
||||||
get_config_value("steps", "max"),
|
|
||||||
get_config_value("steps", "min"),
|
|
||||||
)
|
|
||||||
height = get_and_clamp_int(
|
|
||||||
request.args,
|
|
||||||
"height",
|
|
||||||
get_config_value("height"),
|
|
||||||
get_config_value("height", "max"),
|
|
||||||
get_config_value("height", "min"),
|
|
||||||
)
|
|
||||||
width = get_and_clamp_int(
|
|
||||||
request.args,
|
|
||||||
"width",
|
|
||||||
get_config_value("width"),
|
|
||||||
get_config_value("width", "max"),
|
|
||||||
get_config_value("width", "min"),
|
|
||||||
)
|
|
||||||
|
|
||||||
seed = int(request.args.get("seed", -1))
|
|
||||||
if seed == -1:
|
|
||||||
# this one can safely use np.random because it produces a single value
|
|
||||||
seed = np.random.randint(np.iinfo(np.int32).max)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
|
||||||
user,
|
|
||||||
steps,
|
|
||||||
scheduler.__name__,
|
|
||||||
model_path,
|
|
||||||
device or "any device",
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfg,
|
|
||||||
seed,
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
params = ImageParams(
|
|
||||||
model_path,
|
|
||||||
scheduler,
|
|
||||||
prompt,
|
|
||||||
cfg,
|
|
||||||
steps,
|
|
||||||
seed,
|
|
||||||
eta=eta,
|
|
||||||
lpw=lpw,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
batch=batch,
|
|
||||||
inversion=inversion_path,
|
|
||||||
)
|
|
||||||
size = Size(width, height)
|
|
||||||
return (device, params, size)
|
|
||||||
|
|
||||||
|
|
||||||
def border_from_request() -> Border:
|
|
||||||
left = get_and_clamp_int(
|
|
||||||
request.args, "left", 0, get_config_value("width", "max"), 0
|
|
||||||
)
|
|
||||||
right = get_and_clamp_int(
|
|
||||||
request.args, "right", 0, get_config_value("width", "max"), 0
|
|
||||||
)
|
|
||||||
top = get_and_clamp_int(
|
|
||||||
request.args, "top", 0, get_config_value("height", "max"), 0
|
|
||||||
)
|
|
||||||
bottom = get_and_clamp_int(
|
|
||||||
request.args, "bottom", 0, get_config_value("height", "max"), 0
|
|
||||||
)
|
|
||||||
|
|
||||||
return Border(left, right, top, bottom)
|
|
||||||
|
|
||||||
|
|
||||||
def upscale_from_request() -> UpscaleParams:
|
|
||||||
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
|
|
||||||
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
|
|
||||||
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
|
|
||||||
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
|
|
||||||
correction = get_from_list(request.args, "correction", correction_models)
|
|
||||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
|
||||||
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
|
|
||||||
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
|
||||||
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
|
||||||
|
|
||||||
return UpscaleParams(
|
|
||||||
upscaling,
|
|
||||||
correction_model=correction,
|
|
||||||
denoise=denoise,
|
|
||||||
faces=faces,
|
|
||||||
face_outscale=face_outscale,
|
|
||||||
face_strength=face_strength,
|
|
||||||
format="onnx",
|
|
||||||
outscale=outscale,
|
|
||||||
scale=scale,
|
|
||||||
upscale_order=upscale_order,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_paths(context: ServerContext) -> None:
|
|
||||||
if not path.exists(context.model_path):
|
|
||||||
raise RuntimeError("model path must exist")
|
|
||||||
|
|
||||||
if not path.exists(context.output_path):
|
|
||||||
makedirs(context.output_path)
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_name(model: str) -> str:
|
|
||||||
base = path.basename(model)
|
|
||||||
(file, _ext) = path.splitext(base)
|
|
||||||
return file
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(context: ServerContext) -> None:
|
|
||||||
global correction_models
|
|
||||||
global diffusion_models
|
|
||||||
global inversion_models
|
|
||||||
global upscaling_models
|
|
||||||
|
|
||||||
diffusion_models = [
|
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
|
|
||||||
]
|
|
||||||
diffusion_models.extend(
|
|
||||||
[
|
|
||||||
get_model_name(f)
|
|
||||||
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
diffusion_models = list(set(diffusion_models))
|
|
||||||
diffusion_models.sort()
|
|
||||||
|
|
||||||
correction_models = [
|
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
|
||||||
]
|
|
||||||
correction_models = list(set(correction_models))
|
|
||||||
correction_models.sort()
|
|
||||||
|
|
||||||
inversion_models = [
|
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
|
||||||
]
|
|
||||||
inversion_models = list(set(inversion_models))
|
|
||||||
inversion_models.sort()
|
|
||||||
|
|
||||||
upscaling_models = [
|
|
||||||
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
|
||||||
]
|
|
||||||
upscaling_models = list(set(upscaling_models))
|
|
||||||
upscaling_models.sort()
|
|
||||||
|
|
||||||
|
|
||||||
def load_params(context: ServerContext) -> None:
|
|
||||||
global config_params
|
|
||||||
params_file = path.join(context.params_path, "params.json")
|
|
||||||
with open(params_file, "r") as f:
|
|
||||||
config_params = yaml.safe_load(f)
|
|
||||||
|
|
||||||
if "platform" in config_params and context.default_platform is not None:
|
|
||||||
logger.info(
|
|
||||||
"Overriding default platform from environment: %s",
|
|
||||||
context.default_platform,
|
|
||||||
)
|
|
||||||
config_platform = config_params.get("platform", {})
|
|
||||||
config_platform["default"] = context.default_platform
|
|
||||||
|
|
||||||
|
|
||||||
def load_platforms(context: ServerContext) -> None:
|
|
||||||
global available_platforms
|
|
||||||
|
|
||||||
providers = list(get_available_providers())
|
|
||||||
|
|
||||||
for potential in platform_providers:
|
|
||||||
if (
|
|
||||||
platform_providers[potential] in providers
|
|
||||||
and potential not in context.block_platforms
|
|
||||||
):
|
|
||||||
if potential == "cuda":
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
available_platforms.append(
|
|
||||||
DeviceParams(
|
|
||||||
potential,
|
|
||||||
platform_providers[potential],
|
|
||||||
{
|
|
||||||
"device_id": i,
|
|
||||||
},
|
|
||||||
context.optimizations,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
available_platforms.append(
|
|
||||||
DeviceParams(
|
|
||||||
potential,
|
|
||||||
platform_providers[potential],
|
|
||||||
None,
|
|
||||||
context.optimizations,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if context.any_platform:
|
|
||||||
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
|
||||||
available_platforms.append(
|
|
||||||
DeviceParams(
|
|
||||||
"any",
|
|
||||||
platform_providers["cpu"],
|
|
||||||
None,
|
|
||||||
context.optimizations,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure CPU is last on the list
|
|
||||||
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
|
|
||||||
if a.device == b.device:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# any should be first, if it's available
|
|
||||||
if a.device == "any":
|
|
||||||
return -1
|
|
||||||
|
|
||||||
# cpu should be last, if it's available
|
|
||||||
if a.device == "cpu":
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return -1
|
|
||||||
|
|
||||||
available_platforms = sorted(
|
|
||||||
available_platforms, key=cmp_to_key(any_first_cpu_last)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"available acceleration platforms: %s",
|
|
||||||
", ".join([str(p) for p in available_platforms]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
|
||||||
apply_patches(context)
|
|
||||||
check_paths(context)
|
|
||||||
load_models(context)
|
|
||||||
load_params(context)
|
|
||||||
load_platforms(context)
|
|
||||||
|
|
||||||
if not context.show_progress:
|
|
||||||
disable_progress_bar()
|
|
||||||
disable_progress_bars()
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
CORS(app, origins=context.cors_origin)
|
|
||||||
|
|
||||||
# any is a fake device, should not be in the pool
|
|
||||||
executor = DevicePoolExecutor([p for p in available_platforms if p.device != "any"])
|
|
||||||
|
|
||||||
if is_debug():
|
|
||||||
gc.set_debug(gc.DEBUG_STATS)
|
|
||||||
|
|
||||||
|
|
||||||
def ready_reply(ready: bool, progress: int = 0):
|
|
||||||
return jsonify(
|
|
||||||
{
|
|
||||||
"progress": progress,
|
|
||||||
"ready": ready,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def error_reply(err: str):
|
|
||||||
response = make_response(
|
|
||||||
jsonify(
|
|
||||||
{
|
|
||||||
"error": err,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
response.status_code = 400
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model: str):
|
|
||||||
return base_join(context.model_path, model)
|
|
||||||
|
|
||||||
|
|
||||||
def serve_bundle_file(filename="index.html"):
|
|
||||||
return send_from_directory(path.join("..", context.bundle_path), filename)
|
|
||||||
|
|
||||||
|
|
||||||
# routes
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
|
||||||
def index():
|
|
||||||
return serve_bundle_file()
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/<path:filename>")
|
|
||||||
def index_path(filename):
|
|
||||||
return serve_bundle_file(filename)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api")
|
|
||||||
def introspect():
|
|
||||||
return {
|
|
||||||
"name": "onnx-web",
|
|
||||||
"routes": [
|
|
||||||
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
|
|
||||||
for rule in app.url_map.iter_rules()
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/settings/masks")
|
|
||||||
def list_mask_filters():
|
|
||||||
return jsonify(list(mask_filters.keys()))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/settings/models")
|
|
||||||
def list_models():
|
|
||||||
return jsonify(
|
|
||||||
{
|
|
||||||
"correction": correction_models,
|
|
||||||
"diffusion": diffusion_models,
|
|
||||||
"inversion": inversion_models,
|
|
||||||
"upscaling": upscaling_models,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/settings/noises")
|
|
||||||
def list_noise_sources():
|
|
||||||
return jsonify(list(noise_sources.keys()))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/settings/params")
|
|
||||||
def list_params():
|
|
||||||
return jsonify(config_params)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/settings/platforms")
|
|
||||||
def list_platforms():
|
|
||||||
return jsonify([p.device for p in available_platforms])
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/settings/schedulers")
|
|
||||||
def list_schedulers():
|
|
||||||
return jsonify(list(pipeline_schedulers.keys()))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/img2img", methods=["POST"])
|
|
||||||
def img2img():
|
|
||||||
if "source" not in request.files:
|
|
||||||
return error_reply("source image is required")
|
|
||||||
|
|
||||||
source_file = request.files.get("source")
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
upscale = upscale_from_request()
|
|
||||||
|
|
||||||
strength = get_and_clamp_float(
|
|
||||||
request.args,
|
|
||||||
"strength",
|
|
||||||
get_config_value("strength"),
|
|
||||||
get_config_value("strength", "max"),
|
|
||||||
get_config_value("strength", "min"),
|
|
||||||
)
|
|
||||||
|
|
||||||
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
|
||||||
job_name = output[0]
|
|
||||||
logger.info("img2img job queued for: %s", job_name)
|
|
||||||
|
|
||||||
source = valid_image(source, min_dims=size, max_dims=size)
|
|
||||||
executor.submit(
|
|
||||||
job_name,
|
|
||||||
run_img2img_pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
output,
|
|
||||||
upscale,
|
|
||||||
source,
|
|
||||||
strength,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/txt2img", methods=["POST"])
|
|
||||||
def txt2img():
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
upscale = upscale_from_request()
|
|
||||||
|
|
||||||
output = make_output_name(context, "txt2img", params, size)
|
|
||||||
job_name = output[0]
|
|
||||||
logger.info("txt2img job queued for: %s", job_name)
|
|
||||||
|
|
||||||
executor.submit(
|
|
||||||
job_name,
|
|
||||||
run_txt2img_pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
output,
|
|
||||||
upscale,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/inpaint", methods=["POST"])
|
|
||||||
def inpaint():
|
|
||||||
if "source" not in request.files:
|
|
||||||
return error_reply("source image is required")
|
|
||||||
|
|
||||||
if "mask" not in request.files:
|
|
||||||
return error_reply("mask image is required")
|
|
||||||
|
|
||||||
source_file = request.files.get("source")
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
||||||
|
|
||||||
mask_file = request.files.get("mask")
|
|
||||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
expand = border_from_request()
|
|
||||||
upscale = upscale_from_request()
|
|
||||||
|
|
||||||
fill_color = get_not_empty(request.args, "fillColor", "white")
|
|
||||||
mask_filter = get_from_map(request.args, "filter", mask_filters, "none")
|
|
||||||
noise_source = get_from_map(request.args, "noise", noise_sources, "histogram")
|
|
||||||
tile_order = get_from_list(
|
|
||||||
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
|
|
||||||
)
|
|
||||||
|
|
||||||
output = make_output_name(
|
|
||||||
context,
|
|
||||||
"inpaint",
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
extras=(
|
|
||||||
expand.left,
|
|
||||||
expand.right,
|
|
||||||
expand.top,
|
|
||||||
expand.bottom,
|
|
||||||
mask_filter.__name__,
|
|
||||||
noise_source.__name__,
|
|
||||||
fill_color,
|
|
||||||
tile_order,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
job_name = output[0]
|
|
||||||
logger.info("inpaint job queued for: %s", job_name)
|
|
||||||
|
|
||||||
source = valid_image(source, min_dims=size, max_dims=size)
|
|
||||||
mask = valid_image(mask, min_dims=size, max_dims=size)
|
|
||||||
executor.submit(
|
|
||||||
job_name,
|
|
||||||
run_inpaint_pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
output,
|
|
||||||
upscale,
|
|
||||||
source,
|
|
||||||
mask,
|
|
||||||
expand,
|
|
||||||
noise_source,
|
|
||||||
mask_filter,
|
|
||||||
fill_color,
|
|
||||||
tile_order,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/upscale", methods=["POST"])
|
|
||||||
def upscale():
|
|
||||||
if "source" not in request.files:
|
|
||||||
return error_reply("source image is required")
|
|
||||||
|
|
||||||
source_file = request.files.get("source")
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
upscale = upscale_from_request()
|
|
||||||
|
|
||||||
output = make_output_name(context, "upscale", params, size)
|
|
||||||
job_name = output[0]
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
|
||||||
|
|
||||||
source = valid_image(source, min_dims=size, max_dims=size)
|
|
||||||
executor.submit(
|
|
||||||
job_name,
|
|
||||||
run_upscale_pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
output,
|
|
||||||
upscale,
|
|
||||||
source,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/chain", methods=["POST"])
|
|
||||||
def chain():
|
|
||||||
logger.debug(
|
|
||||||
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
|
||||||
)
|
|
||||||
body = request.form.get("chain") or request.files.get("chain")
|
|
||||||
if body is None:
|
|
||||||
return error_reply("chain pipeline must have a body")
|
|
||||||
|
|
||||||
data = yaml.safe_load(body)
|
|
||||||
with open("./schemas/chain.yaml", "r") as f:
|
|
||||||
schema = yaml.safe_load(f.read())
|
|
||||||
|
|
||||||
logger.debug("validating chain request: %s against %s", data, schema)
|
|
||||||
validate(data, schema)
|
|
||||||
|
|
||||||
# get defaults from the regular parameters
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
output = make_output_name(context, "chain", params, size)
|
|
||||||
job_name = output[0]
|
|
||||||
|
|
||||||
pipeline = ChainPipeline()
|
|
||||||
for stage_data in data.get("stages", []):
|
|
||||||
callback = chain_stages[stage_data.get("type")]
|
|
||||||
kwargs = stage_data.get("params", {})
|
|
||||||
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
|
||||||
|
|
||||||
stage = StageParams(
|
|
||||||
stage_data.get("name", callback.__name__),
|
|
||||||
tile_size=get_size(kwargs.get("tile_size")),
|
|
||||||
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
|
||||||
)
|
|
||||||
|
|
||||||
if "border" in kwargs:
|
|
||||||
border = Border.even(int(kwargs.get("border")))
|
|
||||||
kwargs["border"] = border
|
|
||||||
|
|
||||||
if "upscale" in kwargs:
|
|
||||||
upscale = UpscaleParams(kwargs.get("upscale"))
|
|
||||||
kwargs["upscale"] = upscale
|
|
||||||
|
|
||||||
stage_source_name = "source:%s" % (stage.name)
|
|
||||||
stage_mask_name = "mask:%s" % (stage.name)
|
|
||||||
|
|
||||||
if stage_source_name in request.files:
|
|
||||||
logger.debug(
|
|
||||||
"loading source image %s for pipeline stage %s",
|
|
||||||
stage_source_name,
|
|
||||||
stage.name,
|
|
||||||
)
|
|
||||||
source_file = request.files.get(stage_source_name)
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
|
||||||
source = valid_image(source, max_dims=(size.width, size.height))
|
|
||||||
kwargs["stage_source"] = source
|
|
||||||
|
|
||||||
if stage_mask_name in request.files:
|
|
||||||
logger.debug(
|
|
||||||
"loading mask image %s for pipeline stage %s",
|
|
||||||
stage_mask_name,
|
|
||||||
stage.name,
|
|
||||||
)
|
|
||||||
mask_file = request.files.get(stage_mask_name)
|
|
||||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
|
||||||
mask = valid_image(mask, max_dims=(size.width, size.height))
|
|
||||||
kwargs["stage_mask"] = mask
|
|
||||||
|
|
||||||
pipeline.append((callback, stage, kwargs))
|
|
||||||
|
|
||||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
|
||||||
|
|
||||||
# build and run chain pipeline
|
|
||||||
empty_source = Image.new("RGB", (size.width, size.height))
|
|
||||||
executor.submit(
|
|
||||||
job_name,
|
|
||||||
pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
empty_source,
|
|
||||||
output=output[0],
|
|
||||||
size=size,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/blend", methods=["POST"])
|
|
||||||
def blend():
|
|
||||||
if "mask" not in request.files:
|
|
||||||
return error_reply("mask image is required")
|
|
||||||
|
|
||||||
mask_file = request.files.get("mask")
|
|
||||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
|
||||||
mask = valid_image(mask)
|
|
||||||
|
|
||||||
max_sources = 2
|
|
||||||
sources = []
|
|
||||||
|
|
||||||
for i in range(max_sources):
|
|
||||||
source_file = request.files.get("source:%s" % (i))
|
|
||||||
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
|
||||||
source = valid_image(source, mask.size, mask.size)
|
|
||||||
sources.append(source)
|
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
upscale = upscale_from_request()
|
|
||||||
|
|
||||||
output = make_output_name(context, "upscale", params, size)
|
|
||||||
job_name = output[0]
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
|
||||||
|
|
||||||
executor.submit(
|
|
||||||
job_name,
|
|
||||||
run_blend_pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
output,
|
|
||||||
upscale,
|
|
||||||
sources,
|
|
||||||
mask,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/txt2txt", methods=["POST"])
|
|
||||||
def txt2txt():
|
|
||||||
device, params, size = pipeline_from_request()
|
|
||||||
|
|
||||||
output = make_output_name(context, "upscale", params, size)
|
|
||||||
logger.info("upscale job queued for: %s", output)
|
|
||||||
|
|
||||||
executor.submit(
|
|
||||||
output,
|
|
||||||
run_txt2txt_pipeline,
|
|
||||||
context,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
output,
|
|
||||||
needs_device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(json_params(output, params, size))
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/cancel", methods=["PUT"])
|
|
||||||
def cancel():
|
|
||||||
output_file = request.args.get("output", None)
|
|
||||||
|
|
||||||
cancel = executor.cancel(output_file)
|
|
||||||
|
|
||||||
return ready_reply(cancel)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/ready")
|
|
||||||
def ready():
|
|
||||||
output_file = request.args.get("output", None)
|
|
||||||
|
|
||||||
done, progress = executor.done(output_file)
|
|
||||||
|
|
||||||
if done is None:
|
|
||||||
output = base_join(context.output_path, output_file)
|
|
||||||
if path.exists(output):
|
|
||||||
return ready_reply(True)
|
|
||||||
|
|
||||||
return ready_reply(done, progress=progress)
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/status")
|
|
||||||
def status():
|
|
||||||
return jsonify(executor.status())
|
|
||||||
|
|
||||||
|
|
||||||
@app.route("/output/<path:filename>")
|
|
||||||
def output(filename: str):
|
|
||||||
return send_from_directory(
|
|
||||||
path.join("..", context.output_path), filename, as_attachment=False
|
|
||||||
)
|
|
|
@ -0,0 +1,477 @@
|
||||||
|
from io import BytesIO
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from flask import Flask, jsonify, make_response, request, url_for
|
||||||
|
from jsonschema import validate
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .context import ServerContext
|
||||||
|
from .utils import wrap_route
|
||||||
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
get_available_platforms,
|
||||||
|
get_config_params,
|
||||||
|
get_config_value,
|
||||||
|
get_correction_models,
|
||||||
|
get_diffusion_models,
|
||||||
|
get_inversion_models,
|
||||||
|
get_mask_filters,
|
||||||
|
get_noise_sources,
|
||||||
|
get_upscaling_models,
|
||||||
|
)
|
||||||
|
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
||||||
|
|
||||||
|
from ..chain import (
|
||||||
|
CHAIN_STAGES,
|
||||||
|
ChainPipeline,
|
||||||
|
)
|
||||||
|
from ..diffusion.load import get_pipeline_schedulers
|
||||||
|
from ..diffusion.run import (
|
||||||
|
run_blend_pipeline,
|
||||||
|
run_img2img_pipeline,
|
||||||
|
run_inpaint_pipeline,
|
||||||
|
run_txt2img_pipeline,
|
||||||
|
run_upscale_pipeline,
|
||||||
|
)
|
||||||
|
from ..image import ( # mask filters; noise sources
|
||||||
|
valid_image,
|
||||||
|
)
|
||||||
|
from ..output import json_params, make_output_name
|
||||||
|
from ..params import (
|
||||||
|
Border,
|
||||||
|
StageParams,
|
||||||
|
TileOrder,
|
||||||
|
UpscaleParams,
|
||||||
|
)
|
||||||
|
from ..transformers import run_txt2txt_pipeline
|
||||||
|
from ..utils import (
|
||||||
|
base_join,
|
||||||
|
get_and_clamp_float,
|
||||||
|
get_and_clamp_int,
|
||||||
|
get_from_list,
|
||||||
|
get_from_map,
|
||||||
|
get_not_empty,
|
||||||
|
get_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def ready_reply(ready: bool, progress: int = 0):
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"progress": progress,
|
||||||
|
"ready": ready,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def error_reply(err: str):
|
||||||
|
response = make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"error": err,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def url_from_rule(rule) -> str:
|
||||||
|
options = {}
|
||||||
|
for arg in rule.arguments:
|
||||||
|
options[arg] = ":%s" % (arg)
|
||||||
|
|
||||||
|
return url_for(rule.endpoint, **options)
|
||||||
|
|
||||||
|
|
||||||
|
def introspect(context: ServerContext, app: Flask):
|
||||||
|
return {
|
||||||
|
"name": "onnx-web",
|
||||||
|
"routes": [
|
||||||
|
{"path": url_from_rule(rule), "methods": list(rule.methods).sort()}
|
||||||
|
for rule in app.url_map.iter_rules()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def list_mask_filters(context: ServerContext):
|
||||||
|
return jsonify(list(get_mask_filters().keys()))
|
||||||
|
|
||||||
|
|
||||||
|
def list_models(context: ServerContext):
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"correction": get_correction_models(),
|
||||||
|
"diffusion": get_diffusion_models(),
|
||||||
|
"inversion": get_inversion_models(),
|
||||||
|
"upscaling": get_upscaling_models(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_noise_sources(context: ServerContext):
|
||||||
|
return jsonify(list(get_noise_sources().keys()))
|
||||||
|
|
||||||
|
|
||||||
|
def list_params(context: ServerContext):
|
||||||
|
return jsonify(get_config_params())
|
||||||
|
|
||||||
|
|
||||||
|
def list_platforms(context: ServerContext):
|
||||||
|
return jsonify([p.device for p in get_available_platforms()])
|
||||||
|
|
||||||
|
|
||||||
|
def list_schedulers(context: ServerContext):
|
||||||
|
return jsonify(list(get_pipeline_schedulers().keys()))
|
||||||
|
|
||||||
|
|
||||||
|
def img2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
if "source" not in request.files:
|
||||||
|
return error_reply("source image is required")
|
||||||
|
|
||||||
|
source_file = request.files.get("source")
|
||||||
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
|
strength = get_and_clamp_float(
|
||||||
|
request.args,
|
||||||
|
"strength",
|
||||||
|
get_config_value("strength"),
|
||||||
|
get_config_value("strength", "max"),
|
||||||
|
get_config_value("strength", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = make_output_name(context, "img2img", params, size, extras=(strength,))
|
||||||
|
job_name = output[0]
|
||||||
|
logger.info("img2img job queued for: %s", job_name)
|
||||||
|
|
||||||
|
source = valid_image(source, min_dims=size, max_dims=size)
|
||||||
|
pool.submit(
|
||||||
|
job_name,
|
||||||
|
run_img2img_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
source,
|
||||||
|
strength,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
|
output = make_output_name(context, "txt2img", params, size)
|
||||||
|
job_name = output[0]
|
||||||
|
logger.info("txt2img job queued for: %s", job_name)
|
||||||
|
|
||||||
|
pool.submit(
|
||||||
|
job_name,
|
||||||
|
run_txt2img_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
|
def inpaint(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
if "source" not in request.files:
|
||||||
|
return error_reply("source image is required")
|
||||||
|
|
||||||
|
if "mask" not in request.files:
|
||||||
|
return error_reply("mask image is required")
|
||||||
|
|
||||||
|
source_file = request.files.get("source")
|
||||||
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
|
mask_file = request.files.get("mask")
|
||||||
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||||
|
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
expand = border_from_request()
|
||||||
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
|
fill_color = get_not_empty(request.args, "fillColor", "white")
|
||||||
|
mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none")
|
||||||
|
noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram")
|
||||||
|
tile_order = get_from_list(
|
||||||
|
request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]
|
||||||
|
)
|
||||||
|
|
||||||
|
output = make_output_name(
|
||||||
|
context,
|
||||||
|
"inpaint",
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
extras=(
|
||||||
|
expand.left,
|
||||||
|
expand.right,
|
||||||
|
expand.top,
|
||||||
|
expand.bottom,
|
||||||
|
mask_filter.__name__,
|
||||||
|
noise_source.__name__,
|
||||||
|
fill_color,
|
||||||
|
tile_order,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
job_name = output[0]
|
||||||
|
logger.info("inpaint job queued for: %s", job_name)
|
||||||
|
|
||||||
|
source = valid_image(source, min_dims=size, max_dims=size)
|
||||||
|
mask = valid_image(mask, min_dims=size, max_dims=size)
|
||||||
|
pool.submit(
|
||||||
|
job_name,
|
||||||
|
run_inpaint_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
source,
|
||||||
|
mask,
|
||||||
|
expand,
|
||||||
|
noise_source,
|
||||||
|
mask_filter,
|
||||||
|
fill_color,
|
||||||
|
tile_order,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size, upscale=upscale, border=expand))
|
||||||
|
|
||||||
|
|
||||||
|
def upscale(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
if "source" not in request.files:
|
||||||
|
return error_reply("source image is required")
|
||||||
|
|
||||||
|
source_file = request.files.get("source")
|
||||||
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
|
output = make_output_name(context, "upscale", params, size)
|
||||||
|
job_name = output[0]
|
||||||
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
|
source = valid_image(source, min_dims=size, max_dims=size)
|
||||||
|
pool.submit(
|
||||||
|
job_name,
|
||||||
|
run_upscale_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
source,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
|
def chain(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
logger.debug(
|
||||||
|
"chain pipeline request: %s, %s", request.form.keys(), request.files.keys()
|
||||||
|
)
|
||||||
|
body = request.form.get("chain") or request.files.get("chain")
|
||||||
|
if body is None:
|
||||||
|
return error_reply("chain pipeline must have a body")
|
||||||
|
|
||||||
|
data = yaml.safe_load(body)
|
||||||
|
with open("./schemas/chain.yaml", "r") as f:
|
||||||
|
schema = yaml.safe_load(f.read())
|
||||||
|
|
||||||
|
logger.debug("validating chain request: %s against %s", data, schema)
|
||||||
|
validate(data, schema)
|
||||||
|
|
||||||
|
# get defaults from the regular parameters
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
output = make_output_name(context, "chain", params, size)
|
||||||
|
job_name = output[0]
|
||||||
|
|
||||||
|
pipeline = ChainPipeline()
|
||||||
|
for stage_data in data.get("stages", []):
|
||||||
|
callback = CHAIN_STAGES[stage_data.get("type")]
|
||||||
|
kwargs = stage_data.get("params", {})
|
||||||
|
logger.info("request stage: %s, %s", callback.__name__, kwargs)
|
||||||
|
|
||||||
|
stage = StageParams(
|
||||||
|
stage_data.get("name", callback.__name__),
|
||||||
|
tile_size=get_size(kwargs.get("tile_size")),
|
||||||
|
outscale=get_and_clamp_int(kwargs, "outscale", 1, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "border" in kwargs:
|
||||||
|
border = Border.even(int(kwargs.get("border")))
|
||||||
|
kwargs["border"] = border
|
||||||
|
|
||||||
|
if "upscale" in kwargs:
|
||||||
|
upscale = UpscaleParams(kwargs.get("upscale"))
|
||||||
|
kwargs["upscale"] = upscale
|
||||||
|
|
||||||
|
stage_source_name = "source:%s" % (stage.name)
|
||||||
|
stage_mask_name = "mask:%s" % (stage.name)
|
||||||
|
|
||||||
|
if stage_source_name in request.files:
|
||||||
|
logger.debug(
|
||||||
|
"loading source image %s for pipeline stage %s",
|
||||||
|
stage_source_name,
|
||||||
|
stage.name,
|
||||||
|
)
|
||||||
|
source_file = request.files.get(stage_source_name)
|
||||||
|
source = Image.open(BytesIO(source_file.read())).convert("RGB")
|
||||||
|
source = valid_image(source, max_dims=(size.width, size.height))
|
||||||
|
kwargs["stage_source"] = source
|
||||||
|
|
||||||
|
if stage_mask_name in request.files:
|
||||||
|
logger.debug(
|
||||||
|
"loading mask image %s for pipeline stage %s",
|
||||||
|
stage_mask_name,
|
||||||
|
stage.name,
|
||||||
|
)
|
||||||
|
mask_file = request.files.get(stage_mask_name)
|
||||||
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
|
||||||
|
mask = valid_image(mask, max_dims=(size.width, size.height))
|
||||||
|
kwargs["stage_mask"] = mask
|
||||||
|
|
||||||
|
pipeline.append((callback, stage, kwargs))
|
||||||
|
|
||||||
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
|
# build and run chain pipeline
|
||||||
|
empty_source = Image.new("RGB", (size.width, size.height))
|
||||||
|
pool.submit(
|
||||||
|
job_name,
|
||||||
|
pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
empty_source,
|
||||||
|
output=output[0],
|
||||||
|
size=size,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
|
def blend(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
if "mask" not in request.files:
|
||||||
|
return error_reply("mask image is required")
|
||||||
|
|
||||||
|
mask_file = request.files.get("mask")
|
||||||
|
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||||
|
mask = valid_image(mask)
|
||||||
|
|
||||||
|
max_sources = 2
|
||||||
|
sources = []
|
||||||
|
|
||||||
|
for i in range(max_sources):
|
||||||
|
source_file = request.files.get("source:%s" % (i))
|
||||||
|
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||||
|
source = valid_image(source, mask.size, mask.size)
|
||||||
|
sources.append(source)
|
||||||
|
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
|
output = make_output_name(context, "upscale", params, size)
|
||||||
|
job_name = output[0]
|
||||||
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
|
pool.submit(
|
||||||
|
job_name,
|
||||||
|
run_blend_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
upscale,
|
||||||
|
sources,
|
||||||
|
mask,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
|
def txt2txt(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
device, params, size = pipeline_from_request(context)
|
||||||
|
|
||||||
|
output = make_output_name(context, "upscale", params, size)
|
||||||
|
logger.info("upscale job queued for: %s", output)
|
||||||
|
|
||||||
|
pool.submit(
|
||||||
|
output,
|
||||||
|
run_txt2txt_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
|
def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
output_file = request.args.get("output", None)
|
||||||
|
|
||||||
|
cancel = pool.cancel(output_file)
|
||||||
|
|
||||||
|
return ready_reply(cancel)
|
||||||
|
|
||||||
|
|
||||||
|
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
output_file = request.args.get("output", None)
|
||||||
|
|
||||||
|
done, progress = pool.done(output_file)
|
||||||
|
|
||||||
|
if done is None:
|
||||||
|
output = base_join(context.output_path, output_file)
|
||||||
|
if path.exists(output):
|
||||||
|
return ready_reply(True)
|
||||||
|
|
||||||
|
return ready_reply(done, progress=progress)
|
||||||
|
|
||||||
|
|
||||||
|
def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
return jsonify(pool.status())
|
||||||
|
|
||||||
|
|
||||||
|
def register_api_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
return [
|
||||||
|
app.route("/api")(wrap_route(introspect, context, app=app)),
|
||||||
|
app.route("/api/settings/masks")(wrap_route(list_mask_filters, context)),
|
||||||
|
app.route("/api/settings/models")(wrap_route(list_models, context)),
|
||||||
|
app.route("/api/settings/noises")(wrap_route(list_noise_sources, context)),
|
||||||
|
app.route("/api/settings/params")(wrap_route(list_params, context)),
|
||||||
|
app.route("/api/settings/platforms")(wrap_route(list_platforms, context)),
|
||||||
|
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, context)),
|
||||||
|
app.route("/api/img2img", methods=["POST"])(wrap_route(img2img, context, pool=pool)),
|
||||||
|
app.route("/api/txt2img", methods=["POST"])(wrap_route(txt2img, context, pool=pool)),
|
||||||
|
app.route("/api/txt2txt", methods=["POST"])(wrap_route(txt2txt, context, pool=pool)),
|
||||||
|
app.route("/api/inpaint", methods=["POST"])(wrap_route(inpaint, context, pool=pool)),
|
||||||
|
app.route("/api/upscale", methods=["POST"])(wrap_route(upscale, context, pool=pool)),
|
||||||
|
app.route("/api/chain", methods=["POST"])(wrap_route(chain, context, pool=pool)),
|
||||||
|
app.route("/api/blend", methods=["POST"])(wrap_route(blend, context, pool=pool)),
|
||||||
|
app.route("/api/cancel", methods=["PUT"])(wrap_route(cancel, context, pool=pool)),
|
||||||
|
app.route("/api/ready")(wrap_route(ready, context, pool=pool)),
|
||||||
|
app.route("/api/status")(wrap_route(status, context, pool=pool)),
|
||||||
|
]
|
|
@ -0,0 +1,224 @@
|
||||||
|
from functools import cmp_to_key
|
||||||
|
from glob import glob
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from onnxruntime import get_available_providers
|
||||||
|
|
||||||
|
from .context import ServerContext
|
||||||
|
from ..image import ( # mask filters; noise sources
|
||||||
|
mask_filter_gaussian_multiply,
|
||||||
|
mask_filter_gaussian_screen,
|
||||||
|
mask_filter_none,
|
||||||
|
noise_source_fill_edge,
|
||||||
|
noise_source_fill_mask,
|
||||||
|
noise_source_gaussian,
|
||||||
|
noise_source_histogram,
|
||||||
|
noise_source_normal,
|
||||||
|
noise_source_uniform,
|
||||||
|
)
|
||||||
|
from ..params import (
|
||||||
|
DeviceParams,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
# config caching
|
||||||
|
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
|
||||||
|
|
||||||
|
# pipeline params
|
||||||
|
platform_providers = {
|
||||||
|
"cpu": "CPUExecutionProvider",
|
||||||
|
"cuda": "CUDAExecutionProvider",
|
||||||
|
"directml": "DmlExecutionProvider",
|
||||||
|
"rocm": "ROCMExecutionProvider",
|
||||||
|
}
|
||||||
|
noise_sources = {
|
||||||
|
"fill-edge": noise_source_fill_edge,
|
||||||
|
"fill-mask": noise_source_fill_mask,
|
||||||
|
"gaussian": noise_source_gaussian,
|
||||||
|
"histogram": noise_source_histogram,
|
||||||
|
"normal": noise_source_normal,
|
||||||
|
"uniform": noise_source_uniform,
|
||||||
|
}
|
||||||
|
mask_filters = {
|
||||||
|
"none": mask_filter_none,
|
||||||
|
"gaussian-multiply": mask_filter_gaussian_multiply,
|
||||||
|
"gaussian-screen": mask_filter_gaussian_screen,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Available ORT providers
|
||||||
|
available_platforms: List[DeviceParams] = []
|
||||||
|
|
||||||
|
# loaded from model_path
|
||||||
|
correction_models: List[str] = []
|
||||||
|
diffusion_models: List[str] = []
|
||||||
|
inversion_models: List[str] = []
|
||||||
|
upscaling_models: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_params():
|
||||||
|
return config_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_platforms():
|
||||||
|
return available_platforms
|
||||||
|
|
||||||
|
|
||||||
|
def get_correction_models():
|
||||||
|
return correction_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_diffusion_models():
|
||||||
|
return diffusion_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_inversion_models():
|
||||||
|
return inversion_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_upscaling_models():
|
||||||
|
return upscaling_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_mask_filters():
|
||||||
|
return mask_filters
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise_sources():
|
||||||
|
return noise_sources
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_value(key: str, subkey: str = "default", default=None):
|
||||||
|
return config_params.get(key, {}).get(subkey, default)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name(model: str) -> str:
|
||||||
|
base = path.basename(model)
|
||||||
|
(file, _ext) = path.splitext(base)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(context: ServerContext) -> None:
|
||||||
|
global correction_models
|
||||||
|
global diffusion_models
|
||||||
|
global inversion_models
|
||||||
|
global upscaling_models
|
||||||
|
|
||||||
|
diffusion_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*"))
|
||||||
|
]
|
||||||
|
diffusion_models.extend(
|
||||||
|
[
|
||||||
|
get_model_name(f)
|
||||||
|
for f in glob(path.join(context.model_path, "stable-diffusion-*"))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
diffusion_models = list(set(diffusion_models))
|
||||||
|
diffusion_models.sort()
|
||||||
|
|
||||||
|
correction_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, "correction-*"))
|
||||||
|
]
|
||||||
|
correction_models = list(set(correction_models))
|
||||||
|
correction_models.sort()
|
||||||
|
|
||||||
|
inversion_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, "inversion-*"))
|
||||||
|
]
|
||||||
|
inversion_models = list(set(inversion_models))
|
||||||
|
inversion_models.sort()
|
||||||
|
|
||||||
|
upscaling_models = [
|
||||||
|
get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*"))
|
||||||
|
]
|
||||||
|
upscaling_models = list(set(upscaling_models))
|
||||||
|
upscaling_models.sort()
|
||||||
|
|
||||||
|
|
||||||
|
def load_params(context: ServerContext) -> None:
|
||||||
|
global config_params
|
||||||
|
params_file = path.join(context.params_path, "params.json")
|
||||||
|
with open(params_file, "r") as f:
|
||||||
|
config_params = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if "platform" in config_params and context.default_platform is not None:
|
||||||
|
logger.info(
|
||||||
|
"Overriding default platform from environment: %s",
|
||||||
|
context.default_platform,
|
||||||
|
)
|
||||||
|
config_platform = config_params.get("platform", {})
|
||||||
|
config_platform["default"] = context.default_platform
|
||||||
|
|
||||||
|
|
||||||
|
def load_platforms(context: ServerContext) -> None:
|
||||||
|
global available_platforms
|
||||||
|
|
||||||
|
providers = list(get_available_providers())
|
||||||
|
|
||||||
|
for potential in platform_providers:
|
||||||
|
if (
|
||||||
|
platform_providers[potential] in providers
|
||||||
|
and potential not in context.block_platforms
|
||||||
|
):
|
||||||
|
if potential == "cuda":
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
available_platforms.append(
|
||||||
|
DeviceParams(
|
||||||
|
potential,
|
||||||
|
platform_providers[potential],
|
||||||
|
{
|
||||||
|
"device_id": i,
|
||||||
|
},
|
||||||
|
context.optimizations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
available_platforms.append(
|
||||||
|
DeviceParams(
|
||||||
|
potential,
|
||||||
|
platform_providers[potential],
|
||||||
|
None,
|
||||||
|
context.optimizations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if context.any_platform:
|
||||||
|
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
||||||
|
available_platforms.append(
|
||||||
|
DeviceParams(
|
||||||
|
"any",
|
||||||
|
platform_providers["cpu"],
|
||||||
|
None,
|
||||||
|
context.optimizations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure CPU is last on the list
|
||||||
|
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
|
||||||
|
if a.device == b.device:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# any should be first, if it's available
|
||||||
|
if a.device == "any":
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# cpu should be last, if it's available
|
||||||
|
if a.device == "cpu":
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return -1
|
||||||
|
|
||||||
|
available_platforms = sorted(
|
||||||
|
available_platforms, key=cmp_to_key(any_first_cpu_last)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"available acceleration platforms: %s",
|
||||||
|
", ".join([str(p) for p in available_platforms]),
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,183 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
from .context import ServerContext
|
||||||
|
|
||||||
|
from .config import get_available_platforms, get_config_value, get_correction_models, get_upscaling_models
|
||||||
|
from .utils import get_model_path
|
||||||
|
|
||||||
|
from ..diffusion.load import pipeline_schedulers
|
||||||
|
from ..params import (
|
||||||
|
Border,
|
||||||
|
DeviceParams,
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
UpscaleParams,
|
||||||
|
)
|
||||||
|
from ..utils import (
|
||||||
|
get_and_clamp_float,
|
||||||
|
get_and_clamp_int,
|
||||||
|
get_from_list,
|
||||||
|
get_not_empty,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_from_request(context: ServerContext) -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
|
user = request.remote_addr
|
||||||
|
|
||||||
|
# platform stuff
|
||||||
|
device = None
|
||||||
|
device_name = request.args.get("platform")
|
||||||
|
|
||||||
|
if device_name is not None and device_name != "any":
|
||||||
|
for platform in get_available_platforms():
|
||||||
|
if platform.device == device_name:
|
||||||
|
device = platform
|
||||||
|
|
||||||
|
# pipeline stuff
|
||||||
|
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||||
|
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||||
|
model_path = get_model_path(context, model)
|
||||||
|
scheduler = get_from_list(
|
||||||
|
request.args, "scheduler", pipeline_schedulers.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
if scheduler is None:
|
||||||
|
scheduler = get_config_value("scheduler")
|
||||||
|
|
||||||
|
inversion = request.args.get("inversion", None)
|
||||||
|
inversion_path = None
|
||||||
|
if inversion is not None and inversion.strip() != "":
|
||||||
|
inversion_path = get_model_path(context, inversion)
|
||||||
|
|
||||||
|
# image params
|
||||||
|
prompt = get_not_empty(request.args, "prompt", get_config_value("prompt"))
|
||||||
|
negative_prompt = request.args.get("negativePrompt", None)
|
||||||
|
|
||||||
|
if negative_prompt is not None and negative_prompt.strip() == "":
|
||||||
|
negative_prompt = None
|
||||||
|
|
||||||
|
batch = get_and_clamp_int(
|
||||||
|
request.args,
|
||||||
|
"batch",
|
||||||
|
get_config_value("batch"),
|
||||||
|
get_config_value("batch", "max"),
|
||||||
|
get_config_value("batch", "min"),
|
||||||
|
)
|
||||||
|
cfg = get_and_clamp_float(
|
||||||
|
request.args,
|
||||||
|
"cfg",
|
||||||
|
get_config_value("cfg"),
|
||||||
|
get_config_value("cfg", "max"),
|
||||||
|
get_config_value("cfg", "min"),
|
||||||
|
)
|
||||||
|
eta = get_and_clamp_float(
|
||||||
|
request.args,
|
||||||
|
"eta",
|
||||||
|
get_config_value("eta"),
|
||||||
|
get_config_value("eta", "max"),
|
||||||
|
get_config_value("eta", "min"),
|
||||||
|
)
|
||||||
|
steps = get_and_clamp_int(
|
||||||
|
request.args,
|
||||||
|
"steps",
|
||||||
|
get_config_value("steps"),
|
||||||
|
get_config_value("steps", "max"),
|
||||||
|
get_config_value("steps", "min"),
|
||||||
|
)
|
||||||
|
height = get_and_clamp_int(
|
||||||
|
request.args,
|
||||||
|
"height",
|
||||||
|
get_config_value("height"),
|
||||||
|
get_config_value("height", "max"),
|
||||||
|
get_config_value("height", "min"),
|
||||||
|
)
|
||||||
|
width = get_and_clamp_int(
|
||||||
|
request.args,
|
||||||
|
"width",
|
||||||
|
get_config_value("width"),
|
||||||
|
get_config_value("width", "max"),
|
||||||
|
get_config_value("width", "min"),
|
||||||
|
)
|
||||||
|
|
||||||
|
seed = int(request.args.get("seed", -1))
|
||||||
|
if seed == -1:
|
||||||
|
# this one can safely use np.random because it produces a single value
|
||||||
|
seed = np.random.randint(np.iinfo(np.int32).max)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||||
|
user,
|
||||||
|
steps,
|
||||||
|
scheduler,
|
||||||
|
model_path,
|
||||||
|
device or "any device",
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
cfg,
|
||||||
|
seed,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
params = ImageParams(
|
||||||
|
model_path,
|
||||||
|
scheduler,
|
||||||
|
prompt,
|
||||||
|
cfg,
|
||||||
|
steps,
|
||||||
|
seed,
|
||||||
|
eta=eta,
|
||||||
|
lpw=lpw,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
batch=batch,
|
||||||
|
inversion=inversion_path,
|
||||||
|
)
|
||||||
|
size = Size(width, height)
|
||||||
|
return (device, params, size)
|
||||||
|
|
||||||
|
|
||||||
|
def border_from_request() -> Border:
|
||||||
|
left = get_and_clamp_int(
|
||||||
|
request.args, "left", 0, get_config_value("width", "max"), 0
|
||||||
|
)
|
||||||
|
right = get_and_clamp_int(
|
||||||
|
request.args, "right", 0, get_config_value("width", "max"), 0
|
||||||
|
)
|
||||||
|
top = get_and_clamp_int(
|
||||||
|
request.args, "top", 0, get_config_value("height", "max"), 0
|
||||||
|
)
|
||||||
|
bottom = get_and_clamp_int(
|
||||||
|
request.args, "bottom", 0, get_config_value("height", "max"), 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return Border(left, right, top, bottom)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_from_request() -> UpscaleParams:
|
||||||
|
denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0)
|
||||||
|
scale = get_and_clamp_int(request.args, "scale", 1, 4, 1)
|
||||||
|
outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1)
|
||||||
|
upscaling = get_from_list(request.args, "upscaling", get_upscaling_models())
|
||||||
|
correction = get_from_list(request.args, "correction", get_correction_models())
|
||||||
|
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||||
|
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
|
||||||
|
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
||||||
|
upscale_order = request.args.get("upscaleOrder", "correction-first")
|
||||||
|
|
||||||
|
return UpscaleParams(
|
||||||
|
upscaling,
|
||||||
|
correction_model=correction,
|
||||||
|
denoise=denoise,
|
||||||
|
faces=faces,
|
||||||
|
face_outscale=face_outscale,
|
||||||
|
face_strength=face_strength,
|
||||||
|
format="onnx",
|
||||||
|
outscale=outscale,
|
||||||
|
scale=scale,
|
||||||
|
upscale_order=upscale_order,
|
||||||
|
)
|
|
@ -0,0 +1,34 @@
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
from flask import Flask, send_from_directory
|
||||||
|
|
||||||
|
from .utils import wrap_route
|
||||||
|
from .context import ServerContext
|
||||||
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def serve_bundle_file(context: ServerContext, filename="index.html"):
|
||||||
|
return send_from_directory(path.join("..", context.bundle_path), filename)
|
||||||
|
|
||||||
|
|
||||||
|
# non-API routes
|
||||||
|
def index(context: ServerContext):
|
||||||
|
return serve_bundle_file(context)
|
||||||
|
|
||||||
|
|
||||||
|
def index_path(context: ServerContext, filename: str):
|
||||||
|
return serve_bundle_file(context, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def output(context: ServerContext, filename: str):
|
||||||
|
return send_from_directory(
|
||||||
|
path.join("..", context.output_path), filename, as_attachment=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_static_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
return [
|
||||||
|
app.route("/")(wrap_route(index, context)),
|
||||||
|
app.route("/<path:filename>")(wrap_route(index_path, context)),
|
||||||
|
app.route("/output/<path:filename>")(wrap_route(output, context)),
|
||||||
|
]
|
|
@ -0,0 +1,32 @@
|
||||||
|
from os import makedirs, path
|
||||||
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
from functools import partial, update_wrapper
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from onnx_web.utils import base_join
|
||||||
|
from onnx_web.worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
|
from .context import ServerContext
|
||||||
|
|
||||||
|
|
||||||
|
def check_paths(context: ServerContext) -> None:
|
||||||
|
if not path.exists(context.model_path):
|
||||||
|
raise RuntimeError("model path must exist")
|
||||||
|
|
||||||
|
if not path.exists(context.output_path):
|
||||||
|
makedirs(context.output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(context: ServerContext, model: str):
|
||||||
|
return base_join(context.model_path, model)
|
||||||
|
|
||||||
|
|
||||||
|
def register_routes(app: Flask, context: ServerContext, pool: DevicePoolExecutor, routes: List[Tuple[str, Dict, Callable]]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_route(func, *args, **kwargs):
|
||||||
|
partial_func = partial(func, *args, **kwargs)
|
||||||
|
update_wrapper(partial_func, func)
|
||||||
|
return partial_func
|
Loading…
Reference in New Issue