1
0
Fork 0

lint and test fixes

This commit is contained in:
Sean Sube 2023-12-03 12:53:50 -06:00
parent b398d65624
commit 9b4ae0916b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
14 changed files with 20 additions and 19 deletions

View File

@ -18,7 +18,7 @@ class BaseStage:
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
_sources: StageResult, _sources: StageResult,
*args, *,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:

View File

@ -1,5 +1,5 @@
from logging import getLogger from logging import getLogger
from typing import Optional from typing import List, Optional
from PIL import Image from PIL import Image
@ -28,7 +28,7 @@ class BlendGridStage(BaseStage):
# rows: Optional[List[str]] = None, # rows: Optional[List[str]] = None,
# columns: Optional[List[str]] = None, # columns: Optional[List[str]] = None,
# title: Optional[str] = None, # title: Optional[str] = None,
order: Optional[int] = None, order: Optional[List[int]] = None,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**kwargs, **kwargs,

View File

@ -3,7 +3,6 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from ..constants import LATENT_FACTOR from ..constants import LATENT_FACTOR
from ..diffusers.load import load_pipeline from ..diffusers.load import load_pipeline
@ -41,7 +40,7 @@ class SourceTxt2ImgStage(BaseStage):
latents: Optional[np.ndarray] = None, latents: Optional[np.ndarray] = None,
prompt_index: Optional[int] = None, prompt_index: Optional[int] = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> StageResult:
params = params.with_args(**kwargs) params = params.with_args(**kwargs)
size = size.with_args(**kwargs) size = size.with_args(**kwargs)

View File

@ -22,7 +22,7 @@ class UpscaleHighresStage(BaseStage):
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
sources: StageResult, sources: StageResult,
*args, *,
highres: HighresParams, highres: HighresParams,
upscale: UpscaleParams, upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None, stage_source: Optional[Image.Image] = None,

View File

@ -103,7 +103,6 @@ def get_model_version(
opts["prediction_type"] = "epsilon" opts["prediction_type"] = "epsilon"
except Exception: except Exception:
logger.debug("unable to load tensor for version check") logger.debug("unable to load tensor for version check")
pass
return (v2, opts) return (v2, opts)

View File

@ -76,7 +76,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
names = [fix_node_name(node.name) for node in nodes] names = [fix_node_name(node.name) for node in nodes]
for key, value in keys.items(): for key, value in keys.items():
root, *rest = key.split(".") root, *_rest = key.split(".")
logger.trace("fixing XL node name: %s -> %s", key, root) logger.trace("fixing XL node name: %s -> %s", key, root)
simple = False simple = False

View File

@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
class ConversionContext(ServerContext): class ConversionContext(ServerContext):
def __init__( def __init__(
self, self,
model_path: Optional[str] = None, model_path: str = ".",
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
device: Optional[str] = None, device: Optional[str] = None,
half: bool = False, half: bool = False,
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
def tuple_to_source(model: Union[ModelDict, LegacyModel]): def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *_rest = model
return { return {
"name": name, "name": name,

View File

@ -505,7 +505,7 @@ def load_unet(
def load_vae( def load_vae(
server: ServerContext, device: DeviceParams, model: str, params: ImageParams _server: ServerContext, device: DeviceParams, model: str, params: ImageParams
): ):
# one or more VAE models need to be loaded # one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL) vae = path.join(model, "vae", ONNX_MODEL)

View File

@ -28,9 +28,9 @@ class UNetWrapper(object):
def __call__( def __call__(
self, self,
sample: np.ndarray = None, sample: Optional[np.ndarray] = None,
timestep: np.ndarray = None, timestep: Optional[np.ndarray] = None,
encoder_hidden_states: np.ndarray = None, encoder_hidden_states: Optional[np.ndarray] = None,
**kwargs, **kwargs,
): ):
logger.trace( logger.trace(

View File

@ -1,6 +1,6 @@
from typing import List, Literal from typing import List, Literal
NetworkType = Literal["inversion", "lora"] NetworkType = Literal["control", "inversion", "lora"]
class NetworkModel: class NetworkModel:

View File

@ -57,7 +57,7 @@ def json_params(
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
parent: Dict = None, parent: Optional[Dict] = None,
) -> Any: ) -> Any:
json = { json = {
"input_size": size.tojson(), "input_size": size.tojson(),

View File

@ -163,8 +163,8 @@ def load_extras(server: ServerContext):
global extra_strings global extra_strings
global extra_tokens global extra_tokens
labels = {} labels: Dict[str, str] = {}
strings = {} strings: Dict[str, Any] = {}
extra_schema = load_config("./schemas/extras.yaml") extra_schema = load_config("./schemas/extras.yaml")
@ -415,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
): ):
if potential == "cuda" or potential == "rocm": if potential == "cuda" or potential == "rocm":
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
options = { options: Dict[str, Union[int, str]] = {
"device_id": i, "device_id": i,
} }

View File

@ -71,6 +71,7 @@ module = [
"realesrgan", "realesrgan",
"realesrgan.archs.srvgg_arch", "realesrgan.archs.srvgg_arch",
"safetensors", "safetensors",
"scipy",
"timm.models.layers", "timm.models.layers",
"transformers", "transformers",
"win10toast" "win10toast"

View File

@ -274,6 +274,7 @@ class TestInpaintPipeline(unittest.TestCase):
3.0, 3.0,
1, 1,
1, 1,
unet_tile=64,
), ),
Size(*source.size), Size(*source.size),
["test-inpaint-white.png"], ["test-inpaint-white.png"],
@ -310,6 +311,7 @@ class TestInpaintPipeline(unittest.TestCase):
3.0, 3.0,
1, 1,
1, 1,
unet_tile=64,
), ),
Size(*source.size), Size(*source.size),
["test-inpaint-black.png"], ["test-inpaint-black.png"],