apply lint, add missing file
This commit is contained in:
parent
a63669c76b
commit
7e6749e0d7
|
@ -1,2 +1,2 @@
|
|||
from .pipeline import ChainPipeline, PipelineStage, StageParams
|
||||
from .stages import *
|
||||
from .stages import * # NOQA
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .result import StageResult
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server.context import ServerContext
|
||||
from ..worker.context import WorkerContext
|
||||
from .result import StageResult
|
||||
|
||||
|
||||
class BaseStage:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -28,4 +28,6 @@ class BlendLinearStage(BaseStage):
|
|||
) -> StageResult:
|
||||
logger.info("blending source images using linear interpolation")
|
||||
|
||||
return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources])
|
||||
return StageResult(
|
||||
images=[Image.blend(source, stage_source, alpha) for source in sources]
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -38,4 +38,8 @@ class BlendMaskStage(BaseStage):
|
|||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-mult-mask.png", mult_mask)
|
||||
|
||||
return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources])
|
||||
return StageResult(
|
||||
images=[
|
||||
Image.composite(stage_source, source, mult_mask) for source in sources
|
||||
]
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
|
@ -74,12 +73,15 @@ class CorrectGFPGANStage(BaseStage):
|
|||
device = worker.get_device()
|
||||
gfpgan = self.load(server, stage, upscale, device)
|
||||
|
||||
outputs = [gfpgan.enhance(
|
||||
outputs = [
|
||||
gfpgan.enhance(
|
||||
source,
|
||||
has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True,
|
||||
weight=upscale.face_strength,
|
||||
) for source in sources.as_numpy()]
|
||||
)
|
||||
for source in sources.as_numpy()
|
||||
]
|
||||
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from .pipeline import ChainPipeline
|
||||
from ..chain.blend_img2img import BlendImg2ImgStage
|
||||
from ..chain.upscale import stage_upscale_correction
|
||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||
from .pipeline import ChainPipeline
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from boto3 import Session
|
||||
from PIL import Image
|
||||
|
|
|
@ -12,8 +12,8 @@ from ..server import ServerContext
|
|||
from ..utils import is_debug, run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .tile import needs_tile, process_tile_order
|
||||
from .result import StageResult
|
||||
from .tile import needs_tile, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -76,7 +76,7 @@ class ChainPipeline:
|
|||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
callback: Optional[ProgressCallback],
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
result = self(
|
||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||
|
@ -108,10 +108,10 @@ class ChainPipeline:
|
|||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
**pipeline_kwargs
|
||||
**pipeline_kwargs,
|
||||
) -> StageResult:
|
||||
"""
|
||||
DEPRECATED: use `run` instead
|
||||
DEPRECATED: use `.run()` instead
|
||||
"""
|
||||
if callback is None:
|
||||
callback = worker.get_progress_callback()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from logging import getLogger
|
||||
from typing import List
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,41 +1,43 @@
|
|||
from PIL import Image
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class StageResult:
|
||||
"""
|
||||
Chain pipeline stage result.
|
||||
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
||||
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
||||
they are expected.
|
||||
"""
|
||||
arrays: Optional[List[np.ndarray]]
|
||||
images: Optional[List[Image.Image]]
|
||||
"""
|
||||
Chain pipeline stage result.
|
||||
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
||||
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
||||
they are expected.
|
||||
"""
|
||||
|
||||
def __init__(self, arrays = None, images = None) -> None:
|
||||
if arrays is not None and images is not None:
|
||||
raise ValueError("stages must only return one type of result")
|
||||
elif arrays is None and images is None:
|
||||
raise ValueError("stages must return results")
|
||||
arrays: Optional[List[np.ndarray]]
|
||||
images: Optional[List[Image.Image]]
|
||||
|
||||
self.arrays = arrays
|
||||
self.images = images
|
||||
def __init__(self, arrays=None, images=None) -> None:
|
||||
if arrays is not None and images is not None:
|
||||
raise ValueError("stages must only return one type of result")
|
||||
elif arrays is None and images is None:
|
||||
raise ValueError("stages must return results")
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.arrays is not None:
|
||||
return len(self.arrays)
|
||||
else:
|
||||
return len(self.images)
|
||||
self.arrays = arrays
|
||||
self.images = images
|
||||
|
||||
def as_numpy(self) -> List[np.ndarray]:
|
||||
if self.arrays is not None:
|
||||
return self.arrays
|
||||
def __len__(self) -> int:
|
||||
if self.arrays is not None:
|
||||
return len(self.arrays)
|
||||
else:
|
||||
return len(self.images)
|
||||
|
||||
return [np.array(i) for i in self.images]
|
||||
def as_numpy(self) -> List[np.ndarray]:
|
||||
if self.arrays is not None:
|
||||
return self.arrays
|
||||
|
||||
def as_image(self) -> List[Image.Image]:
|
||||
if self.images is not None:
|
||||
return self.images
|
||||
return [np.array(i) for i in self.images]
|
||||
|
||||
return [Image.fromarray(i, "RGB") for i in self.arrays]
|
||||
def as_image(self) -> List[Image.Image]:
|
||||
if self.images is not None:
|
||||
return self.images
|
||||
|
||||
return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -132,7 +132,9 @@ class SourceTxt2ImgStage(BaseStage):
|
|||
else:
|
||||
# encode and record alternative prompts outside of LPW
|
||||
if params.is_panorama() or params.is_xl():
|
||||
logger.debug("prompt alternatives are not supported for panorama or SDXL")
|
||||
logger.debug(
|
||||
"prompt alternatives are not supported for panorama or SDXL"
|
||||
)
|
||||
else:
|
||||
prompt_embeds = encode_prompt(
|
||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
|||
|
||||
from .base import BaseStage
|
||||
from .blend_denoise import BlendDenoiseStage
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_grid import BlendGridStage
|
||||
from .blend_img2img import BlendImg2ImgStage
|
||||
from .blend_linear import BlendLinearStage
|
||||
from .blend_mask import BlendMaskStage
|
||||
from .correct_codeformer import CorrectCodeformerStage
|
||||
|
@ -54,11 +54,11 @@ CHAIN_STAGES = {
|
|||
|
||||
|
||||
def add_stage(name: str, stage: BaseStage) -> bool:
|
||||
global CHAIN_STAGES
|
||||
global CHAIN_STAGES
|
||||
|
||||
if name in CHAIN_STAGES:
|
||||
logger.warning("cannot replace stage: %s", name)
|
||||
return False
|
||||
else:
|
||||
CHAIN_STAGES[name] = stage
|
||||
return True
|
||||
if name in CHAIN_STAGES:
|
||||
logger.warning("cannot replace stage: %s", name)
|
||||
return False
|
||||
else:
|
||||
CHAIN_STAGES[name] = stage
|
||||
return True
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
@ -45,4 +45,4 @@ class UpscaleHighresStage(BaseStage):
|
|||
for source in sources
|
||||
]
|
||||
|
||||
return StageResult(images=outputs)
|
||||
return StageResult(images=outputs)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..onnx import OnnxRRDBNet
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -80,12 +80,15 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
logger.trace("SwinIR input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
logger.trace("SwinIR output shape: %s", (
|
||||
logger.trace(
|
||||
"SwinIR output shape: %s",
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
))
|
||||
),
|
||||
)
|
||||
|
||||
output = swinir(image)
|
||||
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||
|
|
|
@ -81,7 +81,9 @@ def convert_diffusion_diffusers_xl(
|
|||
output=dest_path,
|
||||
task="stable-diffusion-xl",
|
||||
device=device,
|
||||
fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm
|
||||
fp16=conversion.has_optimization(
|
||||
"torch-fp16"
|
||||
), # optimum's fp16 mode only works on CUDA or ROCm
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
|
|
|
@ -574,10 +574,9 @@ def optimize_pipeline(
|
|||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
) -> None:
|
||||
if (
|
||||
server.has_optimization("diffusers-attention-slicing")
|
||||
or server.has_optimization("diffusers-attention-slicing-auto")
|
||||
):
|
||||
if server.has_optimization(
|
||||
"diffusers-attention-slicing"
|
||||
) or server.has_optimization("diffusers-attention-slicing-auto"):
|
||||
logger.debug("enabling auto attention slicing on SD pipeline")
|
||||
try:
|
||||
pipe.enable_attention_slicing(slice_size="auto")
|
||||
|
|
|
@ -13,6 +13,7 @@ from ..chain import (
|
|||
UpscaleOutpaintStage,
|
||||
)
|
||||
from ..chain.highres import stage_highres
|
||||
from ..chain.result import StageResult
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..image import expand_image
|
||||
from ..output import save_image
|
||||
|
|
|
@ -208,6 +208,7 @@ class ImageParams:
|
|||
unet_overlap: float
|
||||
vae_tile: int
|
||||
vae_overlap: float
|
||||
denoise: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -230,6 +231,7 @@ class ImageParams:
|
|||
unet_tile: int = 512,
|
||||
vae_overlap: float = 0.25,
|
||||
vae_tile: int = 512,
|
||||
denoise: int = 3,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
|
@ -250,6 +252,7 @@ class ImageParams:
|
|||
self.unet_tile = unet_tile
|
||||
self.vae_overlap = vae_overlap
|
||||
self.vae_tile = vae_tile
|
||||
self.denoise = denoise
|
||||
|
||||
def do_cfg(self):
|
||||
return self.cfg > 1.0
|
||||
|
@ -320,6 +323,7 @@ class ImageParams:
|
|||
"unet_tile": self.unet_tile,
|
||||
"vae_overlap": self.vae_overlap,
|
||||
"vae_tile": self.vae_tile,
|
||||
"denoise": self.denoise,
|
||||
}
|
||||
|
||||
def with_args(self, **kwargs):
|
||||
|
@ -343,6 +347,7 @@ class ImageParams:
|
|||
kwargs.get("unet_tile", self.unet_tile),
|
||||
kwargs.get("vae_overlap", self.vae_overlap),
|
||||
kwargs.get("vae_tile", self.vae_tile),
|
||||
kwargs.get("denoise", self.denoise),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -138,7 +138,9 @@ def apply_patch_basicsr(server: ServerContext):
|
|||
import basicsr.utils.download_util
|
||||
|
||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
||||
basicsr.utils.download_util.load_file_from_url = partial(
|
||||
patch_cache_path, server
|
||||
)
|
||||
except ImportError:
|
||||
logger.info("unable to import basicsr utils for patching")
|
||||
except AttributeError:
|
||||
|
@ -151,7 +153,9 @@ def apply_patch_codeformer(server: ServerContext):
|
|||
import codeformer.facelib.utils.misc
|
||||
|
||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(
|
||||
patch_cache_path, server
|
||||
)
|
||||
except ImportError:
|
||||
logger.info("unable to import codeformer utils for patching")
|
||||
except AttributeError:
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
from importlib import import_module
|
||||
from logging import getLogger
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from onnx_web.chain.stages import add_stage
|
||||
from onnx_web.diffusers.load import add_pipeline
|
||||
from onnx_web.server.context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class PluginExports:
|
||||
pipelines: Dict[str, Any]
|
||||
stages: Dict[str, Any]
|
||||
|
||||
def __init__(self, pipelines=None, stages=None) -> None:
|
||||
self.pipelines = pipelines or {}
|
||||
self.stages = stages or {}
|
||||
|
||||
|
||||
PluginModule = Callable[[ServerContext], PluginExports]
|
||||
|
||||
|
||||
def load_plugins(server: ServerContext) -> PluginExports:
|
||||
combined_exports = PluginExports()
|
||||
|
||||
for plugin in server.plugins:
|
||||
logger.info("loading plugin module: %s", plugin)
|
||||
try:
|
||||
module: PluginModule = import_module(plugin)
|
||||
exports = module(server)
|
||||
|
||||
for name, pipeline in exports.pipelines.items():
|
||||
if name in combined_exports.pipelines:
|
||||
logger.warning(
|
||||
"multiple plugins exported a pipeline named %s", name
|
||||
)
|
||||
else:
|
||||
combined_exports.pipelines[name] = pipeline
|
||||
|
||||
for name, stage in exports.stages.items():
|
||||
if name in combined_exports.stages:
|
||||
logger.warning("multiple plugins exported a stage named %s", name)
|
||||
else:
|
||||
combined_exports.stages[name] = stage
|
||||
except Exception:
|
||||
logger.exception("error importing plugin")
|
||||
|
||||
return combined_exports
|
||||
|
||||
|
||||
def register_plugins(exports: PluginExports) -> bool:
|
||||
success = True
|
||||
|
||||
for name, pipeline in exports.pipelines.items():
|
||||
success = success and add_pipeline(name, pipeline)
|
||||
|
||||
for name, stage in exports.stages.items():
|
||||
success = success and add_stage(name, stage)
|
||||
|
||||
return success
|
|
@ -1,14 +1,22 @@
|
|||
import unittest
|
||||
from multiprocessing import Queue, Value
|
||||
from os import getpid
|
||||
from onnx_web.errors import RetryException
|
||||
|
||||
from onnx_web.errors import RetryException
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.command import JobCommand
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main
|
||||
from onnx_web.worker.worker import (
|
||||
EXIT_ERROR,
|
||||
EXIT_INTERRUPT,
|
||||
EXIT_MEMORY,
|
||||
EXIT_REPLACED,
|
||||
MEMORY_ERRORS,
|
||||
worker_main,
|
||||
)
|
||||
from tests.helpers import test_device
|
||||
|
||||
|
||||
def main_memory(_worker):
|
||||
raise Exception(MEMORY_ERRORS[0])
|
||||
|
||||
|
|
Loading…
Reference in New Issue