1
0
Fork 0

lint and test fixes

This commit is contained in:
Sean Sube 2023-12-03 12:53:50 -06:00
parent b398d65624
commit 9b4ae0916b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
14 changed files with 20 additions and 19 deletions

View File

@ -18,7 +18,7 @@ class BaseStage:
_stage: StageParams,
_params: ImageParams,
_sources: StageResult,
*args,
*,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Optional
from typing import List, Optional
from PIL import Image
@ -28,7 +28,7 @@ class BlendGridStage(BaseStage):
# rows: Optional[List[str]] = None,
# columns: Optional[List[str]] = None,
# title: Optional[str] = None,
order: Optional[int] = None,
order: Optional[List[int]] = None,
stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,

View File

@ -3,7 +3,6 @@ from typing import Optional, Tuple
import numpy as np
import torch
from PIL import Image
from ..constants import LATENT_FACTOR
from ..diffusers.load import load_pipeline
@ -41,7 +40,7 @@ class SourceTxt2ImgStage(BaseStage):
latents: Optional[np.ndarray] = None,
prompt_index: Optional[int] = None,
**kwargs,
) -> Image.Image:
) -> StageResult:
params = params.with_args(**kwargs)
size = size.with_args(**kwargs)

View File

@ -22,7 +22,7 @@ class UpscaleHighresStage(BaseStage):
stage: StageParams,
params: ImageParams,
sources: StageResult,
*args,
*,
highres: HighresParams,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,

View File

@ -103,7 +103,6 @@ def get_model_version(
opts["prediction_type"] = "epsilon"
except Exception:
logger.debug("unable to load tensor for version check")
pass
return (v2, opts)

View File

@ -76,7 +76,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
names = [fix_node_name(node.name) for node in nodes]
for key, value in keys.items():
root, *rest = key.split(".")
root, *_rest = key.split(".")
logger.trace("fixing XL node name: %s -> %s", key, root)
simple = False

View File

@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
class ConversionContext(ServerContext):
def __init__(
self,
model_path: Optional[str] = None,
model_path: str = ".",
cache_path: Optional[str] = None,
device: Optional[str] = None,
half: bool = False,
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model
name, source, *_rest = model
return {
"name": name,

View File

@ -505,7 +505,7 @@ def load_unet(
def load_vae(
server: ServerContext, device: DeviceParams, model: str, params: ImageParams
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
):
# one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL)

View File

@ -28,9 +28,9 @@ class UNetWrapper(object):
def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
sample: Optional[np.ndarray] = None,
timestep: Optional[np.ndarray] = None,
encoder_hidden_states: Optional[np.ndarray] = None,
**kwargs,
):
logger.trace(

View File

@ -1,6 +1,6 @@
from typing import List, Literal
NetworkType = Literal["inversion", "lora"]
NetworkType = Literal["control", "inversion", "lora"]
class NetworkModel:

View File

@ -57,7 +57,7 @@ def json_params(
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
parent: Dict = None,
parent: Optional[Dict] = None,
) -> Any:
json = {
"input_size": size.tojson(),

View File

@ -163,8 +163,8 @@ def load_extras(server: ServerContext):
global extra_strings
global extra_tokens
labels = {}
strings = {}
labels: Dict[str, str] = {}
strings: Dict[str, Any] = {}
extra_schema = load_config("./schemas/extras.yaml")
@ -415,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
):
if potential == "cuda" or potential == "rocm":
for i in range(torch.cuda.device_count()):
options = {
options: Dict[str, Union[int, str]] = {
"device_id": i,
}

View File

@ -71,6 +71,7 @@ module = [
"realesrgan",
"realesrgan.archs.srvgg_arch",
"safetensors",
"scipy",
"timm.models.layers",
"transformers",
"win10toast"

View File

@ -274,6 +274,7 @@ class TestInpaintPipeline(unittest.TestCase):
3.0,
1,
1,
unet_tile=64,
),
Size(*source.size),
["test-inpaint-white.png"],
@ -310,6 +311,7 @@ class TestInpaintPipeline(unittest.TestCase):
3.0,
1,
1,
unet_tile=64,
),
Size(*source.size),
["test-inpaint-black.png"],