apply lint
This commit is contained in:
parent
a0dfc060da
commit
6d2dd0a043
|
@ -14,7 +14,12 @@ logger = getLogger(__name__)
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_diffusion_textual_inversion(
|
||||
context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None,
|
||||
context: ConversionContext,
|
||||
name: str,
|
||||
base_model: str,
|
||||
inversion: str,
|
||||
format: str,
|
||||
base_token: Optional[str] = None,
|
||||
):
|
||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||
logger.info(
|
||||
|
|
|
@ -10,6 +10,8 @@ from setproctitle import setproctitle
|
|||
from torch.multiprocessing import set_start_method
|
||||
|
||||
from .server.api import register_api_routes
|
||||
from .server.context import ServerContext
|
||||
from .server.hacks import apply_patches
|
||||
from .server.load import (
|
||||
get_available_platforms,
|
||||
load_extras,
|
||||
|
@ -17,8 +19,6 @@ from .server.load import (
|
|||
load_params,
|
||||
load_platforms,
|
||||
)
|
||||
from .server.context import ServerContext
|
||||
from .server.hacks import apply_patches
|
||||
from .server.static import register_static_routes
|
||||
from .server.utils import check_paths
|
||||
from .utils import is_debug
|
||||
|
|
|
@ -31,6 +31,7 @@ from ..utils import (
|
|||
sanitize_name,
|
||||
)
|
||||
from ..worker.pool import DevicePoolExecutor
|
||||
from .context import ServerContext
|
||||
from .load import (
|
||||
get_available_platforms,
|
||||
get_config_params,
|
||||
|
@ -43,7 +44,6 @@ from .load import (
|
|||
get_noise_sources,
|
||||
get_upscaling_models,
|
||||
)
|
||||
from .context import ServerContext
|
||||
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
||||
from .utils import wrap_route
|
||||
|
||||
|
|
|
@ -3,13 +3,12 @@ from glob import glob
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Dict, List, Union
|
||||
from jsonschema import ValidationError, validate
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from jsonschema import ValidationError, validate
|
||||
from yaml import safe_load
|
||||
|
||||
from ..utils import merge
|
||||
from ..image import ( # mask filters; noise sources
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
|
@ -23,6 +22,7 @@ from ..image import ( # mask filters; noise sources
|
|||
)
|
||||
from ..params import DeviceParams
|
||||
from ..torch_before_ort import get_available_providers
|
||||
from ..utils import merge
|
||||
from .context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -146,20 +146,27 @@ def load_extras(context: ServerContext):
|
|||
for model in data[model_type]:
|
||||
if "label" in model:
|
||||
model_name = model["name"]
|
||||
logger.debug("collecting label for model %s from %s", model_name, file)
|
||||
logger.debug(
|
||||
"collecting label for model %s from %s",
|
||||
model_name,
|
||||
file,
|
||||
)
|
||||
labels[model_name] = model["label"]
|
||||
|
||||
except Exception as err:
|
||||
logger.error("error loading extras file: %s", err)
|
||||
|
||||
logger.debug("adding labels to strings: %s", labels)
|
||||
merge(strings, {
|
||||
merge(
|
||||
strings,
|
||||
{
|
||||
"en": {
|
||||
"translation": {
|
||||
"model": labels,
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
extra_strings = strings
|
||||
|
||||
|
@ -170,9 +177,7 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
|||
pattern_path = path.join(context.model_path, pattern)
|
||||
logger.debug("loading models from %s", pattern_path)
|
||||
|
||||
models.extend([
|
||||
get_model_name(f) for f in glob(pattern_path)
|
||||
])
|
||||
models.extend([get_model_name(f) for f in glob(pattern_path)])
|
||||
|
||||
unique_models = list(set(models))
|
||||
unique_models.sort()
|
||||
|
@ -185,25 +190,37 @@ def load_models(context: ServerContext) -> None:
|
|||
global inversion_models
|
||||
global upscaling_models
|
||||
|
||||
diffusion_models = list_model_globs(context, [
|
||||
diffusion_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"diffusion-*",
|
||||
"stable-diffusion-*",
|
||||
])
|
||||
],
|
||||
)
|
||||
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||
|
||||
correction_models = list_model_globs(context, [
|
||||
correction_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"correction-*",
|
||||
])
|
||||
],
|
||||
)
|
||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||
|
||||
inversion_models = list_model_globs(context, [
|
||||
inversion_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"inversion-*",
|
||||
])
|
||||
],
|
||||
)
|
||||
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
||||
|
||||
upscaling_models = list_model_globs(context, [
|
||||
upscaling_models = list_model_globs(
|
||||
context,
|
||||
[
|
||||
"upscaling-*",
|
||||
])
|
||||
],
|
||||
)
|
||||
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ from flask import request
|
|||
from ..diffusers.load import pipeline_schedulers
|
||||
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
|
||||
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
||||
from .context import ServerContext
|
||||
from .load import (
|
||||
get_available_platforms,
|
||||
get_config_value,
|
||||
get_correction_models,
|
||||
get_upscaling_models,
|
||||
)
|
||||
from .context import ServerContext
|
||||
from .utils import get_model_path
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -110,7 +110,8 @@ def sanitize_name(name):
|
|||
|
||||
def merge(a, b, path=None):
|
||||
"merges b into a"
|
||||
if path is None: path = []
|
||||
if path is None:
|
||||
path = []
|
||||
for key in b:
|
||||
if key in a:
|
||||
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
||||
|
@ -118,7 +119,7 @@ def merge(a, b, path=None):
|
|||
elif a[key] == b[key]:
|
||||
pass # same leaf value
|
||||
else:
|
||||
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
|
||||
raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
|
||||
else:
|
||||
a[key] = b[key]
|
||||
return a
|
||||
|
|
|
@ -9,8 +9,8 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
|
|||
[tool.mypy]
|
||||
# ignore_missing_imports = true
|
||||
exclude = [
|
||||
"onnx_web.diffusion.lpw_stable_diffusion_onnx",
|
||||
"onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale"
|
||||
"onnx_web.diffusers.lpw_stable_diffusion_onnx",
|
||||
"onnx_web.diffusers.pipeline_onnx_stable_diffusion_upscale"
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
|
|
|
@ -6,5 +6,5 @@ ignore = E203, W503
|
|||
max-line-length = 160
|
||||
per-file-ignores = __init__.py:F401
|
||||
exclude =
|
||||
onnx_web/diffusion/lpw_stable_diffusion_onnx.py
|
||||
onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py
|
||||
onnx_web/diffusers/lpw_stable_diffusion_onnx.py
|
||||
onnx_web/diffusers/pipeline_onnx_stable_diffusion_upscale.py
|
Loading…
Reference in New Issue