lint and test fixes
This commit is contained in:
parent
b398d65624
commit
9b4ae0916b
|
@ -18,7 +18,7 @@ class BaseStage:
|
|||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
_sources: StageResult,
|
||||
*args,
|
||||
*,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -28,7 +28,7 @@ class BlendGridStage(BaseStage):
|
|||
# rows: Optional[List[str]] = None,
|
||||
# columns: Optional[List[str]] = None,
|
||||
# title: Optional[str] = None,
|
||||
order: Optional[int] = None,
|
||||
order: Optional[List[int]] = None,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**kwargs,
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ..constants import LATENT_FACTOR
|
||||
from ..diffusers.load import load_pipeline
|
||||
|
@ -41,7 +40,7 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
latents: Optional[np.ndarray] = None,
|
||||
prompt_index: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
) -> StageResult:
|
||||
params = params.with_args(**kwargs)
|
||||
size = size.with_args(**kwargs)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ class UpscaleHighresStage(BaseStage):
|
|||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
*args,
|
||||
*,
|
||||
highres: HighresParams,
|
||||
upscale: UpscaleParams,
|
||||
stage_source: Optional[Image.Image] = None,
|
||||
|
|
|
@ -103,7 +103,6 @@ def get_model_version(
|
|||
opts["prediction_type"] = "epsilon"
|
||||
except Exception:
|
||||
logger.debug("unable to load tensor for version check")
|
||||
pass
|
||||
|
||||
return (v2, opts)
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
names = [fix_node_name(node.name) for node in nodes]
|
||||
|
||||
for key, value in keys.items():
|
||||
root, *rest = key.split(".")
|
||||
root, *_rest = key.split(".")
|
||||
logger.trace("fixing XL node name: %s -> %s", key, root)
|
||||
|
||||
simple = False
|
||||
|
|
|
@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
|
|||
class ConversionContext(ServerContext):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
model_path: str = ".",
|
||||
cache_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
half: bool = False,
|
||||
|
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
|||
|
||||
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||
if isinstance(model, list) or isinstance(model, tuple):
|
||||
name, source, *rest = model
|
||||
name, source, *_rest = model
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
|
|
|
@ -505,7 +505,7 @@ def load_unet(
|
|||
|
||||
|
||||
def load_vae(
|
||||
server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
||||
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
||||
):
|
||||
# one or more VAE models need to be loaded
|
||||
vae = path.join(model, "vae", ONNX_MODEL)
|
||||
|
|
|
@ -28,9 +28,9 @@ class UNetWrapper(object):
|
|||
|
||||
def __call__(
|
||||
self,
|
||||
sample: np.ndarray = None,
|
||||
timestep: np.ndarray = None,
|
||||
encoder_hidden_states: np.ndarray = None,
|
||||
sample: Optional[np.ndarray] = None,
|
||||
timestep: Optional[np.ndarray] = None,
|
||||
encoder_hidden_states: Optional[np.ndarray] = None,
|
||||
**kwargs,
|
||||
):
|
||||
logger.trace(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Literal
|
||||
|
||||
NetworkType = Literal["inversion", "lora"]
|
||||
NetworkType = Literal["control", "inversion", "lora"]
|
||||
|
||||
|
||||
class NetworkModel:
|
||||
|
|
|
@ -57,7 +57,7 @@ def json_params(
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
parent: Dict = None,
|
||||
parent: Optional[Dict] = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
"input_size": size.tojson(),
|
||||
|
|
|
@ -163,8 +163,8 @@ def load_extras(server: ServerContext):
|
|||
global extra_strings
|
||||
global extra_tokens
|
||||
|
||||
labels = {}
|
||||
strings = {}
|
||||
labels: Dict[str, str] = {}
|
||||
strings: Dict[str, Any] = {}
|
||||
|
||||
extra_schema = load_config("./schemas/extras.yaml")
|
||||
|
||||
|
@ -415,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
|
|||
):
|
||||
if potential == "cuda" or potential == "rocm":
|
||||
for i in range(torch.cuda.device_count()):
|
||||
options = {
|
||||
options: Dict[str, Union[int, str]] = {
|
||||
"device_id": i,
|
||||
}
|
||||
|
||||
|
|
|
@ -71,6 +71,7 @@ module = [
|
|||
"realesrgan",
|
||||
"realesrgan.archs.srvgg_arch",
|
||||
"safetensors",
|
||||
"scipy",
|
||||
"timm.models.layers",
|
||||
"transformers",
|
||||
"win10toast"
|
||||
|
|
|
@ -274,6 +274,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
|||
3.0,
|
||||
1,
|
||||
1,
|
||||
unet_tile=64,
|
||||
),
|
||||
Size(*source.size),
|
||||
["test-inpaint-white.png"],
|
||||
|
@ -310,6 +311,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
|||
3.0,
|
||||
1,
|
||||
1,
|
||||
unet_tile=64,
|
||||
),
|
||||
Size(*source.size),
|
||||
["test-inpaint-black.png"],
|
||||
|
|
Loading…
Reference in New Issue