From 37dd8927bfc8d5e2f8e1bf5b2776e26279889014 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 4 Feb 2023 21:52:45 -0600 Subject: [PATCH] fix(api): pass device ID in provider params --- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/chain/upscale_stable_diffusion.py | 4 ++-- api/onnx_web/diffusion/load.py | 6 +++--- api/onnx_web/onnx/onnx_net.py | 7 ++----- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 3a030dbd..885c1d0c 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -45,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams # use ONNX acceleration, if available if params.format == 'onnx': - model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options) + model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options) elif params.format == 'pth': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=params.scale) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 86cef5fe..c9d89b86 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -44,10 +44,10 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: De if upscale.format == 'onnx': logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider) - pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options) + pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options) else: logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider) - pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options) + pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider) last_pipeline_instance = pipeline last_pipeline_params = cache_params diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index e7661a70..d45a7f6b 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -65,15 +65,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic scheduler = scheduler.from_pretrained( model, provider=device.provider, - sess_options=device.options, + provider_options=device.options, subfolder='scheduler', ) pipe = pipeline.from_pretrained( model, provider=device.provider, + provider_options=device.options, safety_checker=None, scheduler=scheduler, - sess_options=device.options, ) if device is not None and hasattr(pipe, 'to'): @@ -88,7 +88,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic scheduler = scheduler.from_pretrained( model, provider=device.provider, - sess_options=device.options, + provider_options=device.options, subfolder='scheduler', ) diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 9b0d14f9..20b99098 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -48,14 +48,11 @@ class OnnxNet(): server: ServerContext, model: str, provider: str = 'DmlExecutionProvider', - sess_options: Optional[dict] = None, + provider_options: Optional[dict] = None, ) -> None: - ''' - TODO: get platform provider from request params - ''' model_path = path.join(server.model_path, model) self.session = InferenceSession( - model_path, providers=[provider], sess_options=sess_options) + model_path, providers=[provider], provider_options=provider_options) def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name