1
0
Fork 0

lint(api): add seed to output filename, add types

This commit is contained in:
Sean Sube 2023-01-09 23:26:47 -06:00
parent 63758b0e21
commit 778cf6e7d1
2 changed files with 25 additions and 13 deletions

View File

@ -16,6 +16,8 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipeline,
# types
DiffusionPipeline,
) )
from flask import Flask, jsonify, request, send_from_directory, url_for from flask import Flask, jsonify, request, send_from_directory, url_for
from hashlib import sha256 from hashlib import sha256
@ -79,15 +81,15 @@ pipeline_schedulers = {
} }
def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0): def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value) return min(max(float(args.get(key, default_value)), min_value), max_value)
def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1): def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value) return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_map(args, key: str, values, default): def get_from_map(args, key: str, values: dict[str, Any], default: Any):
selected = args.get(key, default) selected = args.get(key, default)
if selected in values: if selected in values:
return values[selected] return values[selected]
@ -109,7 +111,7 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray:
return image_latents return image_latents
def load_pipeline(pipeline, model: str, provider: str, scheduler): def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_scheduler global last_pipeline_scheduler
global last_pipeline_options global last_pipeline_options
@ -146,7 +148,7 @@ def json_with_cors(data, origin='*'):
return res return res
def make_output_path(type: str, params: Tuple[Union[str, int, float]]): def make_output_path(type: str, seed: int, params: Tuple[Union[str, int, float]]):
sha = sha256() sha = sha256()
sha.update(type.encode('utf-8')) sha.update(type.encode('utf-8'))
for param in params: for param in params:
@ -159,7 +161,7 @@ def make_output_path(type: str, params: Tuple[Union[str, int, float]]):
else: else:
print('cannot hash param: %s, %s' % (param, type(param))) print('cannot hash param: %s, %s' % (param, type(param)))
output_file = '%s_%s.png' % (type, sha.hexdigest()) output_file = '%s_%s_%s.png' % (type, seed, sha.hexdigest())
output_full = safer_join(output_path, output_file) output_full = safer_join(output_path, output_file)
return (output_file, output_full) return (output_file, output_full)
@ -238,7 +240,7 @@ def list_schedulers():
return json_with_cors(list(pipeline_schedulers.keys())) return json_with_cors(list(pipeline_schedulers.keys()))
def pipeline_from_request(pipeline): def pipeline_from_request(pipeline: DiffusionPipeline):
user = request.remote_addr user = request.remote_addr
# pipeline stuff # pipeline stuff
@ -250,12 +252,13 @@ def pipeline_from_request(pipeline):
# image params # image params
prompt = request.args.get('prompt', default_prompt) prompt = request.args.get('prompt', default_prompt)
negative_prompt = request.args.get('negative', None); negative_prompt = request.args.get('negative', None)
cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, max_cfg, 0) cfg = get_and_clamp_int(request.args, 'cfg', default_cfg, config_params.get('cfg').get('max'), 0)
steps = get_and_clamp_int(request.args, 'steps', default_steps, max_steps) steps = get_and_clamp_int(request.args, 'steps', default_steps, config_params.get('steps').get('max'))
height = get_and_clamp_int(request.args, 'height', default_height, max_height) height = get_and_clamp_int(
width = get_and_clamp_int(request.args, 'width', default_width, max_width) request.args, 'height', default_height, config_params.get('height').get('max'))
width = get_and_clamp_int(request.args, 'width', default_width, config_params.get('width').get('max'))
seed = int(request.args.get('seed', -1)) seed = int(request.args.get('seed', -1))
if seed == -1: if seed == -1:
@ -405,5 +408,5 @@ def inpaint():
@app.route('/output/<path:filename>') @app.route('/output/<path:filename>')
def output(filename): def output(filename: str):
return send_from_directory(path.join('..', output_path), filename, as_attachment=False) return send_from_directory(path.join('..', output_path), filename, as_attachment=False)

View File

@ -13,13 +13,20 @@
], ],
"settings": { "settings": {
"cSpell.words": [ "cSpell.words": [
"astype",
"CUDA",
"ddim", "ddim",
"ddpm", "ddpm",
"directml", "directml",
"ftfy", "ftfy",
"Heun",
"huggingface", "huggingface",
"Inpaint", "Inpaint",
"jsonify",
"Karras",
"KDPM",
"Multistep", "Multistep",
"ndarray",
"numpy", "numpy",
"Onnx", "Onnx",
"onnxruntime", "onnxruntime",
@ -27,7 +34,9 @@
"pretrained", "pretrained",
"protobuf", "protobuf",
"runwayml", "runwayml",
"scandir",
"scipy", "scipy",
"Singlestep",
"spacy", "spacy",
"spinalcase", "spinalcase",
"stringcase", "stringcase",