lint and type fixes
This commit is contained in:
parent
05b1ef611f
commit
572db45bcb
|
@ -39,20 +39,26 @@ class StageResult:
|
|||
def __len__(self) -> int:
|
||||
if self.arrays is not None:
|
||||
return len(self.arrays)
|
||||
else:
|
||||
elif self.images is not None:
|
||||
return len(self.images)
|
||||
else:
|
||||
raise ValueError("empty stage result")
|
||||
|
||||
def as_numpy(self) -> List[np.ndarray]:
|
||||
if self.arrays is not None:
|
||||
return self.arrays
|
||||
|
||||
return [np.array(i) for i in self.images]
|
||||
elif self.images is not None:
|
||||
return [np.array(i) for i in self.images]
|
||||
else:
|
||||
raise ValueError("empty stage result")
|
||||
|
||||
def as_image(self) -> List[Image.Image]:
|
||||
if self.images is not None:
|
||||
return self.images
|
||||
|
||||
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
||||
elif self.arrays is not None:
|
||||
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
||||
else:
|
||||
raise ValueError("empty stage result")
|
||||
|
||||
|
||||
def shape_mode(arr: np.ndarray) -> str:
|
||||
|
|
|
@ -6,7 +6,14 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||
from ..params import (
|
||||
DeviceParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
SizeChart,
|
||||
StageParams,
|
||||
UpscaleParams,
|
||||
)
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
@ -17,7 +24,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class UpscaleBSRGANStage(BaseStage):
|
||||
max_tile = 64
|
||||
max_tile = SizeChart.micro
|
||||
|
||||
def load(
|
||||
self,
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
|
||||
from ..models.onnx import OnnxModel
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from ..server import ModelTypes, ServerContext
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
|
@ -17,7 +17,7 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class UpscaleSwinIRStage(BaseStage):
|
||||
max_tile = 64
|
||||
max_tile = SizeChart.micro
|
||||
|
||||
def load(
|
||||
self,
|
||||
|
|
|
@ -43,7 +43,6 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
# inpaint constants
|
||||
NUM_UNET_INPUT_CHANNELS = 9
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
|
||||
DEFAULT_WINDOW = 32
|
||||
DEFAULT_STRIDE = 8
|
||||
|
@ -1215,7 +1214,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
num_channels_latents = NUM_LATENT_CHANNELS
|
||||
num_channels_latents = LATENT_CHANNELS
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
|
|
@ -88,8 +88,8 @@ def expand_prompt(
|
|||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
skip_clip_states: Optional[int] = 0,
|
||||
) -> "np.NDArray":
|
||||
skip_clip_states: int = 0,
|
||||
) -> np.ndarray:
|
||||
# self provides:
|
||||
# tokenizer: CLIPTokenizer
|
||||
# encoder: OnnxRuntimeModel
|
||||
|
@ -144,6 +144,7 @@ def expand_prompt(
|
|||
|
||||
last_state, _pooled_output, *hidden_states = text_result
|
||||
if skip_clip_states > 0:
|
||||
# TODO: why is this normalized?
|
||||
layer_norm = torch.nn.LayerNorm(last_state.shape[2])
|
||||
norm_state = layer_norm(
|
||||
torch.from_numpy(
|
||||
|
|
|
@ -14,6 +14,7 @@ Point = Tuple[int, int]
|
|||
|
||||
|
||||
class SizeChart(IntEnum):
|
||||
micro = 64
|
||||
mini = 128 # small tile for very expensive models
|
||||
half = 256 # half tile for outpainting
|
||||
auto = 512 # auto tile size
|
||||
|
@ -201,7 +202,7 @@ class ImageParams:
|
|||
batch: int
|
||||
control: Optional[NetworkModel]
|
||||
input_prompt: str
|
||||
input_negative_prompt: str
|
||||
input_negative_prompt: Optional[str]
|
||||
loopback: int
|
||||
tiled_vae: bool
|
||||
unet_tile: int
|
||||
|
@ -257,7 +258,7 @@ class ImageParams:
|
|||
def do_cfg(self):
|
||||
return self.cfg > 1.0
|
||||
|
||||
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
|
||||
def get_valid_pipeline(self, group: str, pipeline: Optional[str] = None) -> str:
|
||||
pipeline = pipeline or self.pipeline
|
||||
|
||||
# if the correct pipeline was already requested, simply use that
|
||||
|
|
|
@ -20,7 +20,7 @@ SAFE_CHARS = "._-"
|
|||
|
||||
def split_list(val: str) -> List[str]:
|
||||
parts = [part.strip() for part in val.split(",")]
|
||||
return [part for part in parts if len(part.strip()) > 0]
|
||||
return [part for part in parts if len(part) > 0]
|
||||
|
||||
|
||||
def base_join(base: str, tail: str) -> str:
|
||||
|
@ -75,13 +75,13 @@ def get_from_list(
|
|||
|
||||
|
||||
def get_from_map(
|
||||
args: Any, key: str, values: Dict[str, TElem], default: TElem
|
||||
args: Any, key: str, values: Dict[str, TElem], defaultKey: str
|
||||
) -> TElem:
|
||||
selected = args.get(key, default)
|
||||
selected = args.get(key, defaultKey)
|
||||
if selected in values:
|
||||
return values[selected]
|
||||
else:
|
||||
return values[default]
|
||||
return values[defaultKey]
|
||||
|
||||
|
||||
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
||||
|
@ -209,6 +209,8 @@ def load_config(file: str) -> Dict:
|
|||
return load_yaml(file)
|
||||
elif ext in [".json"]:
|
||||
return load_json(file)
|
||||
else:
|
||||
raise ValueError("unknown config file extension")
|
||||
|
||||
|
||||
def load_config_str(raw: str) -> Dict:
|
||||
|
|
|
@ -9,12 +9,14 @@ skip_glob = ["*/lpw.py"]
|
|||
[tool.mypy]
|
||||
# ignore_missing_imports = true
|
||||
exclude = [
|
||||
"onnx_web.diffusers.lpw_stable_diffusion_onnx"
|
||||
"onnx_web.diffusers.pipelines.controlnet",
|
||||
"onnx_web.diffusers.pipelines.lpw",
|
||||
"onnx_web.diffusers.pipelines.pix2pix"
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"arpeggio",
|
||||
"arpeggio",
|
||||
"basicsr.archs.rrdbnet_arch",
|
||||
"basicsr.utils.download_util",
|
||||
"basicsr.utils",
|
||||
|
@ -27,8 +29,10 @@ module = [
|
|||
"compel",
|
||||
"controlnet_aux",
|
||||
"cv2",
|
||||
"debugpy",
|
||||
"diffusers",
|
||||
"diffusers.configuration_utils",
|
||||
"diffusers.image_processor",
|
||||
"diffusers.loaders",
|
||||
"diffusers.models.attention_processor",
|
||||
"diffusers.models.autoencoder_kl",
|
||||
|
@ -44,6 +48,7 @@ module = [
|
|||
"diffusers.pipelines.pipeline_utils",
|
||||
"diffusers.pipelines.stable_diffusion",
|
||||
"diffusers.pipelines.stable_diffusion.convert_from_ckpt",
|
||||
"diffusers.pipelines.stable_diffusion_xl",
|
||||
"diffusers.schedulers",
|
||||
"diffusers.utils.logging",
|
||||
"facexlib.utils",
|
||||
|
@ -56,6 +61,11 @@ module = [
|
|||
"mediapipe",
|
||||
"onnxruntime",
|
||||
"onnxruntime.transformers.float16",
|
||||
"optimum.exporters.onnx",
|
||||
"optimum.onnxruntime",
|
||||
"optimum.onnxruntime.modeling_diffusion",
|
||||
"optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img",
|
||||
"optimum.pipelines.diffusers.pipeline_utils",
|
||||
"piexif",
|
||||
"piexif.helper",
|
||||
"realesrgan",
|
||||
|
|
Loading…
Reference in New Issue