1
0
Fork 0

feat(api): enable ONNX optimizations through env

This commit is contained in:
Sean Sube 2023-02-18 15:44:39 -06:00
parent 0d2211ff25
commit 5b4c370a1b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 43 additions and 6 deletions

View File

@ -1,7 +1,10 @@
from enum import IntEnum from enum import IntEnum
from typing import Any, Dict, Literal, Optional, Tuple, Union from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from onnxruntime import SessionOptions from onnxruntime import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__)
class SizeChart(IntEnum): class SizeChart(IntEnum):
@ -75,11 +78,16 @@ class Size:
class DeviceParams: class DeviceParams:
def __init__( def __init__(
self, device: str, provider: str, options: Optional[dict] = None self,
device: str,
provider: str,
options: Optional[dict] = None,
optimizations: Optional[List[str]] = None,
) -> None: ) -> None:
self.device = device self.device = device
self.provider = provider self.provider = provider
self.options = options self.options = options
self.optimizations = optimizations
def __str__(self) -> str: def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options) return "%s - %s (%s)" % (self.device, self.provider, self.options)
@ -91,7 +99,23 @@ class DeviceParams:
return (self.provider, self.options) return (self.provider, self.options)
def sess_options(self) -> SessionOptions: def sess_options(self) -> SessionOptions:
return SessionOptions() sess = SessionOptions()
if "onnx-low-memory" in self.optimizations:
logger.debug("enabling ONNX low-memory optimizations")
sess.enable_cpu_mem_arena = False
sess.enable_mem_pattern = False
sess.enable_mem_reuse = False
if "onnx-optimization-disable" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
elif "onnx-optimization-basic" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
elif "onnx-optimization-all" in self.optimizations:
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if "onnx-deterministic-compute" in self.optimizations:
sess.use_deterministic_compute = True
def torch_str(self) -> str: def torch_str(self) -> str:
if self.device.startswith("cuda"): if self.device.startswith("cuda"):

View File

@ -349,16 +349,29 @@ def load_platforms(context: ServerContext) -> None:
{ {
"device_id": i, "device_id": i,
}, },
context.optimizations,
) )
) )
else: else:
available_platforms.append( available_platforms.append(
DeviceParams(potential, platform_providers[potential]) DeviceParams(
potential,
platform_providers[potential],
None,
context.optimizations,
)
) )
if context.any_platform: if context.any_platform:
# the platform should be ignored when the job is scheduled, but set to CPU just in case # the platform should be ignored when the job is scheduled, but set to CPU just in case
available_platforms.append(DeviceParams("any", platform_providers["cpu"])) available_platforms.append(
DeviceParams(
"any",
platform_providers["cpu"],
None,
context.optimizations,
)
)
# make sure CPU is last on the list # make sure CPU is last on the list
def any_first_cpu_last(a: DeviceParams, b: DeviceParams): def any_first_cpu_last(a: DeviceParams, b: DeviceParams):