apply lint, add missing file
This commit is contained in:
parent
a63669c76b
commit
7e6749e0d7
|
@ -1,2 +1,2 @@
|
||||||
from .pipeline import ChainPipeline, PipelineStage, StageParams
|
from .pipeline import ChainPipeline, PipelineStage, StageParams
|
||||||
from .stages import *
|
from .stages import * # NOQA
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .result import StageResult
|
|
||||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||||
from ..server.context import ServerContext
|
from ..server.context import ServerContext
|
||||||
from ..worker.context import WorkerContext
|
from ..worker.context import WorkerContext
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
|
||||||
class BaseStage:
|
class BaseStage:
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -28,4 +28,6 @@ class BlendLinearStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
||||||
return StageResult(images=[Image.blend(source, stage_source, alpha) for source in sources])
|
return StageResult(
|
||||||
|
images=[Image.blend(source, stage_source, alpha) for source in sources]
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
@ -38,4 +38,8 @@ class BlendMaskStage(BaseStage):
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
save_image(server, "last-mult-mask.png", mult_mask)
|
save_image(server, "last-mult-mask.png", mult_mask)
|
||||||
|
|
||||||
return StageResult(images=[Image.composite(stage_source, source, mult_mask) for source in sources])
|
return StageResult(
|
||||||
|
images=[
|
||||||
|
Image.composite(stage_source, source, mult_mask) for source in sources
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
@ -74,12 +73,15 @@ class CorrectGFPGANStage(BaseStage):
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
gfpgan = self.load(server, stage, upscale, device)
|
gfpgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
outputs = [gfpgan.enhance(
|
outputs = [
|
||||||
|
gfpgan.enhance(
|
||||||
source,
|
source,
|
||||||
has_aligned=False,
|
has_aligned=False,
|
||||||
only_center_face=False,
|
only_center_face=False,
|
||||||
paste_back=True,
|
paste_back=True,
|
||||||
weight=upscale.face_strength,
|
weight=upscale.face_strength,
|
||||||
) for source in sources.as_numpy()]
|
)
|
||||||
|
for source in sources.as_numpy()
|
||||||
|
]
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .pipeline import ChainPipeline
|
|
||||||
from ..chain.blend_img2img import BlendImg2ImgStage
|
from ..chain.blend_img2img import BlendImg2ImgStage
|
||||||
from ..chain.upscale import stage_upscale_correction
|
from ..chain.upscale import stage_upscale_correction
|
||||||
from ..chain.upscale_simple import UpscaleSimpleStage
|
from ..chain.upscale_simple import UpscaleSimpleStage
|
||||||
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
from ..params import HighresParams, ImageParams, StageParams, UpscaleParams
|
||||||
|
from .pipeline import ChainPipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from boto3 import Session
|
from boto3 import Session
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
|
@ -12,8 +12,8 @@ from ..server import ServerContext
|
||||||
from ..utils import is_debug, run_gc
|
from ..utils import is_debug, run_gc
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .tile import needs_tile, process_tile_order
|
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
from .tile import needs_tile, process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ class ChainPipeline:
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
callback: Optional[ProgressCallback],
|
callback: Optional[ProgressCallback],
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
result = self(
|
result = self(
|
||||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||||
|
@ -108,10 +108,10 @@ class ChainPipeline:
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
**pipeline_kwargs
|
**pipeline_kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
"""
|
"""
|
||||||
DEPRECATED: use `run` instead
|
DEPRECATED: use `.run()` instead
|
||||||
"""
|
"""
|
||||||
if callback is None:
|
if callback is None:
|
||||||
callback = worker.get_progress_callback()
|
callback = worker.get_progress_callback()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from PIL import Image
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class StageResult:
|
class StageResult:
|
||||||
"""
|
"""
|
||||||
|
@ -10,6 +11,7 @@ class StageResult:
|
||||||
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
||||||
they are expected.
|
they are expected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arrays: Optional[List[np.ndarray]]
|
arrays: Optional[List[np.ndarray]]
|
||||||
images: Optional[List[Image.Image]]
|
images: Optional[List[Image.Image]]
|
||||||
|
|
||||||
|
@ -38,4 +40,4 @@ class StageResult:
|
||||||
if self.images is not None:
|
if self.images is not None:
|
||||||
return self.images
|
return self.images
|
||||||
|
|
||||||
return [Image.fromarray(i, "RGB") for i in self.arrays]
|
return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -132,7 +132,9 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
else:
|
else:
|
||||||
# encode and record alternative prompts outside of LPW
|
# encode and record alternative prompts outside of LPW
|
||||||
if params.is_panorama() or params.is_xl():
|
if params.is_panorama() or params.is_xl():
|
||||||
logger.debug("prompt alternatives are not supported for panorama or SDXL")
|
logger.debug(
|
||||||
|
"prompt alternatives are not supported for panorama or SDXL"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_embeds = encode_prompt(
|
prompt_embeds = encode_prompt(
|
||||||
pipe, prompt_pairs, params.batch, params.do_cfg()
|
pipe, prompt_pairs, params.batch, params.do_cfg()
|
||||||
|
|
|
@ -2,8 +2,8 @@ from logging import getLogger
|
||||||
|
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .blend_denoise import BlendDenoiseStage
|
from .blend_denoise import BlendDenoiseStage
|
||||||
from .blend_img2img import BlendImg2ImgStage
|
|
||||||
from .blend_grid import BlendGridStage
|
from .blend_grid import BlendGridStage
|
||||||
|
from .blend_img2img import BlendImg2ImgStage
|
||||||
from .blend_linear import BlendLinearStage
|
from .blend_linear import BlendLinearStage
|
||||||
from .blend_mask import BlendMaskStage
|
from .blend_mask import BlendMaskStage
|
||||||
from .correct_codeformer import CorrectCodeformerStage
|
from .correct_codeformer import CorrectCodeformerStage
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..onnx import OnnxRRDBNet
|
from ..onnx import OnnxRRDBNet
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -80,12 +80,15 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
logger.trace("SwinIR input shape: %s", image.shape)
|
logger.trace("SwinIR input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
logger.trace("SwinIR output shape: %s", (
|
logger.trace(
|
||||||
|
"SwinIR output shape: %s",
|
||||||
|
(
|
||||||
image.shape[0],
|
image.shape[0],
|
||||||
image.shape[1],
|
image.shape[1],
|
||||||
image.shape[2] * scale,
|
image.shape[2] * scale,
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
output = swinir(image)
|
output = swinir(image)
|
||||||
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||||
|
|
|
@ -81,7 +81,9 @@ def convert_diffusion_diffusers_xl(
|
||||||
output=dest_path,
|
output=dest_path,
|
||||||
task="stable-diffusion-xl",
|
task="stable-diffusion-xl",
|
||||||
device=device,
|
device=device,
|
||||||
fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm
|
fp16=conversion.has_optimization(
|
||||||
|
"torch-fp16"
|
||||||
|
), # optimum's fp16 mode only works on CUDA or ROCm
|
||||||
framework="pt",
|
framework="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -574,10 +574,9 @@ def optimize_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
if (
|
if server.has_optimization(
|
||||||
server.has_optimization("diffusers-attention-slicing")
|
"diffusers-attention-slicing"
|
||||||
or server.has_optimization("diffusers-attention-slicing-auto")
|
) or server.has_optimization("diffusers-attention-slicing-auto"):
|
||||||
):
|
|
||||||
logger.debug("enabling auto attention slicing on SD pipeline")
|
logger.debug("enabling auto attention slicing on SD pipeline")
|
||||||
try:
|
try:
|
||||||
pipe.enable_attention_slicing(slice_size="auto")
|
pipe.enable_attention_slicing(slice_size="auto")
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ..chain import (
|
||||||
UpscaleOutpaintStage,
|
UpscaleOutpaintStage,
|
||||||
)
|
)
|
||||||
from ..chain.highres import stage_highres
|
from ..chain.highres import stage_highres
|
||||||
|
from ..chain.result import StageResult
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..image import expand_image
|
from ..image import expand_image
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
|
|
|
@ -208,6 +208,7 @@ class ImageParams:
|
||||||
unet_overlap: float
|
unet_overlap: float
|
||||||
vae_tile: int
|
vae_tile: int
|
||||||
vae_overlap: float
|
vae_overlap: float
|
||||||
|
denoise: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -230,6 +231,7 @@ class ImageParams:
|
||||||
unet_tile: int = 512,
|
unet_tile: int = 512,
|
||||||
vae_overlap: float = 0.25,
|
vae_overlap: float = 0.25,
|
||||||
vae_tile: int = 512,
|
vae_tile: int = 512,
|
||||||
|
denoise: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
|
@ -250,6 +252,7 @@ class ImageParams:
|
||||||
self.unet_tile = unet_tile
|
self.unet_tile = unet_tile
|
||||||
self.vae_overlap = vae_overlap
|
self.vae_overlap = vae_overlap
|
||||||
self.vae_tile = vae_tile
|
self.vae_tile = vae_tile
|
||||||
|
self.denoise = denoise
|
||||||
|
|
||||||
def do_cfg(self):
|
def do_cfg(self):
|
||||||
return self.cfg > 1.0
|
return self.cfg > 1.0
|
||||||
|
@ -320,6 +323,7 @@ class ImageParams:
|
||||||
"unet_tile": self.unet_tile,
|
"unet_tile": self.unet_tile,
|
||||||
"vae_overlap": self.vae_overlap,
|
"vae_overlap": self.vae_overlap,
|
||||||
"vae_tile": self.vae_tile,
|
"vae_tile": self.vae_tile,
|
||||||
|
"denoise": self.denoise,
|
||||||
}
|
}
|
||||||
|
|
||||||
def with_args(self, **kwargs):
|
def with_args(self, **kwargs):
|
||||||
|
@ -343,6 +347,7 @@ class ImageParams:
|
||||||
kwargs.get("unet_tile", self.unet_tile),
|
kwargs.get("unet_tile", self.unet_tile),
|
||||||
kwargs.get("vae_overlap", self.vae_overlap),
|
kwargs.get("vae_overlap", self.vae_overlap),
|
||||||
kwargs.get("vae_tile", self.vae_tile),
|
kwargs.get("vae_tile", self.vae_tile),
|
||||||
|
kwargs.get("denoise", self.denoise),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -138,7 +138,9 @@ def apply_patch_basicsr(server: ServerContext):
|
||||||
import basicsr.utils.download_util
|
import basicsr.utils.download_util
|
||||||
|
|
||||||
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||||
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)
|
basicsr.utils.download_util.load_file_from_url = partial(
|
||||||
|
patch_cache_path, server
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info("unable to import basicsr utils for patching")
|
logger.info("unable to import basicsr utils for patching")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -151,7 +153,9 @@ def apply_patch_codeformer(server: ServerContext):
|
||||||
import codeformer.facelib.utils.misc
|
import codeformer.facelib.utils.misc
|
||||||
|
|
||||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||||
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)
|
codeformer.facelib.utils.misc.load_file_from_url = partial(
|
||||||
|
patch_cache_path, server
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info("unable to import codeformer utils for patching")
|
logger.info("unable to import codeformer utils for patching")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
from importlib import import_module
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
from onnx_web.chain.stages import add_stage
|
||||||
|
from onnx_web.diffusers.load import add_pipeline
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginExports:
|
||||||
|
pipelines: Dict[str, Any]
|
||||||
|
stages: Dict[str, Any]
|
||||||
|
|
||||||
|
def __init__(self, pipelines=None, stages=None) -> None:
|
||||||
|
self.pipelines = pipelines or {}
|
||||||
|
self.stages = stages or {}
|
||||||
|
|
||||||
|
|
||||||
|
PluginModule = Callable[[ServerContext], PluginExports]
|
||||||
|
|
||||||
|
|
||||||
|
def load_plugins(server: ServerContext) -> PluginExports:
|
||||||
|
combined_exports = PluginExports()
|
||||||
|
|
||||||
|
for plugin in server.plugins:
|
||||||
|
logger.info("loading plugin module: %s", plugin)
|
||||||
|
try:
|
||||||
|
module: PluginModule = import_module(plugin)
|
||||||
|
exports = module(server)
|
||||||
|
|
||||||
|
for name, pipeline in exports.pipelines.items():
|
||||||
|
if name in combined_exports.pipelines:
|
||||||
|
logger.warning(
|
||||||
|
"multiple plugins exported a pipeline named %s", name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
combined_exports.pipelines[name] = pipeline
|
||||||
|
|
||||||
|
for name, stage in exports.stages.items():
|
||||||
|
if name in combined_exports.stages:
|
||||||
|
logger.warning("multiple plugins exported a stage named %s", name)
|
||||||
|
else:
|
||||||
|
combined_exports.stages[name] = stage
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error importing plugin")
|
||||||
|
|
||||||
|
return combined_exports
|
||||||
|
|
||||||
|
|
||||||
|
def register_plugins(exports: PluginExports) -> bool:
|
||||||
|
success = True
|
||||||
|
|
||||||
|
for name, pipeline in exports.pipelines.items():
|
||||||
|
success = success and add_pipeline(name, pipeline)
|
||||||
|
|
||||||
|
for name, stage in exports.stages.items():
|
||||||
|
success = success and add_stage(name, stage)
|
||||||
|
|
||||||
|
return success
|
|
@ -1,14 +1,22 @@
|
||||||
import unittest
|
import unittest
|
||||||
from multiprocessing import Queue, Value
|
from multiprocessing import Queue, Value
|
||||||
from os import getpid
|
from os import getpid
|
||||||
from onnx_web.errors import RetryException
|
|
||||||
|
|
||||||
|
from onnx_web.errors import RetryException
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.command import JobCommand
|
from onnx_web.worker.command import JobCommand
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
from onnx_web.worker.worker import EXIT_ERROR, EXIT_INTERRUPT, EXIT_MEMORY, EXIT_REPLACED, MEMORY_ERRORS, worker_main
|
from onnx_web.worker.worker import (
|
||||||
|
EXIT_ERROR,
|
||||||
|
EXIT_INTERRUPT,
|
||||||
|
EXIT_MEMORY,
|
||||||
|
EXIT_REPLACED,
|
||||||
|
MEMORY_ERRORS,
|
||||||
|
worker_main,
|
||||||
|
)
|
||||||
from tests.helpers import test_device
|
from tests.helpers import test_device
|
||||||
|
|
||||||
|
|
||||||
def main_memory(_worker):
|
def main_memory(_worker):
|
||||||
raise Exception(MEMORY_ERRORS[0])
|
raise Exception(MEMORY_ERRORS[0])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue