feat(api): parse chain pipeline from request
This commit is contained in:
parent
c7a6ec45d8
commit
151ebff237
|
@ -27,16 +27,19 @@ def blend_img2img(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
strength: float,
|
||||
prompt: str = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('generating image using img2img', params.prompt)
|
||||
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
|
||||
prompt = prompt or params.prompt
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
prompt,
|
||||
generator=rng,
|
||||
guidance_scale=params.cfg,
|
||||
image=source_image,
|
||||
|
|
|
@ -46,6 +46,7 @@ def blend_inpaint(
|
|||
fill_color: str = 'white',
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling image by expanding borders', expand)
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ def correct_gfpgan(
|
|||
*,
|
||||
upscale: UpscaleParams,
|
||||
upsampler: Optional[RealESRGANer] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
if upscale.correction_model is None:
|
||||
logger.warn('no face model given, skipping')
|
||||
|
|
|
@ -21,6 +21,7 @@ def persist_disk(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
output: str,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
dest = base_join(ctx.output_path, output)
|
||||
source_image.save(dest)
|
||||
|
|
|
@ -26,6 +26,7 @@ def persist_s3(
|
|||
bucket: str,
|
||||
endpoint_url: str = None,
|
||||
profile_name: str = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
session = Session(profile_name=profile_name)
|
||||
s3 = session.client('s3', endpoint_url=endpoint_url)
|
||||
|
|
|
@ -29,8 +29,11 @@ def source_txt2img(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
size: Size,
|
||||
prompt: str = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('generating image using txt2img, %s steps', params.steps)
|
||||
prompt = prompt or params.prompt
|
||||
logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt)
|
||||
|
||||
if source_image is not None:
|
||||
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
||||
|
@ -42,7 +45,7 @@ def source_txt2img(
|
|||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
prompt,
|
||||
height=size.height,
|
||||
width=size.width,
|
||||
generator=rng,
|
||||
|
|
|
@ -46,6 +46,7 @@ def upscale_outpaint(
|
|||
fill_color: str = 'white',
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling image by expanding borders: %s', expand)
|
||||
|
||||
|
|
|
@ -77,6 +77,7 @@ def upscale_resrgan(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling image with Real ESRGAN', upscale.scale)
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ def upscale_stable_diffusion(
|
|||
source: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info('upscaling with Stable Diffusion')
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from flask_cors import CORS
|
|||
from flask_executor import Executor
|
||||
from glob import glob
|
||||
from io import BytesIO
|
||||
from jsonschema import validate
|
||||
from logging import getLogger
|
||||
from PIL import Image
|
||||
from onnxruntime import get_available_providers
|
||||
|
@ -26,10 +27,12 @@ from typing import Tuple
|
|||
|
||||
|
||||
from .chain import (
|
||||
blend_img2img,
|
||||
blend_inpaint,
|
||||
correct_gfpgan,
|
||||
source_txt2img,
|
||||
persist_disk,
|
||||
persist_s3,
|
||||
source_txt2img,
|
||||
upscale_outpaint,
|
||||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
|
@ -120,10 +123,15 @@ mask_filters = {
|
|||
'gaussian-screen': mask_filter_gaussian_screen,
|
||||
}
|
||||
chain_stages = {
|
||||
'correction-gfpgan': correct_gfpgan,
|
||||
'upscaling-outpaint': upscale_outpaint,
|
||||
'upscaling-resrgan': upscale_resrgan,
|
||||
'upscaling-stable-diffusion': upscale_stable_diffusion,
|
||||
'blend-img2img': blend_img2img,
|
||||
'blend-inpaint': blend_inpaint,
|
||||
'correct-gfpgan': correct_gfpgan,
|
||||
'persist-disk': persist_disk,
|
||||
'persist-s3': persist_s3,
|
||||
'source-txt2img': source_txt2img,
|
||||
'upscale-outpaint': upscale_outpaint,
|
||||
'upscale-resrgan': upscale_resrgan,
|
||||
'upscale-stable-diffusion': upscale_stable_diffusion,
|
||||
}
|
||||
|
||||
# Available ORT providers
|
||||
|
@ -547,37 +555,38 @@ def upscale():
|
|||
|
||||
@app.route('/api/chain', methods=['POST'])
|
||||
def chain():
|
||||
data = request.json
|
||||
|
||||
with open('./schema.yaml', 'r') as f:
|
||||
schema = yaml.safe_load(f.read())
|
||||
|
||||
logger.info('validating chain request: %s against %s', data, schema)
|
||||
validate(data, schema)
|
||||
|
||||
# get defaults from the regular parameters
|
||||
params, size = pipeline_from_request()
|
||||
output = make_output_name('chain', params, size)
|
||||
|
||||
# parse body as json, list of stages
|
||||
example = ChainPipeline(stages=[
|
||||
(source_txt2img, StageParams(), {
|
||||
'size': size,
|
||||
}),
|
||||
(upscale_outpaint, StageParams(), {
|
||||
'expand': Border.even(SizeChart.half),
|
||||
}),
|
||||
(persist_disk, StageParams(tile_size=SizeChart.hd8k), {
|
||||
'output': output,
|
||||
}),
|
||||
(upscale_stable_diffusion, StageParams(tile_size=SizeChart.mini,outscale=4), {
|
||||
'upscale': UpscaleParams('stable-diffusion-x4-upscaler', params.provider, scale=4, outscale=4)
|
||||
}),
|
||||
(persist_disk, StageParams(tile_size=SizeChart.hd8k), {
|
||||
'output': output,
|
||||
}),
|
||||
(persist_s3, StageParams(tile_size=SizeChart.hd8k), {
|
||||
'bucket': 'storage-stable-diffusion',
|
||||
'endpoint_url': 'http://scylla.home.holdmyran.ch:8000',
|
||||
'output': output,
|
||||
'profile_name': 'ceph',
|
||||
}),
|
||||
])
|
||||
pipeline = ChainPipeline()
|
||||
for stage_data in data.get('stages', []):
|
||||
logger.info('request stage: %s', stage_data)
|
||||
|
||||
callback = chain_stages[stage_data.get('type')]
|
||||
kwargs = stage_data.get('params', {})
|
||||
|
||||
stage = StageParams(
|
||||
stage_data.get('name', callback.__name__),
|
||||
tile_size=int(kwargs.get('tile_size', SizeChart.auto)),
|
||||
outscale=int(kwargs.get('outscale', 1)),
|
||||
)
|
||||
# TODO: create Border from border
|
||||
# TODO: create Upscale from upscale
|
||||
pipeline.append((callback, stage, kwargs))
|
||||
|
||||
# build and run chain pipeline
|
||||
executor.submit_stored(output, example, context,
|
||||
params, Image.new('RGB', (1, 1)))
|
||||
fake_source = Image.new('RGB', (1, 1))
|
||||
executor.submit_stored(output, pipeline, context,
|
||||
params, fake_source, output=output, size=size)
|
||||
|
||||
return jsonify({
|
||||
'output': output,
|
||||
|
|
|
@ -18,4 +18,5 @@ boto3
|
|||
flask
|
||||
flask-cors
|
||||
flask_executor
|
||||
jsonschema
|
||||
pyyaml
|
|
@ -0,0 +1,69 @@
|
|||
$id: https://github.com/ssube/onnx-web/blob/main/api/schema.yaml
|
||||
$schema: https://json-schema.org/draft/2020-12/schema
|
||||
|
||||
$defs:
|
||||
border_params:
|
||||
type: object
|
||||
properties:
|
||||
bottom:
|
||||
type: number
|
||||
left:
|
||||
type: number
|
||||
right:
|
||||
type: number
|
||||
top:
|
||||
type: number
|
||||
|
||||
image_params:
|
||||
type: object
|
||||
required: [prompt]
|
||||
properties:
|
||||
prompt:
|
||||
type: string
|
||||
|
||||
upscale_params:
|
||||
type: object
|
||||
required: [outscale, scale]
|
||||
properties:
|
||||
outscale:
|
||||
type: number
|
||||
scale:
|
||||
type: number
|
||||
|
||||
request_stage:
|
||||
type: object
|
||||
required: [name, type, params]
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
type:
|
||||
type: string
|
||||
params:
|
||||
type: object
|
||||
properties:
|
||||
args:
|
||||
type: object
|
||||
additionalProperties: False
|
||||
patternProperties:
|
||||
"^[-_A-Za-z]+$":
|
||||
oneOf:
|
||||
- type: number
|
||||
- type: string
|
||||
border:
|
||||
$ref: "#/$defs/border_params"
|
||||
image:
|
||||
$ref: "#/$defs/image_params"
|
||||
upscale:
|
||||
$ref: "#/$defs/upscale_params"
|
||||
|
||||
request_chain:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/$defs/request_stage"
|
||||
|
||||
type: object
|
||||
additionalProperties: False
|
||||
required: [stages]
|
||||
properties:
|
||||
stages:
|
||||
$ref: "#/$defs/request_chain"
|
|
@ -0,0 +1,43 @@
|
|||
{
|
||||
"stages": [
|
||||
{
|
||||
"name": "start",
|
||||
"type": "source-txt2img",
|
||||
"params": {
|
||||
"prompt": "a magical wizard"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "refine",
|
||||
"type": "blend-img2img",
|
||||
"params": {
|
||||
"prompt": "a magical wizard in a robe fighting a dragon"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "upscale",
|
||||
"type": "upscale-stable-diffusion",
|
||||
"params": {
|
||||
"model": "stable-diffusion-x4-upscaler",
|
||||
"prompt": "a magical wizard in a robe fighting a dragon",
|
||||
"scale": 4,
|
||||
"outscale": 4,
|
||||
"tile_size": 128
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "save-local",
|
||||
"type": "persist-disk",
|
||||
"params": {}
|
||||
},
|
||||
{
|
||||
"name": "save-ceph",
|
||||
"type": "persist-s3",
|
||||
"params": {
|
||||
"bucket": "storage-stable-diffusion",
|
||||
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
|
||||
"profile_name": "ceph"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
Loading…
Reference in New Issue