feat(api): parse chain pipeline from request
This commit is contained in:
parent
c7a6ec45d8
commit
151ebff237
|
@ -27,16 +27,19 @@ def blend_img2img(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
*,
|
||||||
strength: float,
|
strength: float,
|
||||||
|
prompt: str = None,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('generating image using img2img', params.prompt)
|
logger.info('generating image using img2img', params.prompt)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||||
params.model, params.provider, params.scheduler)
|
params.model, params.provider, params.scheduler)
|
||||||
|
|
||||||
|
prompt = prompt or params.prompt
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
result = pipe(
|
result = pipe(
|
||||||
params.prompt,
|
prompt,
|
||||||
generator=rng,
|
generator=rng,
|
||||||
guidance_scale=params.cfg,
|
guidance_scale=params.cfg,
|
||||||
image=source_image,
|
image=source_image,
|
||||||
|
|
|
@ -46,6 +46,7 @@ def blend_inpaint(
|
||||||
fill_color: str = 'white',
|
fill_color: str = 'white',
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('upscaling image by expanding borders', expand)
|
logger.info('upscaling image by expanding borders', expand)
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,7 @@ def correct_gfpgan(
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
upsampler: Optional[RealESRGANer] = None,
|
upsampler: Optional[RealESRGANer] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
if upscale.correction_model is None:
|
if upscale.correction_model is None:
|
||||||
logger.warn('no face model given, skipping')
|
logger.warn('no face model given, skipping')
|
||||||
|
|
|
@ -21,6 +21,7 @@ def persist_disk(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
*,
|
||||||
output: str,
|
output: str,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
dest = base_join(ctx.output_path, output)
|
dest = base_join(ctx.output_path, output)
|
||||||
source_image.save(dest)
|
source_image.save(dest)
|
||||||
|
|
|
@ -26,6 +26,7 @@ def persist_s3(
|
||||||
bucket: str,
|
bucket: str,
|
||||||
endpoint_url: str = None,
|
endpoint_url: str = None,
|
||||||
profile_name: str = None,
|
profile_name: str = None,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
session = Session(profile_name=profile_name)
|
session = Session(profile_name=profile_name)
|
||||||
s3 = session.client('s3', endpoint_url=endpoint_url)
|
s3 = session.client('s3', endpoint_url=endpoint_url)
|
||||||
|
|
|
@ -29,8 +29,11 @@ def source_txt2img(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
*,
|
||||||
size: Size,
|
size: Size,
|
||||||
|
prompt: str = None,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('generating image using txt2img, %s steps', params.steps)
|
prompt = prompt or params.prompt
|
||||||
|
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
|
||||||
|
|
||||||
if source_image is not None:
|
if source_image is not None:
|
||||||
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
||||||
|
@ -42,7 +45,7 @@ def source_txt2img(
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
result = pipe(
|
result = pipe(
|
||||||
params.prompt,
|
prompt,
|
||||||
height=size.height,
|
height=size.height,
|
||||||
width=size.width,
|
width=size.width,
|
||||||
generator=rng,
|
generator=rng,
|
||||||
|
|
|
@ -46,6 +46,7 @@ def upscale_outpaint(
|
||||||
fill_color: str = 'white',
|
fill_color: str = 'white',
|
||||||
mask_filter: Callable = mask_filter_none,
|
mask_filter: Callable = mask_filter_none,
|
||||||
noise_source: Callable = noise_source_histogram,
|
noise_source: Callable = noise_source_histogram,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('upscaling image by expanding borders: %s', expand)
|
logger.info('upscaling image by expanding borders: %s', expand)
|
||||||
|
|
||||||
|
|
|
@ -77,6 +77,7 @@ def upscale_resrgan(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('upscaling image with Real ESRGAN', upscale.scale)
|
logger.info('upscaling image with Real ESRGAN', upscale.scale)
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,7 @@ def upscale_stable_diffusion(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
logger.info('upscaling with Stable Diffusion')
|
logger.info('upscaling with Stable Diffusion')
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from flask_cors import CORS
|
||||||
from flask_executor import Executor
|
from flask_executor import Executor
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from jsonschema import validate
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from onnxruntime import get_available_providers
|
from onnxruntime import get_available_providers
|
||||||
|
@ -26,10 +27,12 @@ from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
from .chain import (
|
from .chain import (
|
||||||
|
blend_img2img,
|
||||||
|
blend_inpaint,
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
source_txt2img,
|
|
||||||
persist_disk,
|
persist_disk,
|
||||||
persist_s3,
|
persist_s3,
|
||||||
|
source_txt2img,
|
||||||
upscale_outpaint,
|
upscale_outpaint,
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
|
@ -120,10 +123,15 @@ mask_filters = {
|
||||||
'gaussian-screen': mask_filter_gaussian_screen,
|
'gaussian-screen': mask_filter_gaussian_screen,
|
||||||
}
|
}
|
||||||
chain_stages = {
|
chain_stages = {
|
||||||
'correction-gfpgan': correct_gfpgan,
|
'blend-img2img': blend_img2img,
|
||||||
'upscaling-outpaint': upscale_outpaint,
|
'blend-inpaint': blend_inpaint,
|
||||||
'upscaling-resrgan': upscale_resrgan,
|
'correct-gfpgan': correct_gfpgan,
|
||||||
'upscaling-stable-diffusion': upscale_stable_diffusion,
|
'persist-disk': persist_disk,
|
||||||
|
'persist-s3': persist_s3,
|
||||||
|
'source-txt2img': source_txt2img,
|
||||||
|
'upscale-outpaint': upscale_outpaint,
|
||||||
|
'upscale-resrgan': upscale_resrgan,
|
||||||
|
'upscale-stable-diffusion': upscale_stable_diffusion,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Available ORT providers
|
# Available ORT providers
|
||||||
|
@ -547,37 +555,38 @@ def upscale():
|
||||||
|
|
||||||
@app.route('/api/chain', methods=['POST'])
|
@app.route('/api/chain', methods=['POST'])
|
||||||
def chain():
|
def chain():
|
||||||
|
data = request.json
|
||||||
|
|
||||||
|
with open('./schema.yaml', 'r') as f:
|
||||||
|
schema = yaml.safe_load(f.read())
|
||||||
|
|
||||||
|
logger.info('validating chain request: %s against %s', data, schema)
|
||||||
|
validate(data, schema)
|
||||||
|
|
||||||
|
# get defaults from the regular parameters
|
||||||
params, size = pipeline_from_request()
|
params, size = pipeline_from_request()
|
||||||
output = make_output_name('chain', params, size)
|
output = make_output_name('chain', params, size)
|
||||||
|
|
||||||
# parse body as json, list of stages
|
pipeline = ChainPipeline()
|
||||||
example = ChainPipeline(stages=[
|
for stage_data in data.get('stages', []):
|
||||||
(source_txt2img, StageParams(), {
|
logger.info('request stage: %s', stage_data)
|
||||||
'size': size,
|
|
||||||
}),
|
callback = chain_stages[stage_data.get('type')]
|
||||||
(upscale_outpaint, StageParams(), {
|
kwargs = stage_data.get('params', {})
|
||||||
'expand': Border.even(SizeChart.half),
|
|
||||||
}),
|
stage = StageParams(
|
||||||
(persist_disk, StageParams(tile_size=SizeChart.hd8k), {
|
stage_data.get('name', callback.__name__),
|
||||||
'output': output,
|
tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
|
||||||
}),
|
outscale=int(kwargs.get('outscale', 1)),
|
||||||
(upscale_stable_diffusion, StageParams(tile_size=SizeChart.mini,outscale=4), {
|
)
|
||||||
'upscale': UpscaleParams('stable-diffusion-x4-upscaler', params.provider, scale=4, outscale=4)
|
# TODO: create Border from border
|
||||||
}),
|
# TODO: create Upscale from upscale
|
||||||
(persist_disk, StageParams(tile_size=SizeChart.hd8k), {
|
pipeline.append((callback, stage, kwargs))
|
||||||
'output': output,
|
|
||||||
}),
|
|
||||||
(persist_s3, StageParams(tile_size=SizeChart.hd8k), {
|
|
||||||
'bucket': 'storage-stable-diffusion',
|
|
||||||
'endpoint_url': 'http://scylla.home.holdmyran.ch:8000',
|
|
||||||
'output': output,
|
|
||||||
'profile_name': 'ceph',
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
|
|
||||||
# build and run chain pipeline
|
# build and run chain pipeline
|
||||||
executor.submit_stored(output, example, context,
|
fake_source = Image.new('RGB', (1, 1))
|
||||||
params, Image.new('RGB', (1, 1)))
|
executor.submit_stored(output, pipeline, context,
|
||||||
|
params, fake_source, output=output, size=size)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'output': output,
|
'output': output,
|
||||||
|
|
|
@ -18,4 +18,5 @@ boto3
|
||||||
flask
|
flask
|
||||||
flask-cors
|
flask-cors
|
||||||
flask_executor
|
flask_executor
|
||||||
|
jsonschema
|
||||||
pyyaml
|
pyyaml
|
|
@ -0,0 +1,69 @@
|
||||||
|
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
|
||||||
|
$schema: https://json-schema.org/draft/2020-12/schema
|
||||||
|
|
||||||
|
$defs:
|
||||||
|
border_params:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
bottom:
|
||||||
|
type: number
|
||||||
|
left:
|
||||||
|
type: number
|
||||||
|
right:
|
||||||
|
type: number
|
||||||
|
top:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
image_params:
|
||||||
|
type: object
|
||||||
|
required: [prompt]
|
||||||
|
properties:
|
||||||
|
prompt:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
upscale_params:
|
||||||
|
type: object
|
||||||
|
required: [outscale, scale]
|
||||||
|
properties:
|
||||||
|
outscale:
|
||||||
|
type: number
|
||||||
|
scale:
|
||||||
|
type: number
|
||||||
|
|
||||||
|
request_stage:
|
||||||
|
type: object
|
||||||
|
required: [name, type, params]
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
params:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
args:
|
||||||
|
type: object
|
||||||
|
additionalProperties: False
|
||||||
|
patternProperties:
|
||||||
|
"^[-_A-Za-z]+$":
|
||||||
|
oneOf:
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
border:
|
||||||
|
$ref: "#/$defs/border_params"
|
||||||
|
image:
|
||||||
|
$ref: "#/$defs/image_params"
|
||||||
|
upscale:
|
||||||
|
$ref: "#/$defs/upscale_params"
|
||||||
|
|
||||||
|
request_chain:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: "#/$defs/request_stage"
|
||||||
|
|
||||||
|
type: object
|
||||||
|
additionalProperties: False
|
||||||
|
required: [stages]
|
||||||
|
properties:
|
||||||
|
stages:
|
||||||
|
$ref: "#/$defs/request_chain"
|
|
@ -0,0 +1,43 @@
|
||||||
|
{
|
||||||
|
"stages": [
|
||||||
|
{
|
||||||
|
"name": "start",
|
||||||
|
"type": "source-txt2img",
|
||||||
|
"params": {
|
||||||
|
"prompt": "a magical wizard"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "refine",
|
||||||
|
"type": "blend-img2img",
|
||||||
|
"params": {
|
||||||
|
"prompt": "a magical wizard in a robe fighting a dragon"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "upscale",
|
||||||
|
"type": "upscale-stable-diffusion",
|
||||||
|
"params": {
|
||||||
|
"model": "stable-diffusion-x4-upscaler",
|
||||||
|
"prompt": "a magical wizard in a robe fighting a dragon",
|
||||||
|
"scale": 4,
|
||||||
|
"outscale": 4,
|
||||||
|
"tile_size": 128
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "save-local",
|
||||||
|
"type": "persist-disk",
|
||||||
|
"params": {}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "save-ceph",
|
||||||
|
"type": "persist-s3",
|
||||||
|
"params": {
|
||||||
|
"bucket": "storage-stable-diffusion",
|
||||||
|
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
|
||||||
|
"profile_name": "ceph"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
Loading…
Reference in New Issue