1
0
Fork 0

fix(api): pass device ID in provider params

This commit is contained in:
Sean Sube 2023-02-04 21:52:45 -06:00
parent 9d1f9412f6
commit 37dd8927bf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 8 additions and 11 deletions

View File

@ -45,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams
# use ONNX acceleration, if available # use ONNX acceleration, if available
if params.format == 'onnx': if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options) model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options)
elif params.format == 'pth': elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale) num_block=23, num_grow_ch=32, scale=params.scale)

View File

@ -44,10 +44,10 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: De
if upscale.format == 'onnx': if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider) logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options) pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options)
else: else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider) logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options) pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider)
last_pipeline_instance = pipeline last_pipeline_instance = pipeline
last_pipeline_params = cache_params last_pipeline_params = cache_params

View File

@ -65,15 +65,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
sess_options=device.options, provider_options=device.options,
subfolder='scheduler', subfolder='scheduler',
) )
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
provider_options=device.options,
safety_checker=None, safety_checker=None,
scheduler=scheduler, scheduler=scheduler,
sess_options=device.options,
) )
if device is not None and hasattr(pipe, 'to'): if device is not None and hasattr(pipe, 'to'):
@ -88,7 +88,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic
scheduler = scheduler.from_pretrained( scheduler = scheduler.from_pretrained(
model, model,
provider=device.provider, provider=device.provider,
sess_options=device.options, provider_options=device.options,
subfolder='scheduler', subfolder='scheduler',
) )

View File

@ -48,14 +48,11 @@ class OnnxNet():
server: ServerContext, server: ServerContext,
model: str, model: str,
provider: str = 'DmlExecutionProvider', provider: str = 'DmlExecutionProvider',
sess_options: Optional[dict] = None, provider_options: Optional[dict] = None,
) -> None: ) -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(server.model_path, model) model_path = path.join(server.model_path, model)
self.session = InferenceSession( self.session = InferenceSession(
model_path, providers=[provider], sess_options=sess_options) model_path, providers=[provider], provider_options=provider_options)
def __call__(self, image: Any) -> Any: def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name input_name = self.session.get_inputs()[0].name