1
0
Fork 0

lint(api): clean up output file hash stuff

This commit is contained in:
Sean Sube 2023-01-15 19:47:57 -06:00
parent e43238b327
commit bdc8e277fe
3 changed files with 110 additions and 71 deletions

View File

@ -1,10 +1,33 @@
from .image import (
expand_image,
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from .pipeline import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
)
from .upscale import (
gfpgan_url,
make_resrgan,
resrgan_url,
)
from .utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_map,
safer_join,
BaseParams,
Border,
OutputPath,
Point,
Size,
)

View File

@ -16,12 +16,10 @@ from diffusers import (
from flask import Flask, jsonify, request, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor
from hashlib import sha256
from io import BytesIO
from PIL import Image
from struct import pack
from os import environ, makedirs, path, scandir
from typing import Tuple, Union
from typing import Tuple
from .image import (
# mask filters
@ -47,16 +45,14 @@ from .utils import (
get_and_clamp_float,
get_and_clamp_int,
get_from_map,
safer_join,
make_output_path,
BaseParams,
Border,
OutputPath,
Size,
)
import json
import numpy as np
import time
# paths
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
@ -112,44 +108,6 @@ def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename)
def hash_value(sha, param):
if param is None:
return
elif isinstance(param, float):
sha.update(bytearray(pack('!f', param)))
elif isinstance(param, int):
sha.update(bytearray(pack('!I', param)))
elif isinstance(param, str):
sha.update(param.encode('utf-8'))
else:
print('cannot hash param: %s, %s' % (param, type(param)))
def make_output_path(mode: str, params: BaseParams, size: Size, extras: Tuple[Union[str, int, float]]) -> OutputPath:
now = int(time.time())
sha = sha256()
hash_value(mode)
hash_value(params.model)
hash_value(params.provider)
hash_value(params.scheduler)
hash_value(params.prompt)
hash_value(params.negative_prompt)
hash_value(params.cfg)
hash_value(params.steps)
hash_value(params.seed)
hash_value(size.width)
hash_value(size.height)
for param in extras:
hash_value(sha, param)
output_file = '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
output_full = safer_join(output_path, output_file)
return OutputPath(output_full, output_file)
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
@ -306,6 +264,7 @@ def img2img():
params, size = pipeline_from_request()
output = make_output_path(
output_path,
'img2img',
params,
size,
@ -328,6 +287,7 @@ def txt2img():
params, size = pipeline_from_request()
output = make_output_path(
output_path,
'txt2img',
params,
size)
@ -368,6 +328,7 @@ def inpaint():
request.args, 'noise', noise_sources, 'histogram')
output = make_output_path(
output_path,
'inpaint',
params,
size,

View File

@ -1,33 +1,19 @@
from os import path
from typing import Any, Dict, Tuple
import time
from struct import pack
from typing import Any, Dict, Tuple, Union
from hashlib import sha256
Param = Union[str, int, float]
Point = Tuple[int, int]
def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_map(args, key: str, values: Dict[str, Any], default: Any):
selected = args.get(key, default)
if selected in values:
return values[selected]
else:
return values[default]
def safer_join(base, tail):
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, safer_path)
# TODO: .path is only used in one place, can probably just be a str
class OutputPath:
'''
TODO: .path is only used in one place, can probably just be a str
'''
def __init__(self, path, file):
self.path = path
self.file = file
@ -44,7 +30,7 @@ class BaseParams:
self.steps = steps
self.seed = seed
def tojson(self) -> Dict[str, Any]:
def tojson(self) -> Dict[str, Param]:
return {
'model': self.model,
'provider': self.provider,
@ -58,7 +44,7 @@ class BaseParams:
class Border:
def __init__(self, left, right, top, bottom):
def __init__(self, left: int, right: int, top: int, bottom: int):
self.left = left
self.right = right
self.top = top
@ -66,12 +52,81 @@ class Border:
class Size:
def __init__(self, width, height):
def __init__(self, width: int, height: int):
self.width = width
self.height = height
def tojson(self) -> Dict[str, Any]:
def tojson(self) -> Dict[str, int]:
return {
'height': self.height,
'width': self.width,
}
def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any):
selected = args.get(key, default)
if selected in values:
return values[selected]
else:
return values[default]
def safer_join(base: str, tail: str) -> str:
safer_path = path.relpath(path.normpath(path.join('/', tail)), '/')
return path.join(base, safer_path)
def hash_value(sha, param: Param):
'''
TODO: include functions by name
'''
if param is None:
return
elif isinstance(param, float):
sha.update(bytearray(pack('!f', param)))
elif isinstance(param, int):
sha.update(bytearray(pack('!I', param)))
elif isinstance(param, str):
sha.update(param.encode('utf-8'))
else:
print('cannot hash param: %s, %s' % (param, type(param)))
def make_output_path(
root: str,
mode: str,
params: BaseParams,
size: Size,
extras: Union[None, Tuple[Param]] = None
) -> OutputPath:
now = int(time.time())
sha = sha256()
hash_value(mode)
hash_value(params.model)
hash_value(params.provider)
hash_value(params.scheduler.__name__)
hash_value(params.prompt)
hash_value(params.negative_prompt)
hash_value(params.cfg)
hash_value(params.steps)
hash_value(params.seed)
hash_value(size.width)
hash_value(size.height)
if extras is not None:
for param in extras:
hash_value(sha, param)
output_file = '%s_%s_%s_%s.png' % (mode, params.seed, sha.hexdigest(), now)
output_full = safer_join(root, output_file)
return OutputPath(output_full, output_file)