diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 117c258f..9dcfda50 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -14,7 +14,12 @@ logger = getLogger(__name__) @torch.no_grad() def convert_diffusion_textual_inversion( - context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None, + context: ConversionContext, + name: str, + base_model: str, + inversion: str, + format: str, + base_token: Optional[str] = None, ): dest_path = path.join(context.model_path, f"inversion-{name}") logger.info( diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index d267cef4..df56781c 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -10,6 +10,8 @@ from setproctitle import setproctitle from torch.multiprocessing import set_start_method from .server.api import register_api_routes +from .server.context import ServerContext +from .server.hacks import apply_patches from .server.load import ( get_available_platforms, load_extras, @@ -17,8 +19,6 @@ from .server.load import ( load_params, load_platforms, ) -from .server.context import ServerContext -from .server.hacks import apply_patches from .server.static import register_static_routes from .server.utils import check_paths from .utils import is_debug diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index d6dea6eb..5cad5604 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -31,6 +31,7 @@ from ..utils import ( sanitize_name, ) from ..worker.pool import DevicePoolExecutor +from .context import ServerContext from .load import ( get_available_platforms, get_config_params, @@ -43,7 +44,6 @@ from .load import ( get_noise_sources, get_upscaling_models, ) -from .context import ServerContext from .params import border_from_request, pipeline_from_request, upscale_from_request from .utils import wrap_route diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 83095912..8cd20432 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -3,13 +3,12 @@ from glob import glob from logging import getLogger from os import path from typing import Any, Dict, List, Union -from jsonschema import ValidationError, validate import torch import yaml +from jsonschema import ValidationError, validate from yaml import safe_load -from ..utils import merge from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -23,6 +22,7 @@ from ..image import ( # mask filters; noise sources ) from ..params import DeviceParams from ..torch_before_ort import get_available_providers +from ..utils import merge from .context import ServerContext logger = getLogger(__name__) @@ -146,20 +146,27 @@ def load_extras(context: ServerContext): for model in data[model_type]: if "label" in model: model_name = model["name"] - logger.debug("collecting label for model %s from %s", model_name, file) + logger.debug( + "collecting label for model %s from %s", + model_name, + file, + ) labels[model_name] = model["label"] except Exception as err: logger.error("error loading extras file: %s", err) logger.debug("adding labels to strings: %s", labels) - merge(strings, { - "en": { - "translation": { - "model": labels, + merge( + strings, + { + "en": { + "translation": { + "model": labels, + } } - } - }) + }, + ) extra_strings = strings @@ -170,9 +177,7 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]: pattern_path = path.join(context.model_path, pattern) logger.debug("loading models from %s", pattern_path) - models.extend([ - get_model_name(f) for f in glob(pattern_path) - ]) + models.extend([get_model_name(f) for f in glob(pattern_path)]) unique_models = list(set(models)) unique_models.sort() @@ -185,25 +190,37 @@ def load_models(context: ServerContext) -> None: global inversion_models global upscaling_models - diffusion_models = list_model_globs(context, [ - "diffusion-*", - "stable-diffusion-*", - ]) + diffusion_models = list_model_globs( + context, + [ + "diffusion-*", + "stable-diffusion-*", + ], + ) logger.debug("loaded diffusion models from disk: %s", diffusion_models) - correction_models = list_model_globs(context, [ - "correction-*", - ]) + correction_models = list_model_globs( + context, + [ + "correction-*", + ], + ) logger.debug("loaded correction models from disk: %s", correction_models) - inversion_models = list_model_globs(context, [ - "inversion-*", - ]) + inversion_models = list_model_globs( + context, + [ + "inversion-*", + ], + ) logger.debug("loaded inversion models from disk: %s", inversion_models) - upscaling_models = list_model_globs(context, [ - "upscaling-*", - ]) + upscaling_models = list_model_globs( + context, + [ + "upscaling-*", + ], + ) logger.debug("loaded upscaling models from disk: %s", upscaling_models) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index ca492393..6797177d 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -7,13 +7,13 @@ from flask import request from ..diffusers.load import pipeline_schedulers from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty +from .context import ServerContext from .load import ( get_available_platforms, get_config_value, get_correction_models, get_upscaling_models, ) -from .context import ServerContext from .utils import get_model_path logger = getLogger(__name__) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 52445c6d..54c913e8 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -110,15 +110,16 @@ def sanitize_name(name): def merge(a, b, path=None): "merges b into a" - if path is None: path = [] + if path is None: + path = [] for key in b: if key in a: if isinstance(a[key], dict) and isinstance(b[key], dict): merge(a[key], b[key], path + [str(key)]) elif a[key] == b[key]: - pass # same leaf value + pass # same leaf value else: - raise Exception("Conflict at %s" % '.'.join(path + [str(key)])) + raise Exception("Conflict at %s" % ".".join(path + [str(key)])) else: a[key] = b[key] return a diff --git a/api/pyproject.toml b/api/pyproject.toml index e337c1f2..aabd4334 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -9,8 +9,8 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion [tool.mypy] # ignore_missing_imports = true exclude = [ - "onnx_web.diffusion.lpw_stable_diffusion_onnx", - "onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale" + "onnx_web.diffusers.lpw_stable_diffusion_onnx", + "onnx_web.diffusers.pipeline_onnx_stable_diffusion_upscale" ] [[tool.mypy.overrides]] diff --git a/api/setup.cfg b/api/setup.cfg index 9986d4f6..c006aef5 100644 --- a/api/setup.cfg +++ b/api/setup.cfg @@ -6,5 +6,5 @@ ignore = E203, W503 max-line-length = 160 per-file-ignores = __init__.py:F401 exclude = - onnx_web/diffusion/lpw_stable_diffusion_onnx.py - onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py \ No newline at end of file + onnx_web/diffusers/lpw_stable_diffusion_onnx.py + onnx_web/diffusers/pipeline_onnx_stable_diffusion_upscale.py \ No newline at end of file