1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-05 07:19:48 -06:00
parent a0dfc060da
commit 6d2dd0a043
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 60 additions and 37 deletions

View File

@ -14,7 +14,12 @@ logger = getLogger(__name__)
@torch.no_grad()
def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str, format: str, base_token: Optional[str] = None,
context: ConversionContext,
name: str,
base_model: str,
inversion: str,
format: str,
base_token: Optional[str] = None,
):
dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info(

View File

@ -10,6 +10,8 @@ from setproctitle import setproctitle
from torch.multiprocessing import set_start_method
from .server.api import register_api_routes
from .server.context import ServerContext
from .server.hacks import apply_patches
from .server.load import (
get_available_platforms,
load_extras,
@ -17,8 +19,6 @@ from .server.load import (
load_params,
load_platforms,
)
from .server.context import ServerContext
from .server.hacks import apply_patches
from .server.static import register_static_routes
from .server.utils import check_paths
from .utils import is_debug

View File

@ -31,6 +31,7 @@ from ..utils import (
sanitize_name,
)
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .load import (
get_available_platforms,
get_config_params,
@ -43,7 +44,6 @@ from .load import (
get_noise_sources,
get_upscaling_models,
)
from .context import ServerContext
from .params import border_from_request, pipeline_from_request, upscale_from_request
from .utils import wrap_route

View File

@ -3,13 +3,12 @@ from glob import glob
from logging import getLogger
from os import path
from typing import Any, Dict, List, Union
from jsonschema import ValidationError, validate
import torch
import yaml
from jsonschema import ValidationError, validate
from yaml import safe_load
from ..utils import merge
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
@ -23,6 +22,7 @@ from ..image import ( # mask filters; noise sources
)
from ..params import DeviceParams
from ..torch_before_ort import get_available_providers
from ..utils import merge
from .context import ServerContext
logger = getLogger(__name__)
@ -146,20 +146,27 @@ def load_extras(context: ServerContext):
for model in data[model_type]:
if "label" in model:
model_name = model["name"]
logger.debug("collecting label for model %s from %s", model_name, file)
logger.debug(
"collecting label for model %s from %s",
model_name,
file,
)
labels[model_name] = model["label"]
except Exception as err:
logger.error("error loading extras file: %s", err)
logger.debug("adding labels to strings: %s", labels)
merge(strings, {
"en": {
"translation": {
"model": labels,
merge(
strings,
{
"en": {
"translation": {
"model": labels,
}
}
}
})
},
)
extra_strings = strings
@ -170,9 +177,7 @@ def list_model_globs(context: ServerContext, globs: List[str]) -> List[str]:
pattern_path = path.join(context.model_path, pattern)
logger.debug("loading models from %s", pattern_path)
models.extend([
get_model_name(f) for f in glob(pattern_path)
])
models.extend([get_model_name(f) for f in glob(pattern_path)])
unique_models = list(set(models))
unique_models.sort()
@ -185,25 +190,37 @@ def load_models(context: ServerContext) -> None:
global inversion_models
global upscaling_models
diffusion_models = list_model_globs(context, [
"diffusion-*",
"stable-diffusion-*",
])
diffusion_models = list_model_globs(
context,
[
"diffusion-*",
"stable-diffusion-*",
],
)
logger.debug("loaded diffusion models from disk: %s", diffusion_models)
correction_models = list_model_globs(context, [
"correction-*",
])
correction_models = list_model_globs(
context,
[
"correction-*",
],
)
logger.debug("loaded correction models from disk: %s", correction_models)
inversion_models = list_model_globs(context, [
"inversion-*",
])
inversion_models = list_model_globs(
context,
[
"inversion-*",
],
)
logger.debug("loaded inversion models from disk: %s", inversion_models)
upscaling_models = list_model_globs(context, [
"upscaling-*",
])
upscaling_models = list_model_globs(
context,
[
"upscaling-*",
],
)
logger.debug("loaded upscaling models from disk: %s", upscaling_models)

View File

@ -7,13 +7,13 @@ from flask import request
from ..diffusers.load import pipeline_schedulers
from ..params import Border, DeviceParams, ImageParams, Size, UpscaleParams
from ..utils import get_and_clamp_float, get_and_clamp_int, get_from_list, get_not_empty
from .context import ServerContext
from .load import (
get_available_platforms,
get_config_value,
get_correction_models,
get_upscaling_models,
)
from .context import ServerContext
from .utils import get_model_path
logger = getLogger(__name__)

View File

@ -110,15 +110,16 @@ def sanitize_name(name):
def merge(a, b, path=None):
"merges b into a"
if path is None: path = []
if path is None:
path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
pass # same leaf value
pass # same leaf value
else:
raise Exception("Conflict at %s" % '.'.join(path + [str(key)]))
raise Exception("Conflict at %s" % ".".join(path + [str(key)]))
else:
a[key] = b[key]
return a

View File

@ -9,8 +9,8 @@ skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion
[tool.mypy]
# ignore_missing_imports = true
exclude = [
"onnx_web.diffusion.lpw_stable_diffusion_onnx",
"onnx_web.diffusion.pipeline_onnx_stable_diffusion_upscale"
"onnx_web.diffusers.lpw_stable_diffusion_onnx",
"onnx_web.diffusers.pipeline_onnx_stable_diffusion_upscale"
]
[[tool.mypy.overrides]]

View File

@ -6,5 +6,5 @@ ignore = E203, W503
max-line-length = 160
per-file-ignores = __init__.py:F401
exclude =
onnx_web/diffusion/lpw_stable_diffusion_onnx.py
onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py
onnx_web/diffusers/lpw_stable_diffusion_onnx.py
onnx_web/diffusers/pipeline_onnx_stable_diffusion_upscale.py