apply lint
This commit is contained in:
parent
a0dfc060da
commit
6d2dd0a043
|
@ -14,7 +14,12 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_textual_inversion(
|
def convert_diffusion_textual_inversion(
|
||||||
context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None,
|
context: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
base_model: str,
|
||||||
|
inversion: str,
|
||||||
|
format: str,
|
||||||
|
base_token: Optional[str] = None,
|
||||||
):
|
):
|
||||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -10,6 +10,8 @@ from setproctitle import setproctitle
|
||||||
from torch.multiprocessing import set_start_method
|
from torch.multiprocessing import set_start_method
|
||||||
|
|
||||||
from .server.api import register_api_routes
|
from .server.api import register_api_routes
|
||||||
|
from .server.context import ServerContext
|
||||||
|
from .server.hacks import apply_patches
|
||||||
from .server.load import (
|
from .server.load import (
|
||||||
get_available_platforms,
|
get_available_platforms,
|
||||||
load_extras,
|
load_extras,
|
||||||
|
@ -17,8 +19,6 @@ from .server.load import (
|
||||||
load_params,
|
load_params,
|
||||||
load_platforms,
|
load_platforms,
|
||||||
)
|
)
|
||||||
from .server.context import ServerContext
|
|
||||||
from .server.hacks import apply_patches
|
|
||||||
from .server.static import register_static_routes
|
from .server.static import register_static_routes
|
||||||
from .server.utils import check_paths
|
from .server.utils import check_paths
|
||||||
from .utils import is_debug
|
from .utils import is_debug
|
||||||
|
|
|
@ -31,6 +31,7 @@ from ..utils import (
|
||||||
sanitize_name,
|
sanitize_name,
|
||||||
)
|
)
|
||||||
from ..worker.pool import DevicePoolExecutor
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
from .context import ServerContext
|
||||||
from .load import (
|
from .load import (
|
||||||
get_available_platforms,
|
get_available_platforms,
|
||||||
get_config_params,
|
get_config_params,
|
||||||
|
@ -43,7 +44,6 @@ from .load import (
|
||||||
get_noise_sources,
|
get_noise_sources,
|
||||||
get_upscaling_models,
|
get_upscaling_models,
|
||||||
)
|
)
|
||||||
from .context import ServerContext
|
|
||||||
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
from .params import border_from_request, pipeline_from_request, upscale_from_request
|
||||||
from .utils import wrap_route
|
from .utils import wrap_route
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,12 @@ from glob import glob
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
from jsonschema import ValidationError, validate
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
from jsonschema import ValidationError, validate
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
from ..utils import merge
|
|
||||||
from ..image import ( # mask filters; noise sources
|
from ..image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -23,6 +22,7 @@ from ..image import ( # mask filters; noise sources
|
||||||
)
|
)
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..torch_before_ort import get_available_providers
|
from ..torch_before_ort import get_available_providers
|
||||||
|
from ..utils import merge
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -146,20 +146,27 @@ def load_extras(context: ServerContext):
|
||||||
for model in data[model_type]:
|
for model in data[model_type]:
|
||||||
if "label" in model:
|
if "label" in model:
|
||||||
model_name = model["name"]
|
model_name = model["name"]
|
||||||
logger.debug("collecting label for model %s from %s", model_name, file)
|
logger.debug(
|
||||||
|
"collecting label for model %s from %s",
|
||||||
|
model_name,
|
||||||
|
file,
|
||||||
|
)
|
||||||
labels[model_name] = model["label"]
|
labels[model_name] = model["label"]
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.error("error loading extras file: %s", err)
|
logger.error("error loading extras file: %s", err)
|
||||||
|
|
||||||
logger.debug("adding labels to strings: %s", labels)
|
logger.debug("adding labels to strings: %s", labels)
|
||||||
merge(strings, {
|
merge(
|
||||||
|
strings,
|
||||||
|
{
|
||||||
"en": {
|
"en": {
|
||||||
"translation": {
|
"translation": {
|
||||||
"model": labels,
|
"model": labels,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
extra_strings = strings
|
extra_strings = strings
|
||||||
|
|
||||||
|
@ -170,9 +177,7 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
|
||||||
pattern_path = path.join(context.model_path, pattern)
|
pattern_path = path.join(context.model_path, pattern)
|
||||||
logger.debug("loading models from %s", pattern_path)
|
logger.debug("loading models from %s", pattern_path)
|
||||||
|
|
||||||
models.extend([
|
models.extend([get_model_name(f) for f in glob(pattern_path)])
|
||||||
get_model_name(f) for f in glob(pattern_path)
|
|
||||||
])
|
|
||||||
|
|
||||||
unique_models = list(set(models))
|
unique_models = list(set(models))
|
||||||
unique_models.sort()
|
unique_models.sort()
|
||||||
|
@ -185,25 +190,37 @@ def load_models(context: ServerContext) -> None:
|
||||||
global inversion_models
|
global inversion_models
|
||||||
global upscaling_models
|
global upscaling_models
|
||||||
|
|
||||||
diffusion_models = list_model_globs(context, [
|
diffusion_models = list_model_globs(
|
||||||
|
context,
|
||||||
|
[
|
||||||
"diffusion-*",
|
"diffusion-*",
|
||||||
"stable-diffusion-*",
|
"stable-diffusion-*",
|
||||||
])
|
],
|
||||||
|
)
|
||||||
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
|
||||||
|
|
||||||
correction_models = list_model_globs(context, [
|
correction_models = list_model_globs(
|
||||||
|
context,
|
||||||
|
[
|
||||||
"correction-*",
|
"correction-*",
|
||||||
])
|
],
|
||||||
|
)
|
||||||
logger.debug("loaded correction models from disk: %s", correction_models)
|
logger.debug("loaded correction models from disk: %s", correction_models)
|
||||||
|
|
||||||
inversion_models = list_model_globs(context, [
|
inversion_models = list_model_globs(
|
||||||
|
context,
|
||||||
|
[
|
||||||
"inversion-*",
|
"inversion-*",
|
||||||
])
|
],
|
||||||
|
)
|
||||||
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
logger.debug("loaded inversion models from disk: %s", inversion_models)
|
||||||
|
|
||||||
upscaling_models = list_model_globs(context, [
|
upscaling_models = list_model_globs(
|
||||||
|
context,
|
||||||
|
[
|
||||||
"upscaling-*",
|
"upscaling-*",
|
||||||
])
|
],
|
||||||
|
)
|
||||||
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
logger.debug("loaded upscaling models from disk: %s", upscaling_models)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,13 +7,13 @@ from flask import request
|
||||||
from ..diffusers.load import pipeline_schedulers
|
from ..diffusers.load import pipeline_schedulers
|
||||||
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
|
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
|
||||||
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
|
||||||
|
from .context import ServerContext
|
||||||
from .load import (
|
from .load import (
|
||||||
get_available_platforms,
|
get_available_platforms,
|
||||||
get_config_value,
|
get_config_value,
|
||||||
get_correction_models,
|
get_correction_models,
|
||||||
get_upscaling_models,
|
get_upscaling_models,
|
||||||
)
|
)
|
||||||
from .context import ServerContext
|
|
||||||
from .utils import get_model_path
|
from .utils import get_model_path
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -110,7 +110,8 @@ def sanitize_name(name):
|
||||||
|
|
||||||
def merge(a, b, path=None):
|
def merge(a, b, path=None):
|
||||||
"merges b into a"
|
"merges b into a"
|
||||||
if path is None: path = []
|
if path is None:
|
||||||
|
path = []
|
||||||
for key in b:
|
for key in b:
|
||||||
if key in a:
|
if key in a:
|
||||||
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
if isinstance(a[key], dict) and isinstance(b[key], dict):
|
||||||
|
@ -118,7 +119,7 @@ def merge(a, b, path=None):
|
||||||
elif a[key] == b[key]:
|
elif a[key] == b[key]:
|
||||||
pass # same leaf value
|
pass # same leaf value
|
||||||
else:
|
else:
|
||||||
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
|
raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
|
||||||
else:
|
else:
|
||||||
a[key] = b[key]
|
a[key] = b[key]
|
||||||
return a
|
return a
|
||||||
|
|
|
@ -9,8 +9,8 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
# ignore_missing_imports = true
|
# ignore_missing_imports = true
|
||||||
exclude = [
|
exclude = [
|
||||||
"onnx_web.diffusion.lpw_stable_diffusion_onnx",
|
"onnx_web.diffusers.lpw_stable_diffusion_onnx",
|
||||||
"onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale"
|
"onnx_web.diffusers.pipeline_onnx_stable_diffusion_upscale"
|
||||||
]
|
]
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
|
|
|
@ -6,5 +6,5 @@ ignore = E203, W503
|
||||||
max-line-length = 160
|
max-line-length = 160
|
||||||
per-file-ignores = __init__.py:F401
|
per-file-ignores = __init__.py:F401
|
||||||
exclude =
|
exclude =
|
||||||
onnx_web/diffusion/lpw_stable_diffusion_onnx.py
|
onnx_web/diffusers/lpw_stable_diffusion_onnx.py
|
||||||
onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py
|
onnx_web/diffusers/pipeline_onnx_stable_diffusion_upscale.py
|
Loading…
Reference in New Issue