1
0
Fork 0

feat(api): write model hashes to image exif

This commit is contained in:
Sean Sube 2023-06-26 17:24:34 -05:00
parent 003a350a6c
commit 062b1c47aa
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 96 additions and 16 deletions

View File

@ -183,7 +183,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
return model
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
def source_format(model: Dict) -> Optional[str]:
@ -192,7 +193,7 @@ def source_format(model: Dict) -> Optional[str]:
if "source" in model:
_name, ext = path.splitext(model["source"])
if ext in model_formats:
if ext in MODEL_FORMATS:
return ext
return None
@ -231,7 +232,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
checkpoint = torch.load(name, map_location=map_location)
else:
logger.debug("searching for tensors with known extensions")
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}"
if path.exists(next_name):
checkpoint = load_tensor(next_name, map_location=map_location)
@ -275,6 +276,16 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
return checkpoint
def resolve_tensor(name: str) -> Optional[str]:
logger.debug("searching for tensors with known extensions: %s", name)
for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}"
if path.exists(next_name):
return next_name
return None
def onnx_export(
model,
model_args: tuple,

View File

@ -10,12 +10,30 @@ from piexif import ExifIFD, ImageIFD, dump
from piexif.helper import UserComment
from PIL import Image, PngImagePlugin
from onnx_web.convert.utils import resolve_tensor
from onnx_web.server.load import get_extra_hashes
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
from .server import ServerContext
from .utils import base_join
logger = getLogger(__name__)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
if param is None:
@ -68,24 +86,57 @@ def json_params(
def str_params(
server: ServerContext,
params: ImageParams,
size: Size,
inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None,
) -> str:
lora_hashes = (
",".join([f"{name}: TODO" for name, weight in loras])
if loras is not None
else ""
model_hash = get_extra_hashes().get(params.model, "unknown")
model_name = path.basename(path.normpath(params.model))
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
)
for name, _weight in inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{params.input_prompt}.\nNegative prompt: {params.input_negative_prompt}.\n"
f"{params.input_prompt}\nNegative prompt: {params.input_negative_prompt}\n"
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
f"Model hash: TODO, Model: {params.model}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Version: TODO, Tool: onnx-web"
f"Hashes: {dumps(hash_map)}"
)
@ -157,10 +208,10 @@ def save_image(
)
),
)
exif.add_text("model", "TODO: server.version")
exif.add_text("model", server.server_version)
exif.add_text(
"parameters",
str_params(params, size, inversions=inversions, loras=loras),
str_params(server, params, size, inversions=inversions, loras=loras),
)
image.save(path, format=server.image_format, pnginfo=exif)
@ -182,11 +233,13 @@ def save_image(
encoding="unicode",
),
ExifIFD.UserComment: UserComment.dump(
str_params(params, size, inversions=inversions, loras=loras),
str_params(
server, params, size, inversions=inversions, loras=loras
),
encoding="unicode",
),
ImageIFD.Make: "onnx-web",
ImageIFD.Model: "TODO: server.version",
ImageIFD.Model: server.server_version,
}
}
)

View File

@ -92,6 +92,7 @@ network_models: List[NetworkModel] = []
upscaling_models: List[str] = []
# Loaded from extra_models
extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {}
@ -123,6 +124,10 @@ def get_extra_strings():
return extra_strings
def get_extra_hashes():
return extra_hashes
def get_highres_methods():
return highres_methods
@ -147,6 +152,7 @@ def load_extras(server: ServerContext):
"""
Load the extras file(s) and collect the relevant parts for the server: labels and strings
"""
global extra_hashes
global extra_strings
labels = {}
@ -173,8 +179,18 @@ def load_extras(server: ServerContext):
for model_type in ["diffusion", "correction", "upscaling", "networks"]:
if model_type in data:
for model in data[model_type]:
if "label" in model:
model_name = model["name"]
if "hash" in model:
logger.debug(
"collecting hash for model %s from %s",
model_name,
file,
)
extra_hashes[model_name] = model["hash"]
if "label" in model:
logger.debug(
"collecting label for model %s from %s",
model_name,