lint(api): add seed to output filename, add types
This commit is contained in:
parent
63758b0e21
commit
778cf6e7d1
|
@ -16,6 +16,8 @@ from diffusers import (
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
OnnxStableDiffusionInpaintPipeline,
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
|
# types
|
||||||
|
DiffusionPipeline,
|
||||||
)
|
)
|
||||||
from flask import Flask, jsonify, request, send_from_directory, url_for
|
from flask import Flask, jsonify, request, send_from_directory, url_for
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
@ -79,15 +81,15 @@ pipeline_schedulers = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0):
|
def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
|
||||||
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1):
|
def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1) -> int:
|
||||||
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
||||||
|
|
||||||
|
|
||||||
def get_from_map(args, key: str, values, default):
|
def get_from_map(args, key: str, values: dict[str, Any], default: Any):
|
||||||
selected = args.get(key, default)
|
selected = args.get(key, default)
|
||||||
if selected in values:
|
if selected in values:
|
||||||
return values[selected]
|
return values[selected]
|
||||||
|
@ -109,7 +111,7 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
|
||||||
return image_latents
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(pipeline, model: str, provider: str, scheduler):
|
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_scheduler
|
global last_pipeline_scheduler
|
||||||
global last_pipeline_options
|
global last_pipeline_options
|
||||||
|
@ -146,7 +148,7 @@ def json_with_cors(data, origin='*'):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def make_output_path(type: str, params: Tuple[Union[str, int, float]]):
|
def make_output_path(type: str, seed: int, params: Tuple[Union[str, int, float]]):
|
||||||
sha = sha256()
|
sha = sha256()
|
||||||
sha.update(type.encode('utf-8'))
|
sha.update(type.encode('utf-8'))
|
||||||
for param in params:
|
for param in params:
|
||||||
|
@ -159,7 +161,7 @@ def make_output_path(type: str, params: Tuple[Union[str, int, float]]):
|
||||||
else:
|
else:
|
||||||
print('cannot hash param: %s, %s' % (param, type(param)))
|
print('cannot hash param: %s, %s' % (param, type(param)))
|
||||||
|
|
||||||
output_file = '%s_%s.png' % (type, sha.hexdigest())
|
output_file = '%s_%s_%s.png' % (type, seed, sha.hexdigest())
|
||||||
output_full = safer_join(output_path, output_file)
|
output_full = safer_join(output_path, output_file)
|
||||||
|
|
||||||
return (output_file, output_full)
|
return (output_file, output_full)
|
||||||
|
@ -238,7 +240,7 @@ def list_schedulers():
|
||||||
return json_with_cors(list(pipeline_schedulers.keys()))
|
return json_with_cors(list(pipeline_schedulers.keys()))
|
||||||
|
|
||||||
|
|
||||||
def pipeline_from_request(pipeline):
|
def pipeline_from_request(pipeline: DiffusionPipeline):
|
||||||
user = request.remote_addr
|
user = request.remote_addr
|
||||||
|
|
||||||
# pipeline stuff
|
# pipeline stuff
|
||||||
|
@ -250,12 +252,13 @@ def pipeline_from_request(pipeline):
|
||||||
|
|
||||||
# image params
|
# image params
|
||||||
prompt = request.args.get('prompt', default_prompt)
|
prompt = request.args.get('prompt', default_prompt)
|
||||||
negative_prompt = request.args.get('negative', None);
|
negative_prompt = request.args.get('negative', None)
|
||||||
|
|
||||||
cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0)
|
cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0)
|
||||||
steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps)
|
steps = get_and_clamp_int(request.args, 'steps', default_steps, config_params.get('steps').get('max'))
|
||||||
height = get_and_clamp_int(request.args, 'height', default_height, max_height)
|
height = get_and_clamp_int(
|
||||||
width = get_and_clamp_int(request.args, 'width', default_width, max_width)
|
request.args, 'height', default_height, config_params.get('height').get('max'))
|
||||||
|
width = get_and_clamp_int(request.args, 'width', default_width, config_params.get('width').get('max'))
|
||||||
|
|
||||||
seed = int(request.args.get('seed', -1))
|
seed = int(request.args.get('seed', -1))
|
||||||
if seed == -1:
|
if seed == -1:
|
||||||
|
@ -405,5 +408,5 @@ def inpaint():
|
||||||
|
|
||||||
|
|
||||||
@app.route('/output/<path:filename>')
|
@app.route('/output/<path:filename>')
|
||||||
def output(filename):
|
def output(filename: str):
|
||||||
return send_from_directory(path.join('..', output_path), filename, as_attachment=False)
|
return send_from_directory(path.join('..', output_path), filename, as_attachment=False)
|
||||||
|
|
|
@ -13,13 +13,20 @@
|
||||||
],
|
],
|
||||||
"settings": {
|
"settings": {
|
||||||
"cSpell.words": [
|
"cSpell.words": [
|
||||||
|
"astype",
|
||||||
|
"CUDA",
|
||||||
"ddim",
|
"ddim",
|
||||||
"ddpm",
|
"ddpm",
|
||||||
"directml",
|
"directml",
|
||||||
"ftfy",
|
"ftfy",
|
||||||
|
"Heun",
|
||||||
"huggingface",
|
"huggingface",
|
||||||
"Inpaint",
|
"Inpaint",
|
||||||
|
"jsonify",
|
||||||
|
"Karras",
|
||||||
|
"KDPM",
|
||||||
"Multistep",
|
"Multistep",
|
||||||
|
"ndarray",
|
||||||
"numpy",
|
"numpy",
|
||||||
"Onnx",
|
"Onnx",
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
|
@ -27,7 +34,9 @@
|
||||||
"pretrained",
|
"pretrained",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
"runwayml",
|
"runwayml",
|
||||||
|
"scandir",
|
||||||
"scipy",
|
"scipy",
|
||||||
|
"Singlestep",
|
||||||
"spacy",
|
"spacy",
|
||||||
"spinalcase",
|
"spinalcase",
|
||||||
"stringcase",
|
"stringcase",
|
||||||
|
|
Loading…
Reference in New Issue