1
0
Fork 0

feat(api): load source and mask images for chain pipelines (#88)

This commit is contained in:
Sean Sube 2023-02-04 14:08:43 -06:00
parent 4e5ad54471
commit 1de2a51db5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 24 additions and 2 deletions

View File

@ -320,9 +320,11 @@ def load_platforms():
'device_id': i,
}))
else:
available_platforms.append(DeviceParams(potential, platform_providers[potential]))
available_platforms.append(DeviceParams(
potential, platform_providers[potential]))
logger.info('available acceleration platforms: %s', ', '.join([str(p) for p in available_platforms]))
logger.info('available acceleration platforms: %s',
', '.join([str(p) for p in available_platforms]))
context = ServerContext.from_environ()
@ -603,6 +605,26 @@ def chain():
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
kwargs['upscale'] = upscale
stage_source_name = 'source:%s' % (stage.name)
stage_mask_name = 'mask:%s' % (stage.name)
if stage_source_name in request.files:
logger.debug('loading source image %s for pipeline stage %s',
stage_source_name, stage.name)
source_file = request.files.get('source')
source_image = Image.open(
BytesIO(source_file.read())).convert('RGB')
source_image = source_image.thumbnail((512, 512))
kwargs['source_image'] = source_image
if stage_mask_name in request.files:
logger.debug('loading mask image %s for pipeline stage %s',
stage_mask_name, stage.name)
mask_file = request.files.get('source')
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
mask_image = mask_image.thumbnail((512, 512))
kwargs['mask_image'] = mask_image
pipeline.append((callback, stage, kwargs))
logger.info('running chain pipeline with %s stages', len(pipeline.stages))