From 6ec7777f773dba40f1092b4e38382542dabf69d3 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 4 Jul 2023 10:20:28 -0500 Subject: [PATCH] lint(api): type fixes and hints throughout --- api/onnx_web/chain/__init__.py | 2 +- api/onnx_web/chain/base.py | 24 ++--------------- api/onnx_web/chain/upscale.py | 20 +++++++------- api/onnx_web/chain/upscale_highres.py | 2 +- api/onnx_web/convert/__main__.py | 11 ++++---- api/onnx_web/convert/diffusion/control.py | 9 +++---- api/onnx_web/convert/diffusion/diffusers.py | 2 +- api/onnx_web/prompt/grammar.py | 2 ++ api/onnx_web/prompt/parser.py | 2 +- api/onnx_web/utils.py | 12 +++++---- api/pyproject.toml | 29 ++++++++++++++++++++- 11 files changed, 64 insertions(+), 51 deletions(-) diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index fa247ed7..edda66d2 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,4 +1,4 @@ -from .base import ChainPipeline, PipelineStage, StageCallback, StageParams +from .base import ChainPipeline, PipelineStage, StageParams from .blend_img2img import BlendImg2ImgStage from .blend_inpaint import BlendInpaintStage from .blend_linear import BlendLinearStage diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 23976c69..6d950b5c 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -16,26 +16,6 @@ from .tile import needs_tile, process_tile_order logger = getLogger(__name__) -class StageCallback(Protocol): - """ - Definition for a stage job function. - """ - - def __call__( - self, - job: WorkerContext, - server: ServerContext, - stage: StageParams, - params: ImageParams, - source: Image.Image, - **kwargs: Any - ) -> Image.Image: - """ - Run this stage against a source image. - """ - pass - - PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] @@ -77,7 +57,7 @@ class ChainPipeline: """ self.stages = list(stages or []) - def append(self, stage: PipelineStage): + def append(self, stage: Optional[PipelineStage]): """ DEPRECATED: use `stage` instead @@ -100,7 +80,7 @@ class ChainPipeline: """ return self(job, server, params, source=source, callback=callback, **kwargs) - def stage(self, callback: StageCallback, params: StageParams, **kwargs): + def stage(self, callback: BaseStage, params: StageParams, **kwargs): self.stages.append((callback, params, kwargs)) return self diff --git a/api/onnx_web/chain/upscale.py b/api/onnx_web/chain/upscale.py index d5a5caeb..2a299a5e 100644 --- a/api/onnx_web/chain/upscale.py +++ b/api/onnx_web/chain/upscale.py @@ -42,8 +42,8 @@ def stage_upscale_correction( *, upscale: UpscaleParams, chain: Optional[ChainPipeline] = None, - pre_stages: List[PipelineStage] = None, - post_stages: List[PipelineStage] = None, + pre_stages: Optional[List[PipelineStage]] = None, + post_stages: Optional[List[PipelineStage]] = None, **kwargs, ) -> ChainPipeline: """ @@ -60,14 +60,14 @@ def stage_upscale_correction( chain = ChainPipeline() if pre_stages is not None: - for stage, pre_params, pre_opts in pre_stages: - chain.append((stage, pre_params, pre_opts)) + for pre_stage in pre_stages: + chain.append(pre_stage) upscale_opts = { **kwargs, "upscale": upscale, } - upscale_stage = None + upscale_stage: Optional[PipelineStage] = None if upscale.scale > 1: if "bsrgan" in upscale.upscale_model: bsrgan_params = StageParams( @@ -94,12 +94,14 @@ def stage_upscale_correction( else: logger.warn("unknown upscaling model: %s", upscale.upscale_model) - correct_stage = None + correct_stage: Optional[PipelineStage] = None if upscale.faces: face_params = StageParams( tile_size=stage.tile_size, outscale=upscale.face_outscale ) - if "codeformer" in upscale.correction_model: + if upscale.correction_model is None: + logger.warn("no correction model set, skipping") + elif "codeformer" in upscale.correction_model: correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts) elif "gfpgan" in upscale.correction_model: correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts) @@ -120,7 +122,7 @@ def stage_upscale_correction( logger.warn("unknown upscaling order: %s", upscale.upscale_order) if post_stages is not None: - for stage, post_params, post_opts in post_stages: - chain.append((stage, post_params, post_opts)) + for post_stage in post_stages: + chain.append(post_stage) return chain diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index f0e62d1c..96916e03 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -21,7 +21,7 @@ class UpscaleHighresStage(BaseStage): stage: StageParams, params: ImageParams, source: Image.Image, - *, + *args, highres: HighresParams, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 7e3315c6..36c97429 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -268,7 +268,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_errors = [] if args.sources and "sources" in models: - for model in models.get("sources"): + for model in models.get("sources", []): model = tuple_to_source(model) name = model.get("name") @@ -292,7 +292,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_errors.append(name) if args.networks and "networks" in models: - for network in models.get("networks"): + for network in models.get("networks", []): name = network["name"] if name in args.skip: @@ -316,6 +316,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): conversion, network, dest, + path.join(conversion.model_path, network_type, name), ) if network_type == "inversion" and network_model == "concept": dest, hf = fetch_model( @@ -342,7 +343,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_errors.append(name) if args.diffusion and "diffusion" in models: - for model in models.get("diffusion"): + for model in models.get("diffusion", []): model = tuple_to_diffusion(model) name = model.get("name") @@ -483,7 +484,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_errors.append(name) if args.upscaling and "upscaling" in models: - for model in models.get("upscaling"): + for model in models.get("upscaling", []): model = tuple_to_upscaling(model) name = model.get("name") @@ -516,7 +517,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_errors.append(name) if args.correction and "correction" in models: - for model in models.get("correction"): + for model in models.get("correction", []): model = tuple_to_correction(model) name = model.get("name") diff --git a/api/onnx_web/convert/diffusion/control.py b/api/onnx_web/convert/diffusion/control.py index ffc9f5c9..7e96801b 100644 --- a/api/onnx_web/convert/diffusion/control.py +++ b/api/onnx_web/convert/diffusion/control.py @@ -1,7 +1,7 @@ from logging import getLogger from os import path from pathlib import Path -from typing import Dict +from typing import Dict, Optional import torch @@ -17,16 +17,15 @@ def convert_diffusion_control( conversion: ConversionContext, model: Dict, source: str, - model_path: str, output_path: str, - opset: int, - attention_slicing: str, + attention_slicing: Optional[str] = None, ): name = model.get("name") source = source or model.get("source") device = conversion.training_device dtype = conversion.torch_dtype() + opset = conversion.opset logger.debug("using Torch dtype %s for ControlNet", dtype) output_path = Path(output_path) @@ -35,7 +34,7 @@ def convert_diffusion_control( logger.info("ONNX model already exists, skipping") return - controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype) + controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype) if attention_slicing is not None: logger.info("enabling attention slicing for ControlNet") controlnet.set_attention_slice(attention_slicing) diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 48bd7797..f096e935 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -262,7 +262,7 @@ def convert_diffusion_diffusers( conversion: ConversionContext, model: Dict, source: str, - format: str, + format: Optional[str], hf: bool = False, ) -> Tuple[bool, str]: """ diff --git a/api/onnx_web/prompt/grammar.py b/api/onnx_web/prompt/grammar.py index ae28d23e..20127030 100644 --- a/api/onnx_web/prompt/grammar.py +++ b/api/onnx_web/prompt/grammar.py @@ -43,6 +43,8 @@ class PromptPhrase: if isinstance(other, self.__class__): return other.tokens == self.tokens and other.weight == self.weight + return False + class OnnxPromptVisitor(PTNodeVisitor): def __init__(self, defaults=True, weight=0.5, **kwargs): diff --git a/api/onnx_web/prompt/parser.py b/api/onnx_web/prompt/parser.py index 6a0fde22..8bab4dbf 100644 --- a/api/onnx_web/prompt/parser.py +++ b/api/onnx_web/prompt/parser.py @@ -15,7 +15,7 @@ def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray: def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray: - pass + raise NotImplementedError() def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray: diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index caeaab2c..0c1b115d 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -6,7 +6,7 @@ from json import JSONDecodeError from logging import getLogger from os import environ, path from platform import system -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union, TypeVar import torch from yaml import safe_load @@ -43,9 +43,11 @@ def get_and_clamp_int( return min(max(int(args.get(key, default_value)), min_value), max_value) +TElem = TypeVar("TElem") + def get_from_list( - args: Any, key: str, values: Sequence[Any], default_value: Optional[Any] = None -) -> Optional[Any]: + args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None +) -> Optional[TElem]: selected = args.get(key, default_value) if selected in values: return selected @@ -57,7 +59,7 @@ def get_from_list( return None -def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any: +def get_from_map(args: Any, key: str, values: Dict[str, TElem], default: TElem) -> TElem: selected = args.get(key, default) if selected in values: return values[selected] @@ -65,7 +67,7 @@ def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> A return values[default] -def get_not_empty(args: Any, key: str, default: Any) -> Any: +def get_not_empty(args: Any, key: str, default: TElem) -> TElem: val = args.get(key, default) if val is None or len(val) == 0: diff --git a/api/pyproject.toml b/api/pyproject.toml index b863bbfb..efed7be3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -14,6 +14,7 @@ exclude = [ [[tool.mypy.overrides]] module = [ +"arpeggio", "basicsr.archs.rrdbnet_arch", "basicsr.utils.download_util", "basicsr.utils", @@ -23,19 +24,45 @@ module = [ "codeformer.facelib.utils.misc", "codeformer.facelib.utils", "codeformer.facelib", + "compel", + "controlnet_aux", + "cv2", "diffusers", + "diffusers.configuration_utils", + "diffusers.loaders", + "diffusers.models.attention_processor", + "diffusers.models.autoencoder_kl", + "diffusers.models.cross_attention", + "diffusers.models.embeddings", + "diffusers.models.modeling_utils", + "diffusers.models.unet_2d_blocks", + "diffusers.models.vae", + "diffusers.utils", "diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion", + "diffusers.pipelines.onnx_utils", "diffusers.pipelines.paint_by_example", "diffusers.pipelines.stable_diffusion", + "diffusers.pipelines.stable_diffusion.convert_from_ckpt", "diffusers.pipeline_utils", + "diffusers.schedulers", "diffusers.utils.logging", "facexlib.utils", "facexlib", "gfpgan", + "gi.repository", + "huggingface_hub", + "huggingface_hub.file_download", + "huggingface_hub.utils.tqdm", + "mediapipe", "onnxruntime", + "onnxruntime.transformers.float16", + "piexif", + "piexif.helper", "realesrgan", "realesrgan.archs.srvgg_arch", "safetensors", - "transformers" + "timm.models.layers", + "transformers", + "win10toast" ] ignore_missing_imports = true \ No newline at end of file