lint(api): type fixes and hints throughout
This commit is contained in:
parent
5d13629ee8
commit
6ec7777f77
|
@ -1,4 +1,4 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
from .base import ChainPipeline, PipelineStage, StageParams
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
from .blend_inpaint import BlendInpaintStage
|
from .blend_inpaint import BlendInpaintStage
|
||||||
from .blend_linear import BlendLinearStage
|
from .blend_linear import BlendLinearStage
|
||||||
|
|
|
@ -16,26 +16,6 @@ from .tile import needs_tile, process_tile_order
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StageCallback(Protocol):
|
|
||||||
"""
|
|
||||||
Definition for a stage job function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
job: WorkerContext,
|
|
||||||
server: ServerContext,
|
|
||||||
stage: StageParams,
|
|
||||||
params: ImageParams,
|
|
||||||
source: Image.Image,
|
|
||||||
**kwargs: Any
|
|
||||||
) -> Image.Image:
|
|
||||||
"""
|
|
||||||
Run this stage against a source image.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,7 +57,7 @@ class ChainPipeline:
|
||||||
"""
|
"""
|
||||||
self.stages = list(stages or [])
|
self.stages = list(stages or [])
|
||||||
|
|
||||||
def append(self, stage: PipelineStage):
|
def append(self, stage: Optional[PipelineStage]):
|
||||||
"""
|
"""
|
||||||
DEPRECATED: use `stage` instead
|
DEPRECATED: use `stage` instead
|
||||||
|
|
||||||
|
@ -100,7 +80,7 @@ class ChainPipeline:
|
||||||
"""
|
"""
|
||||||
return self(job, server, params, source=source, callback=callback, **kwargs)
|
return self(job, server, params, source=source, callback=callback, **kwargs)
|
||||||
|
|
||||||
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
||||||
self.stages.append((callback, params, kwargs))
|
self.stages.append((callback, params, kwargs))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -42,8 +42,8 @@ def stage_upscale_correction(
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
chain: Optional[ChainPipeline] = None,
|
chain: Optional[ChainPipeline] = None,
|
||||||
pre_stages: List[PipelineStage] = None,
|
pre_stages: Optional[List[PipelineStage]] = None,
|
||||||
post_stages: List[PipelineStage] = None,
|
post_stages: Optional[List[PipelineStage]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ChainPipeline:
|
) -> ChainPipeline:
|
||||||
"""
|
"""
|
||||||
|
@ -60,14 +60,14 @@ def stage_upscale_correction(
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
|
|
||||||
if pre_stages is not None:
|
if pre_stages is not None:
|
||||||
for stage, pre_params, pre_opts in pre_stages:
|
for pre_stage in pre_stages:
|
||||||
chain.append((stage, pre_params, pre_opts))
|
chain.append(pre_stage)
|
||||||
|
|
||||||
upscale_opts = {
|
upscale_opts = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
"upscale": upscale,
|
"upscale": upscale,
|
||||||
}
|
}
|
||||||
upscale_stage = None
|
upscale_stage: Optional[PipelineStage] = None
|
||||||
if upscale.scale > 1:
|
if upscale.scale > 1:
|
||||||
if "bsrgan" in upscale.upscale_model:
|
if "bsrgan" in upscale.upscale_model:
|
||||||
bsrgan_params = StageParams(
|
bsrgan_params = StageParams(
|
||||||
|
@ -94,12 +94,14 @@ def stage_upscale_correction(
|
||||||
else:
|
else:
|
||||||
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
logger.warn("unknown upscaling model: %s", upscale.upscale_model)
|
||||||
|
|
||||||
correct_stage = None
|
correct_stage: Optional[PipelineStage] = None
|
||||||
if upscale.faces:
|
if upscale.faces:
|
||||||
face_params = StageParams(
|
face_params = StageParams(
|
||||||
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
tile_size=stage.tile_size, outscale=upscale.face_outscale
|
||||||
)
|
)
|
||||||
if "codeformer" in upscale.correction_model:
|
if upscale.correction_model is None:
|
||||||
|
logger.warn("no correction model set, skipping")
|
||||||
|
elif "codeformer" in upscale.correction_model:
|
||||||
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
|
correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts)
|
||||||
elif "gfpgan" in upscale.correction_model:
|
elif "gfpgan" in upscale.correction_model:
|
||||||
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
|
correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts)
|
||||||
|
@ -120,7 +122,7 @@ def stage_upscale_correction(
|
||||||
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
logger.warn("unknown upscaling order: %s", upscale.upscale_order)
|
||||||
|
|
||||||
if post_stages is not None:
|
if post_stages is not None:
|
||||||
for stage, post_params, post_opts in post_stages:
|
for post_stage in post_stages:
|
||||||
chain.append((stage, post_params, post_opts))
|
chain.append(post_stage)
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
|
|
@ -21,7 +21,7 @@ class UpscaleHighresStage(BaseStage):
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
*,
|
*args,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
|
|
@ -268,7 +268,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
model_errors = []
|
model_errors = []
|
||||||
|
|
||||||
if args.sources and "sources" in models:
|
if args.sources and "sources" in models:
|
||||||
for model in models.get("sources"):
|
for model in models.get("sources", []):
|
||||||
model = tuple_to_source(model)
|
model = tuple_to_source(model)
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
|
||||||
|
@ -292,7 +292,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
model_errors.append(name)
|
model_errors.append(name)
|
||||||
|
|
||||||
if args.networks and "networks" in models:
|
if args.networks and "networks" in models:
|
||||||
for network in models.get("networks"):
|
for network in models.get("networks", []):
|
||||||
name = network["name"]
|
name = network["name"]
|
||||||
|
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
|
@ -316,6 +316,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
conversion,
|
conversion,
|
||||||
network,
|
network,
|
||||||
dest,
|
dest,
|
||||||
|
path.join(conversion.model_path, network_type, name),
|
||||||
)
|
)
|
||||||
if network_type == "inversion" and network_model == "concept":
|
if network_type == "inversion" and network_model == "concept":
|
||||||
dest, hf = fetch_model(
|
dest, hf = fetch_model(
|
||||||
|
@ -342,7 +343,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
model_errors.append(name)
|
model_errors.append(name)
|
||||||
|
|
||||||
if args.diffusion and "diffusion" in models:
|
if args.diffusion and "diffusion" in models:
|
||||||
for model in models.get("diffusion"):
|
for model in models.get("diffusion", []):
|
||||||
model = tuple_to_diffusion(model)
|
model = tuple_to_diffusion(model)
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
|
||||||
|
@ -483,7 +484,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
model_errors.append(name)
|
model_errors.append(name)
|
||||||
|
|
||||||
if args.upscaling and "upscaling" in models:
|
if args.upscaling and "upscaling" in models:
|
||||||
for model in models.get("upscaling"):
|
for model in models.get("upscaling", []):
|
||||||
model = tuple_to_upscaling(model)
|
model = tuple_to_upscaling(model)
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
|
||||||
|
@ -516,7 +517,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
model_errors.append(name)
|
model_errors.append(name)
|
||||||
|
|
||||||
if args.correction and "correction" in models:
|
if args.correction and "correction" in models:
|
||||||
for model in models.get("correction"):
|
for model in models.get("correction", []):
|
||||||
model = tuple_to_correction(model)
|
model = tuple_to_correction(model)
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -17,16 +17,15 @@ def convert_diffusion_control(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
source: str,
|
||||||
model_path: str,
|
|
||||||
output_path: str,
|
output_path: str,
|
||||||
opset: int,
|
attention_slicing: Optional[str] = None,
|
||||||
attention_slicing: str,
|
|
||||||
):
|
):
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
source = source or model.get("source")
|
source = source or model.get("source")
|
||||||
|
|
||||||
device = conversion.training_device
|
device = conversion.training_device
|
||||||
dtype = conversion.torch_dtype()
|
dtype = conversion.torch_dtype()
|
||||||
|
opset = conversion.opset
|
||||||
logger.debug("using Torch dtype %s for ControlNet", dtype)
|
logger.debug("using Torch dtype %s for ControlNet", dtype)
|
||||||
|
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
@ -35,7 +34,7 @@ def convert_diffusion_control(
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=dtype)
|
controlnet = ControlNetModel.from_pretrained(source, torch_dtype=dtype)
|
||||||
if attention_slicing is not None:
|
if attention_slicing is not None:
|
||||||
logger.info("enabling attention slicing for ControlNet")
|
logger.info("enabling attention slicing for ControlNet")
|
||||||
controlnet.set_attention_slice(attention_slicing)
|
controlnet.set_attention_slice(attention_slicing)
|
||||||
|
|
|
@ -262,7 +262,7 @@ def convert_diffusion_diffusers(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
source: str,
|
||||||
format: str,
|
format: Optional[str],
|
||||||
hf: bool = False,
|
hf: bool = False,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -43,6 +43,8 @@ class PromptPhrase:
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return other.tokens == self.tokens and other.weight == self.weight
|
return other.tokens == self.tokens and other.weight == self.weight
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class OnnxPromptVisitor(PTNodeVisitor):
|
class OnnxPromptVisitor(PTNodeVisitor):
|
||||||
def __init__(self, defaults=True, weight=0.5, **kwargs):
|
def __init__(self, defaults=True, weight=0.5, **kwargs):
|
||||||
|
|
|
@ -15,7 +15,7 @@ def parse_prompt_compel(pipeline, prompt: str) -> np.ndarray:
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
|
def parse_prompt_lpw(pipeline, prompt: str, debug=False) -> np.ndarray:
|
||||||
pass
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
|
def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray:
|
||||||
|
|
|
@ -6,7 +6,7 @@ from json import JSONDecodeError
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from platform import system
|
from platform import system
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Union, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
@ -43,9 +43,11 @@ def get_and_clamp_int(
|
||||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
|
TElem = TypeVar("TElem")
|
||||||
|
|
||||||
def get_from_list(
|
def get_from_list(
|
||||||
args: Any, key: str, values: Sequence[Any], default_value: Optional[Any] = None
|
args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None
|
||||||
) -> Optional[Any]:
|
) -> Optional[TElem]:
|
||||||
selected = args.get(key, default_value)
|
selected = args.get(key, default_value)
|
||||||
if selected in values:
|
if selected in values:
|
||||||
return selected
|
return selected
|
||||||
|
@ -57,7 +59,7 @@ def get_from_list(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any:
|
def get_from_map(args: Any, key: str, values: Dict[str, TElem], default: TElem) -> TElem:
|
||||||
selected = args.get(key, default)
|
selected = args.get(key, default)
|
||||||
if selected in values:
|
if selected in values:
|
||||||
return values[selected]
|
return values[selected]
|
||||||
|
@ -65,7 +67,7 @@ def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> A
|
||||||
return values[default]
|
return values[default]
|
||||||
|
|
||||||
|
|
||||||
def get_not_empty(args: Any, key: str, default: Any) -> Any:
|
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
||||||
val = args.get(key, default)
|
val = args.get(key, default)
|
||||||
|
|
||||||
if val is None or len(val) == 0:
|
if val is None or len(val) == 0:
|
||||||
|
|
|
@ -14,6 +14,7 @@ exclude = [
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = [
|
module = [
|
||||||
|
"arpeggio",
|
||||||
"basicsr.archs.rrdbnet_arch",
|
"basicsr.archs.rrdbnet_arch",
|
||||||
"basicsr.utils.download_util",
|
"basicsr.utils.download_util",
|
||||||
"basicsr.utils",
|
"basicsr.utils",
|
||||||
|
@ -23,19 +24,45 @@ module = [
|
||||||
"codeformer.facelib.utils.misc",
|
"codeformer.facelib.utils.misc",
|
||||||
"codeformer.facelib.utils",
|
"codeformer.facelib.utils",
|
||||||
"codeformer.facelib",
|
"codeformer.facelib",
|
||||||
|
"compel",
|
||||||
|
"controlnet_aux",
|
||||||
|
"cv2",
|
||||||
"diffusers",
|
"diffusers",
|
||||||
|
"diffusers.configuration_utils",
|
||||||
|
"diffusers.loaders",
|
||||||
|
"diffusers.models.attention_processor",
|
||||||
|
"diffusers.models.autoencoder_kl",
|
||||||
|
"diffusers.models.cross_attention",
|
||||||
|
"diffusers.models.embeddings",
|
||||||
|
"diffusers.models.modeling_utils",
|
||||||
|
"diffusers.models.unet_2d_blocks",
|
||||||
|
"diffusers.models.vae",
|
||||||
|
"diffusers.utils",
|
||||||
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
"diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion",
|
||||||
|
"diffusers.pipelines.onnx_utils",
|
||||||
"diffusers.pipelines.paint_by_example",
|
"diffusers.pipelines.paint_by_example",
|
||||||
"diffusers.pipelines.stable_diffusion",
|
"diffusers.pipelines.stable_diffusion",
|
||||||
|
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
|
||||||
"diffusers.pipeline_utils",
|
"diffusers.pipeline_utils",
|
||||||
|
"diffusers.schedulers",
|
||||||
"diffusers.utils.logging",
|
"diffusers.utils.logging",
|
||||||
"facexlib.utils",
|
"facexlib.utils",
|
||||||
"facexlib",
|
"facexlib",
|
||||||
"gfpgan",
|
"gfpgan",
|
||||||
|
"gi.repository",
|
||||||
|
"huggingface_hub",
|
||||||
|
"huggingface_hub.file_download",
|
||||||
|
"huggingface_hub.utils.tqdm",
|
||||||
|
"mediapipe",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
|
"onnxruntime.transformers.float16",
|
||||||
|
"piexif",
|
||||||
|
"piexif.helper",
|
||||||
"realesrgan",
|
"realesrgan",
|
||||||
"realesrgan.archs.srvgg_arch",
|
"realesrgan.archs.srvgg_arch",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"transformers"
|
"timm.models.layers",
|
||||||
|
"transformers",
|
||||||
|
"win10toast"
|
||||||
]
|
]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
Loading…
Reference in New Issue