fix(api): pass device ID in provider params
This commit is contained in:
parent
9d1f9412f6
commit
37dd8927bf
|
@ -45,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
|
||||||
|
|
||||||
# use ONNX acceleration, if available
|
# use ONNX acceleration, if available
|
||||||
if params.format == 'onnx':
|
if params.format == 'onnx':
|
||||||
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
|
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
|
||||||
elif params.format == 'pth':
|
elif params.format == 'pth':
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||||
|
|
|
@ -44,10 +44,10 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: De
|
||||||
|
|
||||||
if upscale.format == 'onnx':
|
if upscale.format == 'onnx':
|
||||||
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
|
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
|
||||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
|
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
|
||||||
else:
|
else:
|
||||||
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
|
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
|
||||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
|
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
|
||||||
|
|
||||||
last_pipeline_instance = pipeline
|
last_pipeline_instance = pipeline
|
||||||
last_pipeline_params = cache_params
|
last_pipeline_params = cache_params
|
||||||
|
|
|
@ -65,15 +65,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
sess_options=device.options,
|
provider_options=device.options,
|
||||||
subfolder='scheduler',
|
subfolder='scheduler',
|
||||||
)
|
)
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
|
provider_options=device.options,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sess_options=device.options,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(pipe, 'to'):
|
if device is not None and hasattr(pipe, 'to'):
|
||||||
|
@ -88,7 +88,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=device.provider,
|
provider=device.provider,
|
||||||
sess_options=device.options,
|
provider_options=device.options,
|
||||||
subfolder='scheduler',
|
subfolder='scheduler',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -48,14 +48,11 @@ class OnnxNet():
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
model: str,
|
model: str,
|
||||||
provider: str = 'DmlExecutionProvider',
|
provider: str = 'DmlExecutionProvider',
|
||||||
sess_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
|
||||||
TODO: get platform provider from request params
|
|
||||||
'''
|
|
||||||
model_path = path.join(server.model_path, model)
|
model_path = path.join(server.model_path, model)
|
||||||
self.session = InferenceSession(
|
self.session = InferenceSession(
|
||||||
model_path, providers=[provider], sess_options=sess_options)
|
model_path, providers=[provider], provider_options=provider_options)
|
||||||
|
|
||||||
def __call__(self, image: Any) -> Any:
|
def __call__(self, image: Any) -> Any:
|
||||||
input_name = self.session.get_inputs()[0].name
|
input_name = self.session.get_inputs()[0].name
|
||||||
|
|
Loading…
Reference in New Issue