diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 2048f54f..e5904f08 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -320,9 +320,11 @@ def load_platforms(): 'device_id': i, })) else: - available_platforms.append(DeviceParams(potential, platform_providers[potential])) + available_platforms.append(DeviceParams( + potential, platform_providers[potential])) - logger.info('available acceleration platforms: %s', ', '.join([str(p) for p in available_platforms])) + logger.info('available acceleration platforms: %s', + ', '.join([str(p) for p in available_platforms])) context = ServerContext.from_environ() @@ -603,6 +605,26 @@ def chain(): upscale = UpscaleParams(kwargs.get('upscale'), params.provider) kwargs['upscale'] = upscale + stage_source_name = 'source:%s' % (stage.name) + stage_mask_name = 'mask:%s' % (stage.name) + + if stage_source_name in request.files: + logger.debug('loading source image %s for pipeline stage %s', + stage_source_name, stage.name) + source_file = request.files.get('source') + source_image = Image.open( + BytesIO(source_file.read())).convert('RGB') + source_image = source_image.thumbnail((512, 512)) + kwargs['source_image'] = source_image + + if stage_mask_name in request.files: + logger.debug('loading mask image %s for pipeline stage %s', + stage_mask_name, stage.name) + mask_file = request.files.get('source') + mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB') + mask_image = mask_image.thumbnail((512, 512)) + kwargs['mask_image'] = mask_image + pipeline.append((callback, stage, kwargs)) logger.info('running chain pipeline with %s stages', len(pipeline.stages))