1
0
Fork 0

lint(api): type fixes and hints throughout

This commit is contained in:
Sean Sube 2023-07-04 10:20:28 -05:00
parent 5d13629ee8
commit 6ec7777f77
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
11 changed files with 64 additions and 51 deletions

View File

@ -1,4 +1,4 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams from .base import ChainPipeline, PipelineStage, StageParams
from .blend_img2img import BlendImg2ImgStage from .blend_img2img import BlendImg2ImgStage
from .blend_inpaint import BlendInpaintStage from .blend_inpaint import BlendInpaintStage
from .blend_linear import BlendLinearStage from .blend_linear import BlendLinearStage

View File

@ -16,26 +16,6 @@ from .tile import needs_tile, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
class StageCallback(Protocol):
"""
Definition for a stage job function.
"""
def __call__(
self,
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
"""
Run this stage against a source image.
"""
pass
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
@ -77,7 +57,7 @@ class ChainPipeline:
""" """
self.stages = list(stages or []) self.stages = list(stages or [])
def append(self, stage: PipelineStage): def append(self, stage: Optional[PipelineStage]):
""" """
DEPRECATED: use `stage` instead DEPRECATED: use `stage` instead
@ -100,7 +80,7 @@ class ChainPipeline:
""" """
return self(job, server, params, source=source, callback=callback, **kwargs) return self(job, server, params, source=source, callback=callback, **kwargs)
def stage(self, callback: StageCallback, params: StageParams, **kwargs): def stage(self, callback: BaseStage, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs)) self.stages.append((callback, params, kwargs))
return self return self

View File

@ -42,8 +42,8 @@ def stage_upscale_correction(
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
chain: Optional[ChainPipeline] = None, chain: Optional[ChainPipeline] = None,
pre_stages: List[PipelineStage] = None, pre_stages: Optional[List[PipelineStage]] = None,
post_stages: List[PipelineStage] = None, post_stages: Optional[List[PipelineStage]] = None,
**kwargs, **kwargs,
) -> ChainPipeline: ) -> ChainPipeline:
""" """
@ -60,14 +60,14 @@ def stage_upscale_correction(
chain = ChainPipeline() chain = ChainPipeline()
if pre_stages is not None: if pre_stages is not None:
for stage, pre_params, pre_opts in pre_stages: for pre_stage in pre_stages:
chain.append((stage, pre_params, pre_opts)) chain.append(pre_stage)
upscale_opts = { upscale_opts = {
**kwargs, **kwargs,
"upscale": upscale, "upscale": upscale,
} }
upscale_stage = None upscale_stage: Optional[PipelineStage] = None
if upscale.scale > 1: if upscale.scale > 1:
if "bsrgan" in upscale.upscale_model: if "bsrgan" in upscale.upscale_model:
bsrgan_params = StageParams( bsrgan_params = StageParams(
@ -94,12 +94,14 @@ def stage_upscale_correction(
else: else:
logger.warn("unknown upscaling model: %s", upscale.upscale_model) logger.warn("unknown upscaling model: %s", upscale.upscale_model)
correct_stage = None correct_stage: Optional[PipelineStage] = None
if upscale.faces: if upscale.faces:
face_params = StageParams( face_params = StageParams(
tile_size=stage.tile_size, outscale=upscale.face_outscale tile_size=stage.tile_size, outscale=upscale.face_outscale
) )
if "codeformer" in upscale.correction_model: if upscale.correction_model is None:
logger.warn("no correction model set, skipping")
elif "codeformer" in upscale.correction_model:
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts) correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
elif "gfpgan" in upscale.correction_model: elif "gfpgan" in upscale.correction_model:
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts) correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
@ -120,7 +122,7 @@ def stage_upscale_correction(
logger.warn("unknown upscaling order: %s", upscale.upscale_order) logger.warn("unknown upscaling order: %s", upscale.upscale_order)
if post_stages is not None: if post_stages is not None:
for stage, post_params, post_opts in post_stages: for post_stage in post_stages:
chain.append((stage, post_params, post_opts)) chain.append(post_stage)
return chain return chain

View File

@ -21,7 +21,7 @@ class UpscaleHighresStage(BaseStage):
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source: Image.Image, source: Image.Image,
*, *args,
highres: HighresParams, highres: HighresParams,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,

View File

@ -268,7 +268,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors = [] model_errors = []
if args.sources and "sources" in models: if args.sources and "sources" in models:
for model in models.get("sources"): for model in models.get("sources", []):
model = tuple_to_source(model) model = tuple_to_source(model)
name = model.get("name") name = model.get("name")
@ -292,7 +292,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name) model_errors.append(name)
if args.networks and "networks" in models: if args.networks and "networks" in models:
for network in models.get("networks"): for network in models.get("networks", []):
name = network["name"] name = network["name"]
if name in args.skip: if name in args.skip:
@ -316,6 +316,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
conversion, conversion,
network, network,
dest, dest,
path.join(conversion.model_path, network_type, name),
) )
if network_type == "inversion" and network_model == "concept": if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model( dest, hf = fetch_model(
@ -342,7 +343,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name) model_errors.append(name)
if args.diffusion and "diffusion" in models: if args.diffusion and "diffusion" in models:
for model in models.get("diffusion"): for model in models.get("diffusion", []):
model = tuple_to_diffusion(model) model = tuple_to_diffusion(model)
name = model.get("name") name = model.get("name")
@ -483,7 +484,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name) model_errors.append(name)
if args.upscaling and "upscaling" in models: if args.upscaling and "upscaling" in models:
for model in models.get("upscaling"): for model in models.get("upscaling", []):
model = tuple_to_upscaling(model) model = tuple_to_upscaling(model)
name = model.get("name") name = model.get("name")
@ -516,7 +517,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
model_errors.append(name) model_errors.append(name)
if args.correction and "correction" in models: if args.correction and "correction" in models:
for model in models.get("correction"): for model in models.get("correction", []):
model = tuple_to_correction(model) model = tuple_to_correction(model)
name = model.get("name") name = model.get("name")

View File

@ -1,7 +1,7 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Optional
import torch import torch
@ -17,16 +17,15 @@ def convert_diffusion_control(
conversion: ConversionContext, conversion: ConversionContext,
model: Dict, model: Dict,
source: str, source: str,
model_path: str,
output_path: str, output_path: str,
opset: int, attention_slicing: Optional[str] = None,
attention_slicing: str,
): ):
name = model.get("name") name = model.get("name")
source = source or model.get("source") source = source or model.get("source")
device = conversion.training_device device = conversion.training_device
dtype = conversion.torch_dtype() dtype = conversion.torch_dtype()
opset = conversion.opset
logger.debug("using Torch dtype %s for ControlNet", dtype) logger.debug("using Torch dtype %s for ControlNet", dtype)
output_path = Path(output_path) output_path = Path(output_path)
@ -35,7 +34,7 @@ def convert_diffusion_control(
logger.info("ONNX model already exists, skipping") logger.info("ONNX model already exists, skipping")
return return
controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype) controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype)
if attention_slicing is not None: if attention_slicing is not None:
logger.info("enabling attention slicing for ControlNet") logger.info("enabling attention slicing for ControlNet")
controlnet.set_attention_slice(attention_slicing) controlnet.set_attention_slice(attention_slicing)

View File

@ -262,7 +262,7 @@ def convert_diffusion_diffusers(
conversion: ConversionContext, conversion: ConversionContext,
model: Dict, model: Dict,
source: str, source: str,
format: str, format: Optional[str],
hf: bool = False, hf: bool = False,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """

View File

@ -43,6 +43,8 @@ class PromptPhrase:
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return other.tokens == self.tokens and other.weight == self.weight return other.tokens == self.tokens and other.weight == self.weight
return False
class OnnxPromptVisitor(PTNodeVisitor): class OnnxPromptVisitor(PTNodeVisitor):
def __init__(self, defaults=True, weight=0.5, **kwargs): def __init__(self, defaults=True, weight=0.5, **kwargs):

View File

@ -15,7 +15,7 @@ def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray:
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray: def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
pass raise NotImplementedError()
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray: def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:

View File

@ -6,7 +6,7 @@ from json import JSONDecodeError
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from platform import system from platform import system
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence, Union, TypeVar
import torch import torch
from yaml import safe_load from yaml import safe_load
@ -43,9 +43,11 @@ def get_and_clamp_int(
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
TElem = TypeVar("TElem")
def get_from_list( def get_from_list(
args: Any, key: str, values: Sequence[Any], default_value: Optional[Any] = None args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None
) -> Optional[Any]: ) -> Optional[TElem]:
selected = args.get(key, default_value) selected = args.get(key, default_value)
if selected in values: if selected in values:
return selected return selected
@ -57,7 +59,7 @@ def get_from_list(
return None return None
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any: def get_from_map(args: Any, key: str, values: Dict[str, TElem], default: TElem) -> TElem:
selected = args.get(key, default) selected = args.get(key, default)
if selected in values: if selected in values:
return values[selected] return values[selected]
@ -65,7 +67,7 @@ def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> A
return values[default] return values[default]
def get_not_empty(args: Any, key: str, default: Any) -> Any: def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
val = args.get(key, default) val = args.get(key, default)
if val is None or len(val) == 0: if val is None or len(val) == 0:

View File

@ -14,6 +14,7 @@ exclude = [
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = [ module = [
"arpeggio",
"basicsr.archs.rrdbnet_arch", "basicsr.archs.rrdbnet_arch",
"basicsr.utils.download_util", "basicsr.utils.download_util",
"basicsr.utils", "basicsr.utils",
@ -23,19 +24,45 @@ module = [
"codeformer.facelib.utils.misc", "codeformer.facelib.utils.misc",
"codeformer.facelib.utils", "codeformer.facelib.utils",
"codeformer.facelib", "codeformer.facelib",
"compel",
"controlnet_aux",
"cv2",
"diffusers", "diffusers",
"diffusers.configuration_utils",
"diffusers.loaders",
"diffusers.models.attention_processor",
"diffusers.models.autoencoder_kl",
"diffusers.models.cross_attention",
"diffusers.models.embeddings",
"diffusers.models.modeling_utils",
"diffusers.models.unet_2d_blocks",
"diffusers.models.vae",
"diffusers.utils",
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion", "diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
"diffusers.pipelines.onnx_utils",
"diffusers.pipelines.paint_by_example", "diffusers.pipelines.paint_by_example",
"diffusers.pipelines.stable_diffusion", "diffusers.pipelines.stable_diffusion",
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
"diffusers.pipeline_utils", "diffusers.pipeline_utils",
"diffusers.schedulers",
"diffusers.utils.logging", "diffusers.utils.logging",
"facexlib.utils", "facexlib.utils",
"facexlib", "facexlib",
"gfpgan", "gfpgan",
"gi.repository",
"huggingface_hub",
"huggingface_hub.file_download",
"huggingface_hub.utils.tqdm",
"mediapipe",
"onnxruntime", "onnxruntime",
"onnxruntime.transformers.float16",
"piexif",
"piexif.helper",
"realesrgan", "realesrgan",
"realesrgan.archs.srvgg_arch", "realesrgan.archs.srvgg_arch",
"safetensors", "safetensors",
"transformers" "timm.models.layers",
"transformers",
"win10toast"
] ]
ignore_missing_imports = true ignore_missing_imports = true