fix(api): pass device ID in provider params
This commit is contained in:
parent
9d1f9412f6
commit
37dd8927bf
|
@ -45,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
|
|||
|
||||
# use ONNX acceleration, if available
|
||||
if params.format == 'onnx':
|
||||
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
|
||||
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
|
||||
elif params.format == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
|
|
|
@ -44,10 +44,10 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: De
|
|||
|
||||
if upscale.format == 'onnx':
|
||||
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
|
||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
|
||||
else:
|
||||
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
|
||||
|
||||
last_pipeline_instance = pipeline
|
||||
last_pipeline_params = cache_params
|
||||
|
|
|
@ -65,15 +65,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
|||
scheduler = scheduler.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
sess_options=device.options,
|
||||
provider_options=device.options,
|
||||
subfolder='scheduler',
|
||||
)
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
provider_options=device.options,
|
||||
safety_checker=None,
|
||||
scheduler=scheduler,
|
||||
sess_options=device.options,
|
||||
)
|
||||
|
||||
if device is not None and hasattr(pipe, 'to'):
|
||||
|
@ -88,7 +88,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
|
|||
scheduler = scheduler.from_pretrained(
|
||||
model,
|
||||
provider=device.provider,
|
||||
sess_options=device.options,
|
||||
provider_options=device.options,
|
||||
subfolder='scheduler',
|
||||
)
|
||||
|
||||
|
|
|
@ -48,14 +48,11 @@ class OnnxNet():
|
|||
server: ServerContext,
|
||||
model: str,
|
||||
provider: str = 'DmlExecutionProvider',
|
||||
sess_options: Optional[dict] = None,
|
||||
provider_options: Optional[dict] = None,
|
||||
) -> None:
|
||||
'''
|
||||
TODO: get platform provider from request params
|
||||
'''
|
||||
model_path = path.join(server.model_path, model)
|
||||
self.session = InferenceSession(
|
||||
model_path, providers=[provider], sess_options=sess_options)
|
||||
model_path, providers=[provider], provider_options=provider_options)
|
||||
|
||||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
|
|
Loading…
Reference in New Issue