1
0
Fork 0

feat(api): parse chain pipeline from request

This commit is contained in:
Sean Sube 2023-01-28 22:31:34 -06:00
parent c7a6ec45d8
commit 151ebff237
13 changed files with 169 additions and 34 deletions

View File

@ -27,16 +27,19 @@ def blend_img2img(
source_image: Image.Image,
*,
strength: float,
prompt: str = None,
**kwargs,
) -> Image.Image:
logger.info('generating image using img2img', params.prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
prompt = prompt or params.prompt
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
prompt,
generator=rng,
guidance_scale=params.cfg,
image=source_image,

View File

@ -46,6 +46,7 @@ def blend_inpaint(
fill_color: str = 'white',
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
**kwargs,
) -> Image.Image:
logger.info('upscaling image by expanding borders', expand)

View File

@ -60,6 +60,7 @@ def correct_gfpgan(
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
**kwargs,
) -> Image.Image:
if upscale.correction_model is None:
logger.warn('no face model given, skipping')

View File

@ -21,6 +21,7 @@ def persist_disk(
source_image: Image.Image,
*,
output: str,
**kwargs,
) -> Image.Image:
dest = base_join(ctx.output_path, output)
source_image.save(dest)

View File

@ -26,6 +26,7 @@ def persist_s3(
bucket: str,
endpoint_url: str = None,
profile_name: str = None,
**kwargs,
) -> Image.Image:
session = Session(profile_name=profile_name)
s3 = session.client('s3', endpoint_url=endpoint_url)

View File

@ -29,8 +29,11 @@ def source_txt2img(
source_image: Image.Image,
*,
size: Size,
prompt: str = None,
**kwargs,
) -> Image.Image:
logger.info('generating image using txt2img, %s steps', params.steps)
prompt = prompt or params.prompt
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
if source_image is not None:
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
@ -42,7 +45,7 @@ def source_txt2img(
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
prompt,
height=size.height,
width=size.width,
generator=rng,

View File

@ -46,6 +46,7 @@ def upscale_outpaint(
fill_color: str = 'white',
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
**kwargs,
) -> Image.Image:
logger.info('upscaling image by expanding borders: %s', expand)

View File

@ -77,6 +77,7 @@ def upscale_resrgan(
source_image: Image.Image,
*,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
logger.info('upscaling image with Real ESRGAN', upscale.scale)

View File

@ -67,6 +67,7 @@ def upscale_stable_diffusion(
source: Image.Image,
*,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
logger.info('upscaling with Stable Diffusion')

View File

@ -18,6 +18,7 @@ from flask_cors import CORS
from flask_executor import Executor
from glob import glob
from io import BytesIO
from jsonschema import validate
from logging import getLogger
from PIL import Image
from onnxruntime import get_available_providers
@ -26,10 +27,12 @@ from typing import Tuple
from .chain import (
blend_img2img,
blend_inpaint,
correct_gfpgan,
source_txt2img,
persist_disk,
persist_s3,
source_txt2img,
upscale_outpaint,
upscale_resrgan,
upscale_stable_diffusion,
@ -120,10 +123,15 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen,
}
chain_stages = {
'correction-gfpgan': correct_gfpgan,
'upscaling-outpaint': upscale_outpaint,
'upscaling-resrgan': upscale_resrgan,
'upscaling-stable-diffusion': upscale_stable_diffusion,
'blend-img2img': blend_img2img,
'blend-inpaint': blend_inpaint,
'correct-gfpgan': correct_gfpgan,
'persist-disk': persist_disk,
'persist-s3': persist_s3,
'source-txt2img': source_txt2img,
'upscale-outpaint': upscale_outpaint,
'upscale-resrgan': upscale_resrgan,
'upscale-stable-diffusion': upscale_stable_diffusion,
}
# Available ORT providers
@ -547,37 +555,38 @@ def upscale():
@app.route('/api/chain', methods=['POST'])
def chain():
data = request.json
with open('./schema.yaml', 'r') as f:
schema = yaml.safe_load(f.read())
logger.info('validating chain request: %s against %s', data, schema)
validate(data, schema)
# get defaults from the regular parameters
params, size = pipeline_from_request()
output = make_output_name('chain', params, size)
# parse body as json, list of stages
example = ChainPipeline(stages=[
(source_txt2img, StageParams(), {
'size': size,
}),
(upscale_outpaint, StageParams(), {
'expand': Border.even(SizeChart.half),
}),
(persist_disk, StageParams(tile_size=SizeChart.hd8k), {
'output': output,
}),
(upscale_stable_diffusion, StageParams(tile_size=SizeChart.mini,outscale=4), {
'upscale': UpscaleParams('stable-diffusion-x4-upscaler', params.provider, scale=4, outscale=4)
}),
(persist_disk, StageParams(tile_size=SizeChart.hd8k), {
'output': output,
}),
(persist_s3, StageParams(tile_size=SizeChart.hd8k), {
'bucket': 'storage-stable-diffusion',
'endpoint_url': 'http://scylla.home.holdmyran.ch:8000',
'output': output,
'profile_name': 'ceph',
}),
])
pipeline = ChainPipeline()
for stage_data in data.get('stages', []):
logger.info('request stage: %s', stage_data)
callback = chain_stages[stage_data.get('type')]
kwargs = stage_data.get('params', {})
stage = StageParams(
stage_data.get('name', callback.__name__),
tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
outscale=int(kwargs.get('outscale', 1)),
)
# TODO: create Border from border
# TODO: create Upscale from upscale
pipeline.append((callback, stage, kwargs))
# build and run chain pipeline
executor.submit_stored(output, example, context,
params, Image.new('RGB', (1, 1)))
fake_source = Image.new('RGB', (1, 1))
executor.submit_stored(output, pipeline, context,
params, fake_source, output=output, size=size)
return jsonify({
'output': output,

View File

@ -18,4 +18,5 @@ boto3
flask
flask-cors
flask_executor
jsonschema
pyyaml

69
api/schema.yaml Normal file
View File

@ -0,0 +1,69 @@
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
$schema: https://json-schema.org/draft/2020-12/schema
$defs:
border_params:
type: object
properties:
bottom:
type: number
left:
type: number
right:
type: number
top:
type: number
image_params:
type: object
required: [prompt]
properties:
prompt:
type: string
upscale_params:
type: object
required: [outscale, scale]
properties:
outscale:
type: number
scale:
type: number
request_stage:
type: object
required: [name, type, params]
properties:
name:
type: string
type:
type: string
params:
type: object
properties:
args:
type: object
additionalProperties: False
patternProperties:
"^[-_A-Za-z]+$":
oneOf:
- type: number
- type: string
border:
$ref: "#/$defs/border_params"
image:
$ref: "#/$defs/image_params"
upscale:
$ref: "#/$defs/upscale_params"
request_chain:
type: array
items:
$ref: "#/$defs/request_stage"
type: object
additionalProperties: False
required: [stages]
properties:
stages:
$ref: "#/$defs/request_chain"

View File

@ -0,0 +1,43 @@
{
"stages": [
{
"name": "start",
"type": "source-txt2img",
"params": {
"prompt": "a magical wizard"
}
},
{
"name": "refine",
"type": "blend-img2img",
"params": {
"prompt": "a magical wizard in a robe fighting a dragon"
}
},
{
"name": "upscale",
"type": "upscale-stable-diffusion",
"params": {
"model": "stable-diffusion-x4-upscaler",
"prompt": "a magical wizard in a robe fighting a dragon",
"scale": 4,
"outscale": 4,
"tile_size": 128
}
},
{
"name": "save-local",
"type": "persist-disk",
"params": {}
},
{
"name": "save-ceph",
"type": "persist-s3",
"params": {
"bucket": "storage-stable-diffusion",
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
"profile_name": "ceph"
}
}
]
}