1
0
Fork 0

lint(api): type fixes and hints throughout

This commit is contained in:
Sean Sube 2023-07-04 10:20:28 -05:00
parent 5d13629ee8
commit 6ec7777f77
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 64 additions and 51 deletions

View File

@ -1,4 +1,4 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .base import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage

View File

@ -16,26 +16,6 @@ from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
class StageCallback(Protocol):
"""
Definition for a stage job function.
"""
def __call__(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
"""
Run this stage against a source image.
"""
pass
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
@ -77,7 +57,7 @@ class ChainPipeline:
"""
self.stages = list(stages or [])
def append(self, stage: PipelineStage):
def append(self, stage: Optional[PipelineStage]):
"""
DEPRECATED: use `stage` instead
@ -100,7 +80,7 @@ class ChainPipeline:
"""
return self(job, server, params, source=source, callback=callback, **kwargs)
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self

View File

@ -42,8 +42,8 @@ def stage_upscale_correction(
*,
upscale: UpscaleParams,
chain: Optional[ChainPipeline] = None,
pre_stages: List[PipelineStage] = None,
post_stages: List[PipelineStage] = None,
pre_stages: Optional[List[PipelineStage]] = None,
post_stages: Optional[List[PipelineStage]] = None,
**kwargs,
) -> ChainPipeline:
"""
@ -60,14 +60,14 @@ def stage_upscale_correction(
chain = ChainPipeline()
if pre_stages is not None:
for stage, pre_params, pre_opts in pre_stages:
chain.append((stage, pre_params, pre_opts))
for pre_stage in pre_stages:
chain.append(pre_stage)
upscale_opts = {
**kwargs,
"upscale": upscale,
}
upscale_stage = None
upscale_stage: Optional[PipelineStage] = None
if upscale.scale > 1:
if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams(
@ -94,12 +94,14 @@ def stage_upscale_correction(
else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
correct_stage = None
correct_stage: Optional[PipelineStage] = None
if upscale.faces:
face_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale
)
if "codeformer" in upscale.correction_model:
if upscale.correction_model is None:
logger.warn("no correction model set, skipping")
elif "codeformer" in upscale.correction_model:
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
elif "gfpgan" in upscale.correction_model:
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
@ -120,7 +122,7 @@ def stage_upscale_correction(
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
if post_stages is not None:
for stage, post_params, post_opts in post_stages:
chain.append((stage, post_params, post_opts))
for post_stage in post_stages:
chain.append(post_stage)
return chain

View File

@ -21,7 +21,7 @@ class UpscaleHighresStage(BaseStage):
stage: StageParams,
params: ImageParams,
source: Image.Image,
*,
*args,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,

View File

@ -268,7 +268,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors = []
if args.sources and "sources" in models:
for model in models.get("sources"):
for model in models.get("sources", []):
model = tuple_to_source(model)
name = model.get("name")
@ -292,7 +292,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name)
if args.networks and "networks" in models:
for network in models.get("networks"):
for network in models.get("networks", []):
name = network["name"]
if name in args.skip:
@ -316,6 +316,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
conversion,
network,
dest,
path.join(conversion.model_path, network_type, name),
)
if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model(
@ -342,7 +343,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name)
if args.diffusion and "diffusion" in models:
for model in models.get("diffusion"):
for model in models.get("diffusion", []):
model = tuple_to_diffusion(model)
name = model.get("name")
@ -483,7 +484,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name)
if args.upscaling and "upscaling" in models:
for model in models.get("upscaling"):
for model in models.get("upscaling", []):
model = tuple_to_upscaling(model)
name = model.get("name")
@ -516,7 +517,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name)
if args.correction and "correction" in models:
for model in models.get("correction"):
for model in models.get("correction", []):
model = tuple_to_correction(model)
name = model.get("name")

View File

@ -1,7 +1,7 @@
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict
from typing import Dict, Optional
import torch
@ -17,16 +17,15 @@ def convert_diffusion_control(
conversion: ConversionContext,
model: Dict,
source: str,
model_path: str,
output_path: str,
opset: int,
attention_slicing: str,
attention_slicing: Optional[str] = None,
):
name = model.get("name")
source = source or model.get("source")
device = conversion.training_device
dtype = conversion.torch_dtype()
opset = conversion.opset
logger.debug("using Torch dtype %s for ControlNet", dtype)
output_path = Path(output_path)
@ -35,7 +34,7 @@ def convert_diffusion_control(
logger.info("ONNX model already exists, skipping")
return
controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype)
controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype)
if attention_slicing is not None:
logger.info("enabling attention slicing for ControlNet")
controlnet.set_attention_slice(attention_slicing)

View File

@ -262,7 +262,7 @@ def convert_diffusion_diffusers(
conversion: ConversionContext,
model: Dict,
source: str,
format: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""

View File

@ -43,6 +43,8 @@ class PromptPhrase:
if isinstance(other, self.__class__):
return other.tokens == self.tokens and other.weight == self.weight
return False
class OnnxPromptVisitor(PTNodeVisitor):
def __init__(self, defaults=True, weight=0.5, **kwargs):

View File

@ -15,7 +15,7 @@ def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray:
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
pass
raise NotImplementedError()
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:

View File

@ -6,7 +6,7 @@ from json import JSONDecodeError
from logging import getLogger
from os import environ, path
from platform import system
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union, TypeVar
import torch
from yaml import safe_load
@ -43,9 +43,11 @@ def get_and_clamp_int(
return min(max(int(args.get(key, default_value)), min_value), max_value)
TElem = TypeVar("TElem")
def get_from_list(
args: Any, key: str, values: Sequence[Any], default_value: Optional[Any] = None
) -> Optional[Any]:
args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None
) -> Optional[TElem]:
selected = args.get(key, default_value)
if selected in values:
return selected
@ -57,7 +59,7 @@ def get_from_list(
return None
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any:
def get_from_map(args: Any, key: str, values: Dict[str, TElem], default: TElem) -> TElem:
selected = args.get(key, default)
if selected in values:
return values[selected]
@ -65,7 +67,7 @@ def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> A
return values[default]
def get_not_empty(args: Any, key: str, default: Any) -> Any:
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
val = args.get(key, default)
if val is None or len(val) == 0:

View File

@ -14,6 +14,7 @@ exclude = [
[[tool.mypy.overrides]]
module = [
"arpeggio",
"basicsr.archs.rrdbnet_arch",
"basicsr.utils.download_util",
"basicsr.utils",
@ -23,19 +24,45 @@ module = [
"codeformer.facelib.utils.misc",
"codeformer.facelib.utils",
"codeformer.facelib",
"compel",
"controlnet_aux",
"cv2",
"diffusers",
"diffusers.configuration_utils",
"diffusers.loaders",
"diffusers.models.attention_processor",
"diffusers.models.autoencoder_kl",
"diffusers.models.cross_attention",
"diffusers.models.embeddings",
"diffusers.models.modeling_utils",
"diffusers.models.unet_2d_blocks",
"diffusers.models.vae",
"diffusers.utils",
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
"diffusers.pipelines.onnx_utils",
"diffusers.pipelines.paint_by_example",
"diffusers.pipelines.stable_diffusion",
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
"diffusers.pipeline_utils",
"diffusers.schedulers",
"diffusers.utils.logging",
"facexlib.utils",
"facexlib",
"gfpgan",
"gi.repository",
"huggingface_hub",
"huggingface_hub.file_download",
"huggingface_hub.utils.tqdm",
"mediapipe",
"onnxruntime",
"onnxruntime.transformers.float16",
"piexif",
"piexif.helper",
"realesrgan",
"realesrgan.archs.srvgg_arch",
"safetensors",
"transformers"
"timm.models.layers",
"transformers",
"win10toast"
]
ignore_missing_imports = true