1
0
Fork 0

lint and type fixes

This commit is contained in:
Sean Sube 2023-12-03 12:13:45 -06:00
parent 05b1ef611f
commit 572db45bcb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 47 additions and 21 deletions

View File

@ -39,20 +39,26 @@ class StageResult:
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
else:
elif self.images is not None:
return len(self.images)
else:
raise ValueError("empty stage result")
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
raise ValueError("empty stage result")
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
raise ValueError("empty stage result")
def shape_mode(arr: np.ndarray) -> str:

View File

@ -6,7 +6,14 @@ import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
from ..params import (
DeviceParams,
ImageParams,
Size,
SizeChart,
StageParams,
UpscaleParams,
)
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -17,7 +24,7 @@ logger = getLogger(__name__)
class UpscaleBSRGANStage(BaseStage):
max_tile = 64
max_tile = SizeChart.micro
def load(
self,

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
@ -17,7 +17,7 @@ logger = getLogger(__name__)
class UpscaleSwinIRStage(BaseStage):
max_tile = 64
max_tile = SizeChart.micro
def load(
self,

View File

@ -43,7 +43,6 @@ logger = logging.get_logger(__name__)
# inpaint constants
NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4
DEFAULT_WINDOW = 32
DEFAULT_STRIDE = 8
@ -1215,7 +1214,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
)
num_channels_latents = NUM_LATENT_CHANNELS
num_channels_latents = LATENT_CHANNELS
latents_shape = (
batch_size * num_images_per_prompt,
num_channels_latents,

View File

@ -88,8 +88,8 @@ def expand_prompt(
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
skip_clip_states: Optional[int] = 0,
) -> "np.NDArray":
skip_clip_states: int = 0,
) -> np.ndarray:
# self provides:
# tokenizer: CLIPTokenizer
# encoder: OnnxRuntimeModel
@ -144,6 +144,7 @@ def expand_prompt(
last_state, _pooled_output, *hidden_states = text_result
if skip_clip_states > 0:
# TODO: why is this normalized?
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
norm_state = layer_norm(
torch.from_numpy(

View File

@ -14,6 +14,7 @@ Point = Tuple[int, int]
class SizeChart(IntEnum):
micro = 64
mini = 128 # small tile for very expensive models
half = 256 # half tile for outpainting
auto = 512 # auto tile size
@ -201,7 +202,7 @@ class ImageParams:
batch: int
control: Optional[NetworkModel]
input_prompt: str
input_negative_prompt: str
input_negative_prompt: Optional[str]
loopback: int
tiled_vae: bool
unet_tile: int
@ -257,7 +258,7 @@ class ImageParams:
def do_cfg(self):
return self.cfg > 1.0
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
pipeline = pipeline or self.pipeline
# if the correct pipeline was already requested, simply use that

View File

@ -20,7 +20,7 @@ SAFE_CHARS = "._-"
def split_list(val: str) -> List[str]:
parts = [part.strip() for part in val.split(",")]
return [part for part in parts if len(part.strip()) > 0]
return [part for part in parts if len(part) > 0]
def base_join(base: str, tail: str) -> str:
@ -75,13 +75,13 @@ def get_from_list(
def get_from_map(
args: Any, key: str, values: Dict[str, TElem], default: TElem
args: Any, key: str, values: Dict[str, TElem], defaultKey: str
) -> TElem:
selected = args.get(key, default)
selected = args.get(key, defaultKey)
if selected in values:
return values[selected]
else:
return values[default]
return values[defaultKey]
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
@ -209,6 +209,8 @@ def load_config(file: str) -> Dict:
return load_yaml(file)
elif ext in [".json"]:
return load_json(file)
else:
raise ValueError("unknown config file extension")
def load_config_str(raw: str) -> Dict:

View File

@ -9,7 +9,9 @@ skip_glob = ["*/lpw.py"]
[tool.mypy]
# ignore_missing_imports = true
exclude = [
"onnx_web.diffusers.lpw_stable_diffusion_onnx"
"onnx_web.diffusers.pipelines.controlnet",
"onnx_web.diffusers.pipelines.lpw",
"onnx_web.diffusers.pipelines.pix2pix"
]
[[tool.mypy.overrides]]
@ -27,8 +29,10 @@ module = [
"compel",
"controlnet_aux",
"cv2",
"debugpy",
"diffusers",
"diffusers.configuration_utils",
"diffusers.image_processor",
"diffusers.loaders",
"diffusers.models.attention_processor",
"diffusers.models.autoencoder_kl",
@ -44,6 +48,7 @@ module = [
"diffusers.pipelines.pipeline_utils",
"diffusers.pipelines.stable_diffusion",
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
"diffusers.pipelines.stable_diffusion_xl",
"diffusers.schedulers",
"diffusers.utils.logging",
"facexlib.utils",
@ -56,6 +61,11 @@ module = [
"mediapipe",
"onnxruntime",
"onnxruntime.transformers.float16",
"optimum.exporters.onnx",
"optimum.onnxruntime",
"optimum.onnxruntime.modeling_diffusion",
"optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img",
"optimum.pipelines.diffusers.pipeline_utils",
"piexif",
"piexif.helper",
"realesrgan",