diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 84034fb8..02ad6c3a 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -18,7 +18,7 @@ class BaseStage: _stage: StageParams, _params: ImageParams, _sources: StageResult, - *args, + *, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 5ca17151..34e4f535 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Optional +from typing import List, Optional from PIL import Image @@ -28,7 +28,7 @@ class BlendGridStage(BaseStage): # rows: Optional[List[str]] = None, # columns: Optional[List[str]] = None, # title: Optional[str] = None, - order: Optional[int] = None, + order: Optional[List[int]] = None, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 468ceecd..571e58ad 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -3,7 +3,6 @@ from typing import Optional, Tuple import numpy as np import torch -from PIL import Image from ..constants import LATENT_FACTOR from ..diffusers.load import load_pipeline @@ -41,7 +40,7 @@ class SourceTxt2ImgStage(BaseStage): latents: Optional[np.ndarray] = None, prompt_index: Optional[int] = None, **kwargs, - ) -> Image.Image: + ) -> StageResult: params = params.with_args(**kwargs) size = size.with_args(**kwargs) diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index bd7f826a..32f891a6 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -22,7 +22,7 @@ class UpscaleHighresStage(BaseStage): stage: StageParams, params: ImageParams, sources: StageResult, - *args, + *, highres: HighresParams, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 45762ffe..a8ecbbf7 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -103,7 +103,6 @@ def get_model_version( opts["prediction_type"] = "epsilon" except Exception: logger.debug("unable to load tensor for version check") - pass return (v2, opts) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index f115942f..c4ed7fa8 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -76,7 +76,7 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] names = [fix_node_name(node.name) for node in nodes] for key, value in keys.items(): - root, *rest = key.split(".") + root, *_rest = key.split(".") logger.trace("fixing XL node name: %s -> %s", key, root) simple = False diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 61b6547f..b1babdd5 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -36,7 +36,7 @@ DEFAULT_OPSET = 14 class ConversionContext(ServerContext): def __init__( self, - model_path: Optional[str] = None, + model_path: str = ".", cache_path: Optional[str] = None, device: Optional[str] = None, half: bool = False, @@ -120,7 +120,7 @@ def download_progress(urls: List[Tuple[str, str]]): def tuple_to_source(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): - name, source, *rest = model + name, source, *_rest = model return { "name": name, diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 5088dc60..8d006f76 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -505,7 +505,7 @@ def load_unet( def load_vae( - server: ServerContext, device: DeviceParams, model: str, params: ImageParams + _server: ServerContext, device: DeviceParams, model: str, params: ImageParams ): # one or more VAE models need to be loaded vae = path.join(model, "vae", ONNX_MODEL) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index 6e15597f..81065d97 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -28,9 +28,9 @@ class UNetWrapper(object): def __call__( self, - sample: np.ndarray = None, - timestep: np.ndarray = None, - encoder_hidden_states: np.ndarray = None, + sample: Optional[np.ndarray] = None, + timestep: Optional[np.ndarray] = None, + encoder_hidden_states: Optional[np.ndarray] = None, **kwargs, ): logger.trace( diff --git a/api/onnx_web/models/meta.py b/api/onnx_web/models/meta.py index fd8b1297..6aaa4e52 100644 --- a/api/onnx_web/models/meta.py +++ b/api/onnx_web/models/meta.py @@ -1,6 +1,6 @@ from typing import List, Literal -NetworkType = Literal["inversion", "lora"] +NetworkType = Literal["control", "inversion", "lora"] class NetworkModel: diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index d64f79a0..ec76ce3d 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -57,7 +57,7 @@ def json_params( upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, highres: Optional[HighresParams] = None, - parent: Dict = None, + parent: Optional[Dict] = None, ) -> Any: json = { "input_size": size.tojson(), diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 0444cc82..6bf1de2d 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -163,8 +163,8 @@ def load_extras(server: ServerContext): global extra_strings global extra_tokens - labels = {} - strings = {} + labels: Dict[str, str] = {} + strings: Dict[str, Any] = {} extra_schema = load_config("./schemas/extras.yaml") @@ -415,7 +415,7 @@ def load_platforms(server: ServerContext) -> None: ): if potential == "cuda" or potential == "rocm": for i in range(torch.cuda.device_count()): - options = { + options: Dict[str, Union[int, str]] = { "device_id": i, } diff --git a/api/pyproject.toml b/api/pyproject.toml index bb47f16d..5d69e906 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -71,6 +71,7 @@ module = [ "realesrgan", "realesrgan.archs.srvgg_arch", "safetensors", + "scipy", "timm.models.layers", "transformers", "win10toast" diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 261566e0..26578f3e 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -274,6 +274,7 @@ class TestInpaintPipeline(unittest.TestCase): 3.0, 1, 1, + unet_tile=64, ), Size(*source.size), ["test-inpaint-white.png"], @@ -310,6 +311,7 @@ class TestInpaintPipeline(unittest.TestCase): 3.0, 1, 1, + unet_tile=64, ), Size(*source.size), ["test-inpaint-black.png"],