feat(api): enable ONNX optimizations through env
This commit is contained in:
parent
0d2211ff25
commit
5b4c370a1b
|
@ -1,7 +1,10 @@
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Dict, Literal, Optional, Tuple, Union
|
from logging import getLogger
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from onnxruntime import SessionOptions
|
from onnxruntime import GraphOptimizationLevel, SessionOptions
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
|
@ -75,11 +78,16 @@ class Size:
|
||||||
|
|
||||||
class DeviceParams:
|
class DeviceParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: str, provider: str, options: Optional[dict] = None
|
self,
|
||||||
|
device: str,
|
||||||
|
provider: str,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
optimizations: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.options = options
|
self.options = options
|
||||||
|
self.optimizations = optimizations
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
return "%s - %s (%s)" % (self.device, self.provider, self.options)
|
||||||
|
@ -91,7 +99,23 @@ class DeviceParams:
|
||||||
return (self.provider, self.options)
|
return (self.provider, self.options)
|
||||||
|
|
||||||
def sess_options(self) -> SessionOptions:
|
def sess_options(self) -> SessionOptions:
|
||||||
return SessionOptions()
|
sess = SessionOptions()
|
||||||
|
|
||||||
|
if "onnx-low-memory" in self.optimizations:
|
||||||
|
logger.debug("enabling ONNX low-memory optimizations")
|
||||||
|
sess.enable_cpu_mem_arena = False
|
||||||
|
sess.enable_mem_pattern = False
|
||||||
|
sess.enable_mem_reuse = False
|
||||||
|
|
||||||
|
if "onnx-optimization-disable" in self.optimizations:
|
||||||
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
elif "onnx-optimization-basic" in self.optimizations:
|
||||||
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
|
elif "onnx-optimization-all" in self.optimizations:
|
||||||
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
|
||||||
|
if "onnx-deterministic-compute" in self.optimizations:
|
||||||
|
sess.use_deterministic_compute = True
|
||||||
|
|
||||||
def torch_str(self) -> str:
|
def torch_str(self) -> str:
|
||||||
if self.device.startswith("cuda"):
|
if self.device.startswith("cuda"):
|
||||||
|
|
|
@ -349,16 +349,29 @@ def load_platforms(context: ServerContext) -> None:
|
||||||
{
|
{
|
||||||
"device_id": i,
|
"device_id": i,
|
||||||
},
|
},
|
||||||
|
context.optimizations,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
available_platforms.append(
|
available_platforms.append(
|
||||||
DeviceParams(potential, platform_providers[potential])
|
DeviceParams(
|
||||||
|
potential,
|
||||||
|
platform_providers[potential],
|
||||||
|
None,
|
||||||
|
context.optimizations,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.any_platform:
|
if context.any_platform:
|
||||||
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
# the platform should be ignored when the job is scheduled, but set to CPU just in case
|
||||||
available_platforms.append(DeviceParams("any", platform_providers["cpu"]))
|
available_platforms.append(
|
||||||
|
DeviceParams(
|
||||||
|
"any",
|
||||||
|
platform_providers["cpu"],
|
||||||
|
None,
|
||||||
|
context.optimizations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# make sure CPU is last on the list
|
# make sure CPU is last on the list
|
||||||
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
|
def any_first_cpu_last(a: DeviceParams, b: DeviceParams):
|
||||||
|
|
Loading…
Reference in New Issue