1
0
Fork 0

fix(api): pass device ID in provider params

This commit is contained in:
Sean Sube 2023-02-04 21:52:45 -06:00
parent 9d1f9412f6
commit 37dd8927bf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 8 additions and 11 deletions

View File

@ -45,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
# use ONNX acceleration, if available
if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)

View File

@ -44,10 +44,10 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: De
if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
last_pipeline_instance = pipeline
last_pipeline_params = cache_params

View File

@ -65,15 +65,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
sess_options=device.options,
provider_options=device.options,
subfolder='scheduler',
)
pipe = pipeline.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
safety_checker=None,
scheduler=scheduler,
sess_options=device.options,
)
if device is not None and hasattr(pipe, 'to'):
@ -88,7 +88,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
sess_options=device.options,
provider_options=device.options,
subfolder='scheduler',
)

View File

@ -48,14 +48,11 @@ class OnnxNet():
server: ServerContext,
model: str,
provider: str = 'DmlExecutionProvider',
sess_options: Optional[dict] = None,
provider_options: Optional[dict] = None,
) -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider], sess_options=sess_options)
model_path, providers=[provider], provider_options=provider_options)
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name