1
0
Fork 0

feat(api): parse chain pipeline from request

This commit is contained in:
Sean Sube 2023-01-28 22:31:34 -06:00
parent c7a6ec45d8
commit 151ebff237
13 changed files with 169 additions and 34 deletions

View File

@ -27,16 +27,19 @@ def blend_img2img(
source_image: Image.Image, source_image: Image.Image,
*, *,
strength: float, strength: float,
prompt: str = None,
**kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('generating image using img2img', params.prompt) logger.info('generating image using img2img', params.prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler) params.model, params.provider, params.scheduler)
prompt = prompt or params.prompt
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
params.prompt, prompt,
generator=rng, generator=rng,
guidance_scale=params.cfg, guidance_scale=params.cfg,
image=source_image, image=source_image,

View File

@ -46,6 +46,7 @@ def blend_inpaint(
fill_color: str = 'white', fill_color: str = 'white',
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
**kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('upscaling image by expanding borders', expand) logger.info('upscaling image by expanding borders', expand)

View File

@ -60,6 +60,7 @@ def correct_gfpgan(
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None, upsampler: Optional[RealESRGANer] = None,
**kwargs,
) -> Image.Image: ) -> Image.Image:
if upscale.correction_model is None: if upscale.correction_model is None:
logger.warn('no face model given, skipping') logger.warn('no face model given, skipping')

View File

@ -21,6 +21,7 @@ def persist_disk(
source_image: Image.Image, source_image: Image.Image,
*, *,
output: str, output: str,
**kwargs,
) -> Image.Image: ) -> Image.Image:
dest = base_join(ctx.output_path, output) dest = base_join(ctx.output_path, output)
source_image.save(dest) source_image.save(dest)

View File

@ -26,6 +26,7 @@ def persist_s3(
bucket: str, bucket: str,
endpoint_url: str = None, endpoint_url: str = None,
profile_name: str = None, profile_name: str = None,
**kwargs,
) -> Image.Image: ) -> Image.Image:
session = Session(profile_name=profile_name) session = Session(profile_name=profile_name)
s3 = session.client('s3', endpoint_url=endpoint_url) s3 = session.client('s3', endpoint_url=endpoint_url)

View File

@ -29,8 +29,11 @@ def source_txt2img(
source_image: Image.Image, source_image: Image.Image,
*, *,
size: Size, size: Size,
prompt: str = None,
**kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('generating image using txt2img, %s steps', params.steps) prompt = prompt or params.prompt
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
if source_image is not None: if source_image is not None:
logger.warn('a source image was passed to a txt2img stage, but will be discarded') logger.warn('a source image was passed to a txt2img stage, but will be discarded')
@ -42,7 +45,7 @@ def source_txt2img(
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
params.prompt, prompt,
height=size.height, height=size.height,
width=size.width, width=size.width,
generator=rng, generator=rng,

View File

@ -46,6 +46,7 @@ def upscale_outpaint(
fill_color: str = 'white', fill_color: str = 'white',
mask_filter: Callable = mask_filter_none, mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram, noise_source: Callable = noise_source_histogram,
**kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('upscaling image by expanding borders: %s', expand) logger.info('upscaling image by expanding borders: %s', expand)

View File

@ -77,6 +77,7 @@ def upscale_resrgan(
source_image: Image.Image, source_image: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('upscaling image with Real ESRGAN', upscale.scale) logger.info('upscaling image with Real ESRGAN', upscale.scale)

View File

@ -67,6 +67,7 @@ def upscale_stable_diffusion(
source: Image.Image, source: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
**kwargs,
) -> Image.Image: ) -> Image.Image:
logger.info('upscaling with Stable Diffusion') logger.info('upscaling with Stable Diffusion')

View File

@ -18,6 +18,7 @@ from flask_cors import CORS
from flask_executor import Executor from flask_executor import Executor
from glob import glob from glob import glob
from io import BytesIO from io import BytesIO
from jsonschema import validate
from logging import getLogger from logging import getLogger
from PIL import Image from PIL import Image
from onnxruntime import get_available_providers from onnxruntime import get_available_providers
@ -26,10 +27,12 @@ from typing import Tuple
from .chain import ( from .chain import (
blend_img2img,
blend_inpaint,
correct_gfpgan, correct_gfpgan,
source_txt2img,
persist_disk, persist_disk,
persist_s3, persist_s3,
source_txt2img,
upscale_outpaint, upscale_outpaint,
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
@ -120,10 +123,15 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen, 'gaussian-screen': mask_filter_gaussian_screen,
} }
chain_stages = { chain_stages = {
'correction-gfpgan': correct_gfpgan, 'blend-img2img': blend_img2img,
'upscaling-outpaint': upscale_outpaint, 'blend-inpaint': blend_inpaint,
'upscaling-resrgan': upscale_resrgan, 'correct-gfpgan': correct_gfpgan,
'upscaling-stable-diffusion': upscale_stable_diffusion, 'persist-disk': persist_disk,
'persist-s3': persist_s3,
'source-txt2img': source_txt2img,
'upscale-outpaint': upscale_outpaint,
'upscale-resrgan': upscale_resrgan,
'upscale-stable-diffusion': upscale_stable_diffusion,
} }
# Available ORT providers # Available ORT providers
@ -547,37 +555,38 @@ def upscale():
@app.route('/api/chain', methods=['POST']) @app.route('/api/chain', methods=['POST'])
def chain(): def chain():
data = request.json
with open('./schema.yaml', 'r') as f:
schema = yaml.safe_load(f.read())
logger.info('validating chain request: %s against %s', data, schema)
validate(data, schema)
# get defaults from the regular parameters
params, size = pipeline_from_request() params, size = pipeline_from_request()
output = make_output_name('chain', params, size) output = make_output_name('chain', params, size)
# parse body as json, list of stages pipeline = ChainPipeline()
example = ChainPipeline(stages=[ for stage_data in data.get('stages', []):
(source_txt2img, StageParams(), { logger.info('request stage: %s', stage_data)
'size': size,
}), callback = chain_stages[stage_data.get('type')]
(upscale_outpaint, StageParams(), { kwargs = stage_data.get('params', {})
'expand': Border.even(SizeChart.half),
}), stage = StageParams(
(persist_disk, StageParams(tile_size=SizeChart.hd8k), { stage_data.get('name', callback.__name__),
'output': output, tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
}), outscale=int(kwargs.get('outscale', 1)),
(upscale_stable_diffusion, StageParams(tile_size=SizeChart.mini,outscale=4), { )
'upscale': UpscaleParams('stable-diffusion-x4-upscaler', params.provider, scale=4, outscale=4) # TODO: create Border from border
}), # TODO: create Upscale from upscale
(persist_disk, StageParams(tile_size=SizeChart.hd8k), { pipeline.append((callback, stage, kwargs))
'output': output,
}),
(persist_s3, StageParams(tile_size=SizeChart.hd8k), {
'bucket': 'storage-stable-diffusion',
'endpoint_url': 'http://scylla.home.holdmyran.ch:8000',
'output': output,
'profile_name': 'ceph',
}),
])
# build and run chain pipeline # build and run chain pipeline
executor.submit_stored(output, example, context, fake_source = Image.new('RGB', (1, 1))
params, Image.new('RGB', (1, 1))) executor.submit_stored(output, pipeline, context,
params, fake_source, output=output, size=size)
return jsonify({ return jsonify({
'output': output, 'output': output,

View File

@ -18,4 +18,5 @@ boto3
flask flask
flask-cors flask-cors
flask_executor flask_executor
jsonschema
pyyaml pyyaml

69
api/schema.yaml Normal file
View File

@ -0,0 +1,69 @@
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
$schema: https://json-schema.org/draft/2020-12/schema
$defs:
border_params:
type: object
properties:
bottom:
type: number
left:
type: number
right:
type: number
top:
type: number
image_params:
type: object
required: [prompt]
properties:
prompt:
type: string
upscale_params:
type: object
required: [outscale, scale]
properties:
outscale:
type: number
scale:
type: number
request_stage:
type: object
required: [name, type, params]
properties:
name:
type: string
type:
type: string
params:
type: object
properties:
args:
type: object
additionalProperties: False
patternProperties:
"^[-_A-Za-z]+$":
oneOf:
- type: number
- type: string
border:
$ref: "#/$defs/border_params"
image:
$ref: "#/$defs/image_params"
upscale:
$ref: "#/$defs/upscale_params"
request_chain:
type: array
items:
$ref: "#/$defs/request_stage"
type: object
additionalProperties: False
required: [stages]
properties:
stages:
$ref: "#/$defs/request_chain"

View File

@ -0,0 +1,43 @@
{
"stages": [
{
"name": "start",
"type": "source-txt2img",
"params": {
"prompt": "a magical wizard"
}
},
{
"name": "refine",
"type": "blend-img2img",
"params": {
"prompt": "a magical wizard in a robe fighting a dragon"
}
},
{
"name": "upscale",
"type": "upscale-stable-diffusion",
"params": {
"model": "stable-diffusion-x4-upscaler",
"prompt": "a magical wizard in a robe fighting a dragon",
"scale": 4,
"outscale": 4,
"tile_size": 128
}
},
{
"name": "save-local",
"type": "persist-disk",
"params": {}
},
{
"name": "save-ceph",
"type": "persist-s3",
"params": {
"bucket": "storage-stable-diffusion",
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
"profile_name": "ceph"
}
}
]
}