From 948bec6d0f030da53a85807ec4b1cb220199f151 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 12 Apr 2023 08:51:16 -0500 Subject: [PATCH] load controlnets from network models list --- api/onnx_web/diffusers/load.py | 12 +++++++----- api/onnx_web/server/params.py | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index bf5e456b..be6edf11 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -280,11 +280,13 @@ def load_pipeline( ) if control is not None: - components["controlnet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model( - control, - provider=device.ort_provider(), - sess_options=device.sess_options(), - )) + components["controlnet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + path.join(server.model_path, "control", f"{control}.onnx"), + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) + ) pipe = pipeline.from_pretrained( model, diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index e730a459..526dfe49 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -20,6 +20,7 @@ from .load import ( get_config_value, get_correction_models, get_highres_methods, + get_network_models, get_upscaling_models, ) from .utils import get_model_path @@ -42,13 +43,13 @@ def pipeline_from_request( device = platform # pipeline stuff - control = get_not_empty(request.args, "control", get_config_value("control")) lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(server, model) scheduler = get_from_list( request.args, "scheduler", list(pipeline_schedulers.keys()) ) + control = get_from_list(request.args, "control", get_network_models()) if scheduler is None: scheduler = get_config_value("scheduler")