From 5b4c370a1b6b3e809be34c14ab43978918973b23 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 15:44:39 -0600 Subject: [PATCH] feat(api): enable ONNX optimizations through env --- api/onnx_web/params.py | 32 ++++++++++++++++++++++++++++---- api/onnx_web/serve.py | 17 +++++++++++++++-- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 40eb2257..fdd7ab9a 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -1,7 +1,10 @@ from enum import IntEnum -from typing import Any, Dict, Literal, Optional, Tuple, Union +from logging import getLogger +from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from onnxruntime import SessionOptions +from onnxruntime import GraphOptimizationLevel, SessionOptions + +logger = getLogger(__name__) class SizeChart(IntEnum): @@ -75,11 +78,16 @@ class Size: class DeviceParams: def __init__( - self, device: str, provider: str, options: Optional[dict] = None + self, + device: str, + provider: str, + options: Optional[dict] = None, + optimizations: Optional[List[str]] = None, ) -> None: self.device = device self.provider = provider self.options = options + self.optimizations = optimizations def __str__(self) -> str: return "%s - %s (%s)" % (self.device, self.provider, self.options) @@ -91,7 +99,23 @@ class DeviceParams: return (self.provider, self.options) def sess_options(self) -> SessionOptions: - return SessionOptions() + sess = SessionOptions() + + if "onnx-low-memory" in self.optimizations: + logger.debug("enabling ONNX low-memory optimizations") + sess.enable_cpu_mem_arena = False + sess.enable_mem_pattern = False + sess.enable_mem_reuse = False + + if "onnx-optimization-disable" in self.optimizations: + sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL + elif "onnx-optimization-basic" in self.optimizations: + sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC + elif "onnx-optimization-all" in self.optimizations: + sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + if "onnx-deterministic-compute" in self.optimizations: + sess.use_deterministic_compute = True def torch_str(self) -> str: if self.device.startswith("cuda"): diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 4818c64c..503be4f0 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -349,16 +349,29 @@ def load_platforms(context: ServerContext) -> None: { "device_id": i, }, + context.optimizations, ) ) else: available_platforms.append( - DeviceParams(potential, platform_providers[potential]) + DeviceParams( + potential, + platform_providers[potential], + None, + context.optimizations, + ) ) if context.any_platform: # the platform should be ignored when the job is scheduled, but set to CPU just in case - available_platforms.append(DeviceParams("any", platform_providers["cpu"])) + available_platforms.append( + DeviceParams( + "any", + platform_providers["cpu"], + None, + context.optimizations, + ) + ) # make sure CPU is last on the list def any_first_cpu_last(a: DeviceParams, b: DeviceParams):