1
0
Fork 0

apply lint, add missing file

This commit is contained in:
Sean Sube 2023-11-18 18:13:13 -06:00
parent a63669c76b
commit 7e6749e0d7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
32 changed files with 176 additions and 84 deletions

View File

@ -1,2 +1,2 @@
from .pipeline import ChainPipeline, PipelineStage, StageParams
from .stages import *
from .stages import * # NOQA

View File

@ -1,11 +1,11 @@
from typing import List, Optional
from typing import Optional
from PIL import Image
from .result import StageResult
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server.context import ServerContext
from ..worker.context import WorkerContext
from .result import StageResult
class BaseStage:

View File

@ -1,8 +1,7 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
import cv2
import numpy as np
from PIL import Image
from ..params import ImageParams, SizeChart, StageParams

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
import numpy as np
import torch

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
@ -28,4 +28,6 @@ class BlendLinearStage(BaseStage):
) -> StageResult:
logger.info("blending source images using linear interpolation")
return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources])
return StageResult(
images=[Image.blend(source, stage_source, alpha) for source in sources]
)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
@ -38,4 +38,8 @@ class BlendMaskStage(BaseStage):
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources])
return StageResult(
images=[
Image.composite(stage_source, source, mult_mask) for source in sources
]
)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image

View File

@ -1,8 +1,7 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
@ -74,12 +73,15 @@ class CorrectGFPGANStage(BaseStage):
device = worker.get_device()
gfpgan = self.load(server, stage, upscale, device)
outputs = [gfpgan.enhance(
outputs = [
gfpgan.enhance(
source,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
) for source in sources.as_numpy()]
)
for source in sources.as_numpy()
]
return StageResult(images=outputs)

View File

@ -1,11 +1,11 @@
from logging import getLogger
from typing import Optional
from .pipeline import ChainPipeline
from ..chain.blend_img2img import BlendImg2ImgStage
from ..chain.upscale import stage_upscale_correction
from ..chain.upscale_simple import UpscaleSimpleStage
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
from .pipeline import ChainPipeline
logger = getLogger(__name__)

View File

@ -1,6 +1,6 @@
from io import BytesIO
from logging import getLogger
from typing import List, Optional
from typing import Optional
from boto3 import Session
from PIL import Image

View File

@ -12,8 +12,8 @@ from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .tile import needs_tile, process_tile_order
from .result import StageResult
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
@ -76,7 +76,7 @@ class ChainPipeline:
params: ImageParams,
sources: StageResult,
callback: Optional[ProgressCallback],
**kwargs
**kwargs,
) -> StageResult:
result = self(
worker, server, params, sources=sources, callback=callback, **kwargs
@ -108,10 +108,10 @@ class ChainPipeline:
params: ImageParams,
sources: StageResult,
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs
**pipeline_kwargs,
) -> StageResult:
"""
DEPRECATED: use `run` instead
DEPRECATED: use `.run()` instead
"""
if callback is None:
callback = worker.get_progress_callback()

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image

View File

@ -1,5 +1,4 @@
from logging import getLogger
from typing import List
from PIL import Image

View File

@ -1,41 +1,43 @@
from PIL import Image
from typing import List, Optional
import numpy as np
from PIL import Image
class StageResult:
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
def __init__(self, arrays = None, images = None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
self.arrays = arrays
self.images = images
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
else:
return len(self.images)
self.arrays = arrays
self.images = images
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
else:
return len(self.images)
return [np.array(i) for i in self.images]
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
return [np.array(i) for i in self.images]
return [Image.fromarray(i, "RGB") for i in self.arrays]
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays]

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Callable, List, Optional
from typing import Callable, Optional
from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import numpy as np
import torch
@ -132,7 +132,9 @@ class SourceTxt2ImgStage(BaseStage):
else:
# encode and record alternative prompts outside of LPW
if params.is_panorama() or params.is_xl():
logger.debug("prompt alternatives are not supported for panorama or SDXL")
logger.debug(
"prompt alternatives are not supported for panorama or SDXL"
)
else:
prompt_embeds = encode_prompt(
pipe, prompt_pairs, params.batch, params.do_cfg()

View File

@ -2,8 +2,8 @@ from logging import getLogger
from .base import BaseStage
from .blend_denoise import BlendDenoiseStage
from .blend_img2img import BlendImg2ImgStage
from .blend_grid import BlendGridStage
from .blend_img2img import BlendImg2ImgStage
from .blend_linear import BlendLinearStage
from .blend_mask import BlendMaskStage
from .correct_codeformer import CorrectCodeformerStage
@ -54,11 +54,11 @@ CHAIN_STAGES = {
def add_stage(name: str, stage: BaseStage) -> bool:
global CHAIN_STAGES
global CHAIN_STAGES
if name in CHAIN_STAGES:
logger.warning("cannot replace stage: %s", name)
return False
else:
CHAIN_STAGES[name] = stage
return True
if name in CHAIN_STAGES:
logger.warning("cannot replace stage: %s", name)
return False
else:
CHAIN_STAGES[name] = stage
return True

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image
@ -45,4 +45,4 @@ class UpscaleHighresStage(BaseStage):
for source in sources
]
return StageResult(images=outputs)
return StageResult(images=outputs)

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Callable, List, Optional, Tuple
from typing import Callable, Optional, Tuple
import numpy as np
import torch

View File

@ -1,8 +1,7 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
from ..onnx import OnnxRRDBNet

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional
from PIL import Image

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import torch
from PIL import Image

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import List, Optional
from typing import Optional
import numpy as np
from PIL import Image
@ -80,12 +80,15 @@ class UpscaleSwinIRStage(BaseStage):
logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
logger.trace("SwinIR output shape: %s", (
logger.trace(
"SwinIR output shape: %s",
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
))
),
)
output = swinir(image)
output = np.clip(np.squeeze(output, axis=0), 0, 1)

View File

@ -81,7 +81,9 @@ def convert_diffusion_diffusers_xl(
output=dest_path,
task="stable-diffusion-xl",
device=device,
fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm
fp16=conversion.has_optimization(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
)

View File

@ -574,10 +574,9 @@ def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
) -> None:
if (
server.has_optimization("diffusers-attention-slicing")
or server.has_optimization("diffusers-attention-slicing-auto")
):
if server.has_optimization(
"diffusers-attention-slicing"
) or server.has_optimization("diffusers-attention-slicing-auto"):
logger.debug("enabling auto attention slicing on SD pipeline")
try:
pipe.enable_attention_slicing(slice_size="auto")

View File

@ -13,6 +13,7 @@ from ..chain import (
UpscaleOutpaintStage,
)
from ..chain.highres import stage_highres
from ..chain.result import StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image

View File

@ -208,6 +208,7 @@ class ImageParams:
unet_overlap: float
vae_tile: int
vae_overlap: float
denoise: int
def __init__(
self,
@ -230,6 +231,7 @@ class ImageParams:
unet_tile: int = 512,
vae_overlap: float = 0.25,
vae_tile: int = 512,
denoise: int = 3,
) -> None:
self.model = model
self.pipeline = pipeline
@ -250,6 +252,7 @@ class ImageParams:
self.unet_tile = unet_tile
self.vae_overlap = vae_overlap
self.vae_tile = vae_tile
self.denoise = denoise
def do_cfg(self):
return self.cfg > 1.0
@ -320,6 +323,7 @@ class ImageParams:
"unet_tile": self.unet_tile,
"vae_overlap": self.vae_overlap,
"vae_tile": self.vae_tile,
"denoise": self.denoise,
}
def with_args(self, **kwargs):
@ -343,6 +347,7 @@ class ImageParams:
kwargs.get("unet_tile", self.unet_tile),
kwargs.get("vae_overlap", self.vae_overlap),
kwargs.get("vae_tile", self.vae_tile),
kwargs.get("denoise", self.denoise),
)

View File

@ -138,7 +138,9 @@ def apply_patch_basicsr(server: ServerContext):
import basicsr.utils.download_util
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
logger.info("unable to import basicsr utils for patching")
except AttributeError:
@ -151,7 +153,9 @@ def apply_patch_codeformer(server: ServerContext):
import codeformer.facelib.utils.misc
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
codeformer.facelib.utils.misc.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError:
logger.info("unable to import codeformer utils for patching")
except AttributeError:

View File

@ -0,0 +1,61 @@
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext
logger = getLogger(__name__)
class PluginExports:
pipelines: Dict[str, Any]
stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None:
self.pipelines = pipelines or {}
self.stages = stages or {}
PluginModule = Callable[[ServerContext], PluginExports]
def load_plugins(server: ServerContext) -> PluginExports:
combined_exports = PluginExports()
for plugin in server.plugins:
logger.info("loading plugin module: %s", plugin)
try:
module: PluginModule = import_module(plugin)
exports = module(server)
for name, pipeline in exports.pipelines.items():
if name in combined_exports.pipelines:
logger.warning(
"multiple plugins exported a pipeline named %s", name
)
else:
combined_exports.pipelines[name] = pipeline
for name, stage in exports.stages.items():
if name in combined_exports.stages:
logger.warning("multiple plugins exported a stage named %s", name)
else:
combined_exports.stages[name] = stage
except Exception:
logger.exception("error importing plugin")
return combined_exports
def register_plugins(exports: PluginExports) -> bool:
success = True
for name, pipeline in exports.pipelines.items():
success = success and add_pipeline(name, pipeline)
for name, stage in exports.stages.items():
success = success and add_stage(name, stage)
return success

View File

@ -1,14 +1,22 @@
import unittest
from multiprocessing import Queue, Value
from os import getpid
from onnx_web.errors import RetryException
from onnx_web.errors import RetryException
from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobCommand
from onnx_web.worker.context import WorkerContext
from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main
from onnx_web.worker.worker import (
EXIT_ERROR,
EXIT_INTERRUPT,
EXIT_MEMORY,
EXIT_REPLACED,
MEMORY_ERRORS,
worker_main,
)
from tests.helpers import test_device
def main_memory(_worker):
raise Exception(MEMORY_ERRORS[0])