feat(api): load source and mask images for chain pipelines (#88)
This commit is contained in:
parent
4e5ad54471
commit
1de2a51db5
|
@ -320,9 +320,11 @@ def load_platforms():
|
|||
'device_id': i,
|
||||
}))
|
||||
else:
|
||||
available_platforms.append(DeviceParams(potential, platform_providers[potential]))
|
||||
available_platforms.append(DeviceParams(
|
||||
potential, platform_providers[potential]))
|
||||
|
||||
logger.info('available acceleration platforms: %s', ', '.join([str(p) for p in available_platforms]))
|
||||
logger.info('available acceleration platforms: %s',
|
||||
', '.join([str(p) for p in available_platforms]))
|
||||
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
|
@ -603,6 +605,26 @@ def chain():
|
|||
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
|
||||
kwargs['upscale'] = upscale
|
||||
|
||||
stage_source_name = 'source:%s' % (stage.name)
|
||||
stage_mask_name = 'mask:%s' % (stage.name)
|
||||
|
||||
if stage_source_name in request.files:
|
||||
logger.debug('loading source image %s for pipeline stage %s',
|
||||
stage_source_name, stage.name)
|
||||
source_file = request.files.get('source')
|
||||
source_image = Image.open(
|
||||
BytesIO(source_file.read())).convert('RGB')
|
||||
source_image = source_image.thumbnail((512, 512))
|
||||
kwargs['source_image'] = source_image
|
||||
|
||||
if stage_mask_name in request.files:
|
||||
logger.debug('loading mask image %s for pipeline stage %s',
|
||||
stage_mask_name, stage.name)
|
||||
mask_file = request.files.get('source')
|
||||
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
||||
mask_image = mask_image.thumbnail((512, 512))
|
||||
kwargs['mask_image'] = mask_image
|
||||
|
||||
pipeline.append((callback, stage, kwargs))
|
||||
|
||||
logger.info('running chain pipeline with %s stages', len(pipeline.stages))
|
||||
|
|
Loading…
Reference in New Issue