diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 894211fa..d2d96cd7 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -90,6 +90,18 @@ pipeline_schedulers = { } +def add_pipeline(name: str, pipeline: Any) -> bool: + global available_pipelines + + if name in available_pipelines: + # TODO: decide if this should be allowed or not + logger.warning("cannot replace existing pipeline: %s", name) + return False + else: + available_pipelines[name] = pipeline + return False + + def get_available_pipelines() -> List[str]: return list(available_pipelines.keys()) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 9e47cc27..d6e4c678 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -23,6 +23,7 @@ from .server.load import ( load_platforms, load_wildcards, ) +from .server.plugin import load_plugins, register_plugins from .server.static import register_static_routes from .server.utils import check_paths from .utils import is_debug @@ -43,12 +44,23 @@ def main(): server = ServerContext.from_environ() apply_patches(server) check_paths(server) + + # register plugins + exports = load_plugins(server) + success = register_plugins(exports) + if success: + logger.info("all plugins loaded successfully") + else: + logger.warning("error loading plugins") + + # load additional resources load_extras(server) load_models(server) load_params(server) load_platforms(server) load_wildcards(server) + # debug and misc server options if is_debug(): gc.set_debug(gc.DEBUG_STATS) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index d4118205..0a4221a2 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -40,6 +40,7 @@ class ServerContext: server_version: str worker_retries: int feature_flags: List[str] + plugins: List[str] def __init__( self,