1
0
Fork 0

fix(api): pass both device and session options to ORT (#38)

This commit is contained in:
Sean Sube 2023-02-14 18:57:50 -06:00
parent feb4603171
commit 8a2a9174ba
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 25 additions and 12 deletions

View File

@ -52,8 +52,8 @@ def load_resrgan(
model = OnnxNet( model = OnnxNet(
server, server,
model_file, model_file,
provider=device.provider, provider=device.ort_provider(),
sess_options=device.options, sess_options=device.sess_options(),
) )
elif params.format == "pth": elif params.format == "pth":
model = RRDBNet( model = RRDBNet(

View File

@ -34,7 +34,9 @@ def load_stable_diffusion(
device.provider, device.provider,
) )
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained( pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider, sess_options=device.options model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
) )
else: else:
logger.debug( logger.debug(
@ -43,7 +45,8 @@ def load_stable_diffusion(
device.provider, device.provider,
) )
pipe = StableDiffusionUpscalePipeline.from_pretrained( pipe = StableDiffusionUpscalePipeline.from_pretrained(
model_path, provider=device.provider model_path,
provider=device.provider,
) )
server.cache.set("diffusion", cache_key, pipe) server.cache.set("diffusion", cache_key, pipe)

View File

@ -109,8 +109,8 @@ def load_pipeline(
logger.debug("loading new diffusion scheduler") logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained( scheduler = scheduler_type.from_pretrained(
model, model,
provider=device.provider, provider=device.ort_provider(),
sess_options=device.options, sess_options=device.sess_options(),
subfolder="scheduler", subfolder="scheduler",
) )
@ -134,15 +134,15 @@ def load_pipeline(
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler_type.from_pretrained( scheduler = scheduler_type.from_pretrained(
model, model,
provider=device.provider, provider=device.ort_provider(),
sess_options=device.options, sess_options=device.sess_options(),
subfolder="scheduler", subfolder="scheduler",
) )
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
custom_pipeline=custom_pipeline, custom_pipeline=custom_pipeline,
provider=device.provider, provider=device.ort_provider(),
sess_options=device.options, sess_options=device.sess_options(),
revision="onnx", revision="onnx",
safety_checker=None, safety_checker=None,
scheduler=scheduler, scheduler=scheduler,

View File

@ -3,7 +3,7 @@ from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from onnxruntime import InferenceSession from onnxruntime import InferenceSession, SessionOptions
from ..utils import ServerContext from ..utils import ServerContext
@ -47,7 +47,7 @@ class OnnxNet:
server: ServerContext, server: ServerContext,
model: str, model: str,
provider: str = "DmlExecutionProvider", provider: str = "DmlExecutionProvider",
sess_options: Optional[dict] = None, sess_options: Optional[SessionOptions] = None,
) -> None: ) -> None:
model_path = path.join(server.model_path, model) model_path = path.join(server.model_path, model)
self.session = InferenceSession( self.session = InferenceSession(

View File

@ -1,4 +1,5 @@
from enum import IntEnum from enum import IntEnum
from onnxruntime import SessionOptions
from typing import Any, Dict, Literal, Optional, Tuple, Union from typing import Any, Dict, Literal, Optional, Tuple, Union
@ -79,6 +80,15 @@ class DeviceParams:
def __str__(self) -> str: def __str__(self) -> str:
return "%s - %s (%s)" % (self.device, self.provider, self.options) return "%s - %s (%s)" % (self.device, self.provider, self.options)
def ort_provider(self) -> Tuple[str, Any]:
if self.options is None:
return self.provider
else:
return (self.provider, self.options)
def sess_options(self) -> SessionOptions:
return SessionOptions()
def torch_device(self) -> str: def torch_device(self) -> str:
if self.device.startswith("cuda"): if self.device.startswith("cuda"):
return self.device return self.device