lint(api): type fixes and hints throughout
This commit is contained in:
parent
5d13629ee8
commit
6ec7777f77
|
@ -1,4 +1,4 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||
from .base import ChainPipeline, PipelineStage, StageParams
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_inpaint import BlendInpaintStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
|
|
|
@ -16,26 +16,6 @@ from .tile import needs_tile, process_tile_order
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class StageCallback(Protocol):
|
||||
"""
|
||||
Definition for a stage job function.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
job: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
**kwargs: Any
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Run this stage against a source image.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||
|
||||
|
||||
|
@ -77,7 +57,7 @@ class ChainPipeline:
|
|||
"""
|
||||
self.stages = list(stages or [])
|
||||
|
||||
def append(self, stage: PipelineStage):
|
||||
def append(self, stage: Optional[PipelineStage]):
|
||||
"""
|
||||
DEPRECATED: use `stage` instead
|
||||
|
||||
|
@ -100,7 +80,7 @@ class ChainPipeline:
|
|||
"""
|
||||
return self(job, server, params, source=source, callback=callback, **kwargs)
|
||||
|
||||
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
|
||||
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||
self.stages.append((callback, params, kwargs))
|
||||
return self
|
||||
|
||||
|
|
|
@ -42,8 +42,8 @@ def stage_upscale_correction(
|
|||
*,
|
||||
upscale: UpscaleParams,
|
||||
chain: Optional[ChainPipeline] = None,
|
||||
pre_stages: List[PipelineStage] = None,
|
||||
post_stages: List[PipelineStage] = None,
|
||||
pre_stages: Optional[List[PipelineStage]] = None,
|
||||
post_stages: Optional[List[PipelineStage]] = None,
|
||||
**kwargs,
|
||||
) -> ChainPipeline:
|
||||
"""
|
||||
|
@ -60,14 +60,14 @@ def stage_upscale_correction(
|
|||
chain = ChainPipeline()
|
||||
|
||||
if pre_stages is not None:
|
||||
for stage, pre_params, pre_opts in pre_stages:
|
||||
chain.append((stage, pre_params, pre_opts))
|
||||
for pre_stage in pre_stages:
|
||||
chain.append(pre_stage)
|
||||
|
||||
upscale_opts = {
|
||||
**kwargs,
|
||||
"upscale": upscale,
|
||||
}
|
||||
upscale_stage = None
|
||||
upscale_stage: Optional[PipelineStage] = None
|
||||
if upscale.scale > 1:
|
||||
if "bsrgan" in upscale.upscale_model:
|
||||
bsrgan_params = StageParams(
|
||||
|
@ -94,12 +94,14 @@ def stage_upscale_correction(
|
|||
else:
|
||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||
|
||||
correct_stage = None
|
||||
correct_stage: Optional[PipelineStage] = None
|
||||
if upscale.faces:
|
||||
face_params = StageParams(
|
||||
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
||||
)
|
||||
if "codeformer" in upscale.correction_model:
|
||||
if upscale.correction_model is None:
|
||||
logger.warn("no correction model set, skipping")
|
||||
elif "codeformer" in upscale.correction_model:
|
||||
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
|
||||
elif "gfpgan" in upscale.correction_model:
|
||||
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
|
||||
|
@ -120,7 +122,7 @@ def stage_upscale_correction(
|
|||
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
||||
|
||||
if post_stages is not None:
|
||||
for stage, post_params, post_opts in post_stages:
|
||||
chain.append((stage, post_params, post_opts))
|
||||
for post_stage in post_stages:
|
||||
chain.append(post_stage)
|
||||
|
||||
return chain
|
||||
|
|
|
@ -21,7 +21,7 @@ class UpscaleHighresStage(BaseStage):
|
|||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
source: Image.Image,
|
||||
*,
|
||||
*args,
|
||||
highres: HighresParams,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
|
|
|
@ -268,7 +268,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_errors = []
|
||||
|
||||
if args.sources and "sources" in models:
|
||||
for model in models.get("sources"):
|
||||
for model in models.get("sources", []):
|
||||
model = tuple_to_source(model)
|
||||
name = model.get("name")
|
||||
|
||||
|
@ -292,7 +292,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_errors.append(name)
|
||||
|
||||
if args.networks and "networks" in models:
|
||||
for network in models.get("networks"):
|
||||
for network in models.get("networks", []):
|
||||
name = network["name"]
|
||||
|
||||
if name in args.skip:
|
||||
|
@ -316,6 +316,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
conversion,
|
||||
network,
|
||||
dest,
|
||||
path.join(conversion.model_path, network_type, name),
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest, hf = fetch_model(
|
||||
|
@ -342,7 +343,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_errors.append(name)
|
||||
|
||||
if args.diffusion and "diffusion" in models:
|
||||
for model in models.get("diffusion"):
|
||||
for model in models.get("diffusion", []):
|
||||
model = tuple_to_diffusion(model)
|
||||
name = model.get("name")
|
||||
|
||||
|
@ -483,7 +484,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_errors.append(name)
|
||||
|
||||
if args.upscaling and "upscaling" in models:
|
||||
for model in models.get("upscaling"):
|
||||
for model in models.get("upscaling", []):
|
||||
model = tuple_to_upscaling(model)
|
||||
name = model.get("name")
|
||||
|
||||
|
@ -516,7 +517,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
model_errors.append(name)
|
||||
|
||||
if args.correction and "correction" in models:
|
||||
for model in models.get("correction"):
|
||||
for model in models.get("correction", []):
|
||||
model = tuple_to_correction(model)
|
||||
name = model.get("name")
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -17,16 +17,15 @@ def convert_diffusion_control(
|
|||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
opset: int,
|
||||
attention_slicing: str,
|
||||
attention_slicing: Optional[str] = None,
|
||||
):
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
opset = conversion.opset
|
||||
logger.debug("using Torch dtype %s for ControlNet", dtype)
|
||||
|
||||
output_path = Path(output_path)
|
||||
|
@ -35,7 +34,7 @@ def convert_diffusion_control(
|
|||
logger.info("ONNX model already exists, skipping")
|
||||
return
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype)
|
||||
controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype)
|
||||
if attention_slicing is not None:
|
||||
logger.info("enabling attention slicing for ControlNet")
|
||||
controlnet.set_attention_slice(attention_slicing)
|
||||
|
|
|
@ -262,7 +262,7 @@ def convert_diffusion_diffusers(
|
|||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
format: str,
|
||||
format: Optional[str],
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
|
|
|
@ -43,6 +43,8 @@ class PromptPhrase:
|
|||
if isinstance(other, self.__class__):
|
||||
return other.tokens == self.tokens and other.weight == self.weight
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class OnnxPromptVisitor(PTNodeVisitor):
|
||||
def __init__(self, defaults=True, weight=0.5, **kwargs):
|
||||
|
|
|
@ -15,7 +15,7 @@ def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray:
|
|||
|
||||
|
||||
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
|
||||
pass
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
|
||||
|
|
|
@ -6,7 +6,7 @@ from json import JSONDecodeError
|
|||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from platform import system
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union, TypeVar
|
||||
|
||||
import torch
|
||||
from yaml import safe_load
|
||||
|
@ -43,9 +43,11 @@ def get_and_clamp_int(
|
|||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||
|
||||
|
||||
TElem = TypeVar("TElem")
|
||||
|
||||
def get_from_list(
|
||||
args: Any, key: str, values: Sequence[Any], default_value: Optional[Any] = None
|
||||
) -> Optional[Any]:
|
||||
args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None
|
||||
) -> Optional[TElem]:
|
||||
selected = args.get(key, default_value)
|
||||
if selected in values:
|
||||
return selected
|
||||
|
@ -57,7 +59,7 @@ def get_from_list(
|
|||
return None
|
||||
|
||||
|
||||
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any:
|
||||
def get_from_map(args: Any, key: str, values: Dict[str, TElem], default: TElem) -> TElem:
|
||||
selected = args.get(key, default)
|
||||
if selected in values:
|
||||
return values[selected]
|
||||
|
@ -65,7 +67,7 @@ def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> A
|
|||
return values[default]
|
||||
|
||||
|
||||
def get_not_empty(args: Any, key: str, default: Any) -> Any:
|
||||
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
||||
val = args.get(key, default)
|
||||
|
||||
if val is None or len(val) == 0:
|
||||
|
|
|
@ -14,6 +14,7 @@ exclude = [
|
|||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"arpeggio",
|
||||
"basicsr.archs.rrdbnet_arch",
|
||||
"basicsr.utils.download_util",
|
||||
"basicsr.utils",
|
||||
|
@ -23,19 +24,45 @@ module = [
|
|||
"codeformer.facelib.utils.misc",
|
||||
"codeformer.facelib.utils",
|
||||
"codeformer.facelib",
|
||||
"compel",
|
||||
"controlnet_aux",
|
||||
"cv2",
|
||||
"diffusers",
|
||||
"diffusers.configuration_utils",
|
||||
"diffusers.loaders",
|
||||
"diffusers.models.attention_processor",
|
||||
"diffusers.models.autoencoder_kl",
|
||||
"diffusers.models.cross_attention",
|
||||
"diffusers.models.embeddings",
|
||||
"diffusers.models.modeling_utils",
|
||||
"diffusers.models.unet_2d_blocks",
|
||||
"diffusers.models.vae",
|
||||
"diffusers.utils",
|
||||
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
||||
"diffusers.pipelines.onnx_utils",
|
||||
"diffusers.pipelines.paint_by_example",
|
||||
"diffusers.pipelines.stable_diffusion",
|
||||
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
|
||||
"diffusers.pipeline_utils",
|
||||
"diffusers.schedulers",
|
||||
"diffusers.utils.logging",
|
||||
"facexlib.utils",
|
||||
"facexlib",
|
||||
"gfpgan",
|
||||
"gi.repository",
|
||||
"huggingface_hub",
|
||||
"huggingface_hub.file_download",
|
||||
"huggingface_hub.utils.tqdm",
|
||||
"mediapipe",
|
||||
"onnxruntime",
|
||||
"onnxruntime.transformers.float16",
|
||||
"piexif",
|
||||
"piexif.helper",
|
||||
"realesrgan",
|
||||
"realesrgan.archs.srvgg_arch",
|
||||
"safetensors",
|
||||
"transformers"
|
||||
"timm.models.layers",
|
||||
"transformers",
|
||||
"win10toast"
|
||||
]
|
||||
ignore_missing_imports = true
|
Loading…
Reference in New Issue