1
0
Fork 0

feat(api): enable ONNX optimizations through env

This commit is contained in:
Sean Sube 2023-02-18 15:44:39 -06:00
parent 0d2211ff25
commit 5b4c370a1b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 43 additions and 6 deletions

View File

@ -1,7 +1,10 @@
from enum import IntEnum
from typing import Any, Dict, Literal, Optional, Tuple, Union
from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from onnxruntime import SessionOptions
from onnxruntime import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__)
class SizeChart(IntEnum):
@ -75,11 +78,16 @@ class Size:
class DeviceParams:
def __init__(
self, device: str, provider: str, options: Optional[dict] = None
self,
device: str,
provider: str,
options: Optional[dict] = None,
optimizations: Optional[List[str]] = None,
) -> None:
self.device = device
self.provider = provider
self.options = options
self.optimizations = optimizations
def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options)
@ -91,7 +99,23 @@ class DeviceParams:
return (self.provider, self.options)
def sess_options(self) -> SessionOptions:
return SessionOptions()
sess = SessionOptions()
if "onnx-low-memory" in self.optimizations:
logger.debug("enabling ONNX low-memory optimizations")
sess.enable_cpu_mem_arena = False
sess.enable_mem_pattern = False
sess.enable_mem_reuse = False
if "onnx-optimization-disable" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
elif "onnx-optimization-basic" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
elif "onnx-optimization-all" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if "onnx-deterministic-compute" in self.optimizations:
sess.use_deterministic_compute = True
def torch_str(self) -> str:
if self.device.startswith("cuda"):

View File

@ -349,16 +349,29 @@ def load_platforms(context: ServerContext) -> None:
{
"device_id": i,
},
context.optimizations,
)
)
else:
available_platforms.append(
DeviceParams(potential, platform_providers[potential])
DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
)
if context.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(DeviceParams("any", platform_providers["cpu"]))
available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)
# make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):