lint and test fixes
This commit is contained in:
parent
b398d65624
commit
9b4ae0916b
|
@ -18,7 +18,7 @@ class BaseStage:
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
_sources: StageResult,
|
_sources: StageResult,
|
||||||
*args,
|
*,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class BlendGridStage(BaseStage):
|
||||||
# rows: Optional[List[str]] = None,
|
# rows: Optional[List[str]] = None,
|
||||||
# columns: Optional[List[str]] = None,
|
# columns: Optional[List[str]] = None,
|
||||||
# title: Optional[str] = None,
|
# title: Optional[str] = None,
|
||||||
order: Optional[int] = None,
|
order: Optional[List[int]] = None,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..constants import LATENT_FACTOR
|
from ..constants import LATENT_FACTOR
|
||||||
from ..diffusers.load import load_pipeline
|
from ..diffusers.load import load_pipeline
|
||||||
|
@ -41,7 +40,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
latents: Optional[np.ndarray] = None,
|
latents: Optional[np.ndarray] = None,
|
||||||
prompt_index: Optional[int] = None,
|
prompt_index: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> StageResult:
|
||||||
params = params.with_args(**kwargs)
|
params = params.with_args(**kwargs)
|
||||||
size = size.with_args(**kwargs)
|
size = size.with_args(**kwargs)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ class UpscaleHighresStage(BaseStage):
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
*args,
|
*,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
|
|
|
@ -103,7 +103,6 @@ def get_model_version(
|
||||||
opts["prediction_type"] = "epsilon"
|
opts["prediction_type"] = "epsilon"
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("unable to load tensor for version check")
|
logger.debug("unable to load tensor for version check")
|
||||||
pass
|
|
||||||
|
|
||||||
return (v2, opts)
|
return (v2, opts)
|
||||||
|
|
||||||
|
|
|
@ -76,7 +76,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
names = [fix_node_name(node.name) for node in nodes]
|
names = [fix_node_name(node.name) for node in nodes]
|
||||||
|
|
||||||
for key, value in keys.items():
|
for key, value in keys.items():
|
||||||
root, *rest = key.split(".")
|
root, *_rest = key.split(".")
|
||||||
logger.trace("fixing XL node name: %s -> %s", key, root)
|
logger.trace("fixing XL node name: %s -> %s", key, root)
|
||||||
|
|
||||||
simple = False
|
simple = False
|
||||||
|
|
|
@ -36,7 +36,7 @@ DEFAULT_OPSET = 14
|
||||||
class ConversionContext(ServerContext):
|
class ConversionContext(ServerContext):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: Optional[str] = None,
|
model_path: str = ".",
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
half: bool = False,
|
half: bool = False,
|
||||||
|
@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
||||||
|
|
||||||
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *_rest = model
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
|
|
@ -505,7 +505,7 @@ def load_unet(
|
||||||
|
|
||||||
|
|
||||||
def load_vae(
|
def load_vae(
|
||||||
server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
_server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
||||||
):
|
):
|
||||||
# one or more VAE models need to be loaded
|
# one or more VAE models need to be loaded
|
||||||
vae = path.join(model, "vae", ONNX_MODEL)
|
vae = path.join(model, "vae", ONNX_MODEL)
|
||||||
|
|
|
@ -28,9 +28,9 @@ class UNetWrapper(object):
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
sample: np.ndarray = None,
|
sample: Optional[np.ndarray] = None,
|
||||||
timestep: np.ndarray = None,
|
timestep: Optional[np.ndarray] = None,
|
||||||
encoder_hidden_states: np.ndarray = None,
|
encoder_hidden_states: Optional[np.ndarray] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
logger.trace(
|
logger.trace(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
NetworkType = Literal["inversion", "lora"]
|
NetworkType = Literal["control", "inversion", "lora"]
|
||||||
|
|
||||||
|
|
||||||
class NetworkModel:
|
class NetworkModel:
|
||||||
|
|
|
@ -57,7 +57,7 @@ def json_params(
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
parent: Dict = None,
|
parent: Optional[Dict] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
json = {
|
json = {
|
||||||
"input_size": size.tojson(),
|
"input_size": size.tojson(),
|
||||||
|
|
|
@ -163,8 +163,8 @@ def load_extras(server: ServerContext):
|
||||||
global extra_strings
|
global extra_strings
|
||||||
global extra_tokens
|
global extra_tokens
|
||||||
|
|
||||||
labels = {}
|
labels: Dict[str, str] = {}
|
||||||
strings = {}
|
strings: Dict[str, Any] = {}
|
||||||
|
|
||||||
extra_schema = load_config("./schemas/extras.yaml")
|
extra_schema = load_config("./schemas/extras.yaml")
|
||||||
|
|
||||||
|
@ -415,7 +415,7 @@ def load_platforms(server: ServerContext) -> None:
|
||||||
):
|
):
|
||||||
if potential == "cuda" or potential == "rocm":
|
if potential == "cuda" or potential == "rocm":
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
options = {
|
options: Dict[str, Union[int, str]] = {
|
||||||
"device_id": i,
|
"device_id": i,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,6 +71,7 @@ module = [
|
||||||
"realesrgan",
|
"realesrgan",
|
||||||
"realesrgan.archs.srvgg_arch",
|
"realesrgan.archs.srvgg_arch",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
|
"scipy",
|
||||||
"timm.models.layers",
|
"timm.models.layers",
|
||||||
"transformers",
|
"transformers",
|
||||||
"win10toast"
|
"win10toast"
|
||||||
|
|
|
@ -274,6 +274,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
||||||
3.0,
|
3.0,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
|
unet_tile=64,
|
||||||
),
|
),
|
||||||
Size(*source.size),
|
Size(*source.size),
|
||||||
["test-inpaint-white.png"],
|
["test-inpaint-white.png"],
|
||||||
|
@ -310,6 +311,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
||||||
3.0,
|
3.0,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
|
unet_tile=64,
|
||||||
),
|
),
|
||||||
Size(*source.size),
|
Size(*source.size),
|
||||||
["test-inpaint-black.png"],
|
["test-inpaint-black.png"],
|
||||||
|
|
Loading…
Reference in New Issue