feat(api): write model hashes to image exif
This commit is contained in:
parent
003a350a6c
commit
062b1c47aa
|
@ -183,7 +183,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|||
return model
|
||||
|
||||
|
||||
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
|
||||
|
||||
|
||||
def source_format(model: Dict) -> Optional[str]:
|
||||
|
@ -192,7 +193,7 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
|
||||
if "source" in model:
|
||||
_name, ext = path.splitext(model["source"])
|
||||
if ext in model_formats:
|
||||
if ext in MODEL_FORMATS:
|
||||
return ext
|
||||
|
||||
return None
|
||||
|
@ -231,7 +232,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
checkpoint = torch.load(name, map_location=map_location)
|
||||
else:
|
||||
logger.debug("searching for tensors with known extensions")
|
||||
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
||||
for next_extension in RESOLVE_FORMATS:
|
||||
next_name = f"{name}.{next_extension}"
|
||||
if path.exists(next_name):
|
||||
checkpoint = load_tensor(next_name, map_location=map_location)
|
||||
|
@ -275,6 +276,16 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
return checkpoint
|
||||
|
||||
|
||||
def resolve_tensor(name: str) -> Optional[str]:
|
||||
logger.debug("searching for tensors with known extensions: %s", name)
|
||||
for next_extension in RESOLVE_FORMATS:
|
||||
next_name = f"{name}.{next_extension}"
|
||||
if path.exists(next_name):
|
||||
return next_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
|
|
|
@ -10,12 +10,30 @@ from piexif import ExifIFD, ImageIFD, dump
|
|||
from piexif.helper import UserComment
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from onnx_web.convert.utils import resolve_tensor
|
||||
from onnx_web.server.load import get_extra_hashes
|
||||
|
||||
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
||||
from .server import ServerContext
|
||||
from .utils import base_join
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
HASH_BUFFER_SIZE = 2**22 # 4MB
|
||||
|
||||
|
||||
def hash_file(name: str):
|
||||
sha = sha256()
|
||||
with open(name, "rb") as f:
|
||||
while True:
|
||||
data = f.read(HASH_BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
|
||||
sha.update(data)
|
||||
|
||||
return sha.hexdigest()
|
||||
|
||||
|
||||
def hash_value(sha, param: Optional[Param]):
|
||||
if param is None:
|
||||
|
@ -68,24 +86,57 @@ def json_params(
|
|||
|
||||
|
||||
def str_params(
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
inversions: List[Tuple[str, float]] = None,
|
||||
loras: List[Tuple[str, float]] = None,
|
||||
) -> str:
|
||||
lora_hashes = (
|
||||
",".join([f"{name}: TODO" for name, weight in loras])
|
||||
if loras is not None
|
||||
else ""
|
||||
model_hash = get_extra_hashes().get(params.model, "unknown")
|
||||
model_name = path.basename(path.normpath(params.model))
|
||||
hash_map = {
|
||||
model_name: model_hash,
|
||||
}
|
||||
|
||||
inversion_hashes = ""
|
||||
if inversions is not None:
|
||||
inversion_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper(),
|
||||
)
|
||||
for name, _weight in inversions
|
||||
]
|
||||
inversion_hashes = ",".join(
|
||||
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
||||
)
|
||||
hash_map.update(dict(inversion_pairs))
|
||||
|
||||
lora_hashes = ""
|
||||
if loras is not None:
|
||||
lora_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper(),
|
||||
)
|
||||
for name, _weight in loras
|
||||
]
|
||||
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
||||
hash_map.update(dict(lora_pairs))
|
||||
|
||||
return (
|
||||
f"{params.input_prompt}.\nNegative prompt: {params.input_negative_prompt}.\n"
|
||||
f"{params.input_prompt}\nNegative prompt: {params.input_negative_prompt}\n"
|
||||
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
|
||||
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
|
||||
f"Model hash: TODO, Model: {params.model}, "
|
||||
f"Model hash: {model_hash}, Model: {model_name}, "
|
||||
f"Tool: onnx-web, Version: {server.server_version}, "
|
||||
f'Inversion hashes: "{inversion_hashes}", '
|
||||
f'Lora hashes: "{lora_hashes}", '
|
||||
f"Version: TODO, Tool: onnx-web"
|
||||
f"Hashes: {dumps(hash_map)}"
|
||||
)
|
||||
|
||||
|
||||
|
@ -157,10 +208,10 @@ def save_image(
|
|||
)
|
||||
),
|
||||
)
|
||||
exif.add_text("model", "TODO: server.version")
|
||||
exif.add_text("model", server.server_version)
|
||||
exif.add_text(
|
||||
"parameters",
|
||||
str_params(params, size, inversions=inversions, loras=loras),
|
||||
str_params(server, params, size, inversions=inversions, loras=loras),
|
||||
)
|
||||
|
||||
image.save(path, format=server.image_format, pnginfo=exif)
|
||||
|
@ -182,11 +233,13 @@ def save_image(
|
|||
encoding="unicode",
|
||||
),
|
||||
ExifIFD.UserComment: UserComment.dump(
|
||||
str_params(params, size, inversions=inversions, loras=loras),
|
||||
str_params(
|
||||
server, params, size, inversions=inversions, loras=loras
|
||||
),
|
||||
encoding="unicode",
|
||||
),
|
||||
ImageIFD.Make: "onnx-web",
|
||||
ImageIFD.Model: "TODO: server.version",
|
||||
ImageIFD.Model: server.server_version,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
|
@ -92,6 +92,7 @@ network_models: List[NetworkModel] = []
|
|||
upscaling_models: List[str] = []
|
||||
|
||||
# Loaded from extra_models
|
||||
extra_hashes: Dict[str, str] = {}
|
||||
extra_strings: Dict[str, Any] = {}
|
||||
|
||||
|
||||
|
@ -123,6 +124,10 @@ def get_extra_strings():
|
|||
return extra_strings
|
||||
|
||||
|
||||
def get_extra_hashes():
|
||||
return extra_hashes
|
||||
|
||||
|
||||
def get_highres_methods():
|
||||
return highres_methods
|
||||
|
||||
|
@ -147,6 +152,7 @@ def load_extras(server: ServerContext):
|
|||
"""
|
||||
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
||||
"""
|
||||
global extra_hashes
|
||||
global extra_strings
|
||||
|
||||
labels = {}
|
||||
|
@ -173,8 +179,18 @@ def load_extras(server: ServerContext):
|
|||
for model_type in ["diffusion", "correction", "upscaling", "networks"]:
|
||||
if model_type in data:
|
||||
for model in data[model_type]:
|
||||
if "label" in model:
|
||||
model_name = model["name"]
|
||||
|
||||
if "hash" in model:
|
||||
logger.debug(
|
||||
"collecting hash for model %s from %s",
|
||||
model_name,
|
||||
file,
|
||||
)
|
||||
|
||||
extra_hashes[model_name] = model["hash"]
|
||||
|
||||
if "label" in model:
|
||||
logger.debug(
|
||||
"collecting label for model %s from %s",
|
||||
model_name,
|
||||
|
|
Loading…
Reference in New Issue