feat(api): load source and mask images for chain pipelines (#88)
This commit is contained in:
parent
4e5ad54471
commit
1de2a51db5
|
@ -320,9 +320,11 @@ def load_platforms():
|
||||||
'device_id': i,
|
'device_id': i,
|
||||||
}))
|
}))
|
||||||
else:
|
else:
|
||||||
available_platforms.append(DeviceParams(potential, platform_providers[potential]))
|
available_platforms.append(DeviceParams(
|
||||||
|
potential, platform_providers[potential]))
|
||||||
|
|
||||||
logger.info('available acceleration platforms: %s', ', '.join([str(p) for p in available_platforms]))
|
logger.info('available acceleration platforms: %s',
|
||||||
|
', '.join([str(p) for p in available_platforms]))
|
||||||
|
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
context = ServerContext.from_environ()
|
||||||
|
@ -603,6 +605,26 @@ def chain():
|
||||||
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
|
upscale = UpscaleParams(kwargs.get('upscale'), params.provider)
|
||||||
kwargs['upscale'] = upscale
|
kwargs['upscale'] = upscale
|
||||||
|
|
||||||
|
stage_source_name = 'source:%s' % (stage.name)
|
||||||
|
stage_mask_name = 'mask:%s' % (stage.name)
|
||||||
|
|
||||||
|
if stage_source_name in request.files:
|
||||||
|
logger.debug('loading source image %s for pipeline stage %s',
|
||||||
|
stage_source_name, stage.name)
|
||||||
|
source_file = request.files.get('source')
|
||||||
|
source_image = Image.open(
|
||||||
|
BytesIO(source_file.read())).convert('RGB')
|
||||||
|
source_image = source_image.thumbnail((512, 512))
|
||||||
|
kwargs['source_image'] = source_image
|
||||||
|
|
||||||
|
if stage_mask_name in request.files:
|
||||||
|
logger.debug('loading mask image %s for pipeline stage %s',
|
||||||
|
stage_mask_name, stage.name)
|
||||||
|
mask_file = request.files.get('source')
|
||||||
|
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
||||||
|
mask_image = mask_image.thumbnail((512, 512))
|
||||||
|
kwargs['mask_image'] = mask_image
|
||||||
|
|
||||||
pipeline.append((callback, stage, kwargs))
|
pipeline.append((callback, stage, kwargs))
|
||||||
|
|
||||||
logger.info('running chain pipeline with %s stages', len(pipeline.stages))
|
logger.info('running chain pipeline with %s stages', len(pipeline.stages))
|
||||||
|
|
Loading…
Reference in New Issue