fix(api): get default params from file, enforce minimum params
This commit is contained in:
parent
c09eb75ab4
commit
e8b580a5de
|
@ -19,7 +19,7 @@ from diffusers import (
|
||||||
# types
|
# types
|
||||||
DiffusionPipeline,
|
DiffusionPipeline,
|
||||||
)
|
)
|
||||||
from flask import Flask, jsonify, request, send_file, send_from_directory, url_for
|
from flask import Flask, jsonify, request, send_from_directory, url_for
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from flask_executor import Executor
|
from flask_executor import Executor
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
@ -32,22 +32,6 @@ from typing import Any, Dict, Tuple, Union
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# defaults
|
|
||||||
default_model = 'stable-diffusion-onnx-v1-5'
|
|
||||||
default_platform = 'amd'
|
|
||||||
default_scheduler = 'euler-a'
|
|
||||||
default_prompt = "a photo of an astronaut eating a hamburger"
|
|
||||||
default_cfg = 8
|
|
||||||
default_steps = 20
|
|
||||||
default_height = 512
|
|
||||||
default_width = 512
|
|
||||||
default_strength = 0.5
|
|
||||||
|
|
||||||
max_cfg = 30
|
|
||||||
max_steps = 150
|
|
||||||
max_height = 512
|
|
||||||
max_width = 512
|
|
||||||
|
|
||||||
# paths
|
# paths
|
||||||
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
|
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
|
||||||
path.join('..', 'gui', 'out'))
|
path.join('..', 'gui', 'out'))
|
||||||
|
@ -190,27 +174,41 @@ def pipeline_from_request():
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
model = get_model_path(request.args.get('model', default_model))
|
model = get_model_path(request.args.get(
|
||||||
|
'model', config_params.get('model').get('default')))
|
||||||
provider = get_from_map(request.args, 'platform',
|
provider = get_from_map(request.args, 'platform',
|
||||||
platform_providers, default_platform)
|
platform_providers, config_params.get('provider').get('default'))
|
||||||
scheduler = get_from_map(request.args, 'scheduler',
|
scheduler = get_from_map(request.args, 'scheduler',
|
||||||
pipeline_schedulers, default_scheduler)
|
pipeline_schedulers, config_params.get('scheduler').get('default'))
|
||||||
|
|
||||||
# image params
|
# image params
|
||||||
prompt = request.args.get('prompt', default_prompt)
|
prompt = request.args.get(
|
||||||
|
'prompt', config_params.get('prompt').get('default'))
|
||||||
negative_prompt = request.args.get('negativePrompt', None)
|
negative_prompt = request.args.get('negativePrompt', None)
|
||||||
|
|
||||||
if negative_prompt is not None and negative_prompt.strip() == '':
|
if negative_prompt is not None and negative_prompt.strip() == '':
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
|
|
||||||
cfg = get_and_clamp_float(
|
cfg = get_and_clamp_float(
|
||||||
request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0)
|
request.args, 'cfg',
|
||||||
|
config_params.get('cfg').get('default'),
|
||||||
|
config_params.get('cfg').get('max'),
|
||||||
|
config_params.get('cfg').get('min'))
|
||||||
steps = get_and_clamp_int(
|
steps = get_and_clamp_int(
|
||||||
request.args, 'steps', default_steps, config_params.get('steps').get('max'))
|
request.args, 'steps',
|
||||||
|
config_params.get('steps').get('default'),
|
||||||
|
config_params.get('steps').get('max'),
|
||||||
|
config_params.get('steps').get('min'))
|
||||||
height = get_and_clamp_int(
|
height = get_and_clamp_int(
|
||||||
request.args, 'height', default_height, config_params.get('height').get('max'))
|
request.args, 'height',
|
||||||
|
config_params.get('height').get('default'),
|
||||||
|
config_params.get('height').get('max'),
|
||||||
|
config_params.get('height').get('min'))
|
||||||
width = get_and_clamp_int(
|
width = get_and_clamp_int(
|
||||||
request.args, 'width', default_width, config_params.get('width').get('max'))
|
request.args, 'width',
|
||||||
|
config_params.get('width').get('default'),
|
||||||
|
config_params.get('width').get('max'),
|
||||||
|
config_params.get('width').get('min'))
|
||||||
|
|
||||||
seed = int(request.args.get('seed', -1))
|
seed = int(request.args.get('seed', -1))
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
|
@ -369,7 +367,6 @@ def list_schedulers():
|
||||||
def img2img():
|
def img2img():
|
||||||
input_file = request.files.get('source')
|
input_file = request.files.get('source')
|
||||||
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
|
input_image = Image.open(BytesIO(input_file.read())).convert('RGB')
|
||||||
input_image.thumbnail((default_width, default_height))
|
|
||||||
|
|
||||||
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)
|
strength = get_and_clamp_float(request.args, 'strength', 0.5, 1.0)
|
||||||
|
|
||||||
|
@ -380,6 +377,7 @@ def img2img():
|
||||||
(prompt, cfg, negative_prompt, steps, strength, height, width))
|
(prompt, cfg, negative_prompt, steps, strength, height, width))
|
||||||
print("img2img output: %s" % (output_full))
|
print("img2img output: %s" % (output_full))
|
||||||
|
|
||||||
|
input_image.thumbnail((width, height))
|
||||||
executor.submit_stored(output_file, run_img2img_pipeline, model, provider,
|
executor.submit_stored(output_file, run_img2img_pipeline, model, provider,
|
||||||
scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image)
|
scheduler, prompt, negative_prompt, cfg, steps, seed, output_full, strength, input_image)
|
||||||
|
|
||||||
|
@ -394,8 +392,8 @@ def img2img():
|
||||||
'cfg': cfg,
|
'cfg': cfg,
|
||||||
'negativePrompt': negative_prompt,
|
'negativePrompt': negative_prompt,
|
||||||
'steps': steps,
|
'steps': steps,
|
||||||
'height': default_height,
|
'height': height,
|
||||||
'width': default_width,
|
'width': width,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -433,11 +431,9 @@ def txt2img():
|
||||||
def inpaint():
|
def inpaint():
|
||||||
source_file = request.files.get('source')
|
source_file = request.files.get('source')
|
||||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||||
source_image.thumbnail((default_width, default_height))
|
|
||||||
|
|
||||||
mask_file = request.files.get('mask')
|
mask_file = request.files.get('mask')
|
||||||
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB')
|
||||||
mask_image.thumbnail((default_width, default_height))
|
|
||||||
|
|
||||||
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
|
(model, provider, scheduler, prompt, negative_prompt, cfg, steps, height,
|
||||||
width, seed) = pipeline_from_request()
|
width, seed) = pipeline_from_request()
|
||||||
|
@ -446,6 +442,8 @@ def inpaint():
|
||||||
'inpaint', seed, (prompt, cfg, steps, height, width, seed))
|
'inpaint', seed, (prompt, cfg, steps, height, width, seed))
|
||||||
print("inpaint output: %s" % output_full)
|
print("inpaint output: %s" % output_full)
|
||||||
|
|
||||||
|
source_image.thumbnail((width, height))
|
||||||
|
mask_image.thumbnail((width, height))
|
||||||
executor.submit_stored(output_file, run_inpaint_pipeline, model, provider, scheduler, prompt, negative_prompt,
|
executor.submit_stored(output_file, run_inpaint_pipeline, model, provider, scheduler, prompt, negative_prompt,
|
||||||
cfg, steps, seed, output_full, height, width, source_image, mask_image)
|
cfg, steps, seed, output_full, height, width, source_image, mask_image)
|
||||||
|
|
||||||
|
@ -460,8 +458,8 @@ def inpaint():
|
||||||
'cfg': cfg,
|
'cfg': cfg,
|
||||||
'negativePrompt': negative_prompt,
|
'negativePrompt': negative_prompt,
|
||||||
'steps': steps,
|
'steps': steps,
|
||||||
'height': default_height,
|
'height': height,
|
||||||
'width': default_width,
|
'width': width,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue