feat(api): write model hashes to image exif
This commit is contained in:
parent
003a350a6c
commit
062b1c47aa
|
@ -183,7 +183,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
model_formats = ["onnx", "pth", "ckpt", "safetensors"]
|
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||||
|
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
|
||||||
|
|
||||||
|
|
||||||
def source_format(model: Dict) -> Optional[str]:
|
def source_format(model: Dict) -> Optional[str]:
|
||||||
|
@ -192,7 +193,7 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
|
|
||||||
if "source" in model:
|
if "source" in model:
|
||||||
_name, ext = path.splitext(model["source"])
|
_name, ext = path.splitext(model["source"])
|
||||||
if ext in model_formats:
|
if ext in MODEL_FORMATS:
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
@ -231,7 +232,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
else:
|
else:
|
||||||
logger.debug("searching for tensors with known extensions")
|
logger.debug("searching for tensors with known extensions")
|
||||||
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
for next_extension in RESOLVE_FORMATS:
|
||||||
next_name = f"{name}.{next_extension}"
|
next_name = f"{name}.{next_extension}"
|
||||||
if path.exists(next_name):
|
if path.exists(next_name):
|
||||||
checkpoint = load_tensor(next_name, map_location=map_location)
|
checkpoint = load_tensor(next_name, map_location=map_location)
|
||||||
|
@ -275,6 +276,16 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_tensor(name: str) -> Optional[str]:
|
||||||
|
logger.debug("searching for tensors with known extensions: %s", name)
|
||||||
|
for next_extension in RESOLVE_FORMATS:
|
||||||
|
next_name = f"{name}.{next_extension}"
|
||||||
|
if path.exists(next_name):
|
||||||
|
return next_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def onnx_export(
|
def onnx_export(
|
||||||
model,
|
model,
|
||||||
model_args: tuple,
|
model_args: tuple,
|
||||||
|
|
|
@ -10,12 +10,30 @@ from piexif import ExifIFD, ImageIFD, dump
|
||||||
from piexif.helper import UserComment
|
from piexif.helper import UserComment
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
|
from onnx_web.convert.utils import resolve_tensor
|
||||||
|
from onnx_web.server.load import get_extra_hashes
|
||||||
|
|
||||||
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
||||||
from .server import ServerContext
|
from .server import ServerContext
|
||||||
from .utils import base_join
|
from .utils import base_join
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
HASH_BUFFER_SIZE = 2**22 # 4MB
|
||||||
|
|
||||||
|
|
||||||
|
def hash_file(name: str):
|
||||||
|
sha = sha256()
|
||||||
|
with open(name, "rb") as f:
|
||||||
|
while True:
|
||||||
|
data = f.read(HASH_BUFFER_SIZE)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
sha.update(data)
|
||||||
|
|
||||||
|
return sha.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def hash_value(sha, param: Optional[Param]):
|
def hash_value(sha, param: Optional[Param]):
|
||||||
if param is None:
|
if param is None:
|
||||||
|
@ -68,24 +86,57 @@ def json_params(
|
||||||
|
|
||||||
|
|
||||||
def str_params(
|
def str_params(
|
||||||
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
inversions: List[Tuple[str, float]] = None,
|
inversions: List[Tuple[str, float]] = None,
|
||||||
loras: List[Tuple[str, float]] = None,
|
loras: List[Tuple[str, float]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
lora_hashes = (
|
model_hash = get_extra_hashes().get(params.model, "unknown")
|
||||||
",".join([f"{name}: TODO" for name, weight in loras])
|
model_name = path.basename(path.normpath(params.model))
|
||||||
if loras is not None
|
hash_map = {
|
||||||
else ""
|
model_name: model_hash,
|
||||||
|
}
|
||||||
|
|
||||||
|
inversion_hashes = ""
|
||||||
|
if inversions is not None:
|
||||||
|
inversion_pairs = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||||
|
).upper(),
|
||||||
)
|
)
|
||||||
|
for name, _weight in inversions
|
||||||
|
]
|
||||||
|
inversion_hashes = ",".join(
|
||||||
|
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
||||||
|
)
|
||||||
|
hash_map.update(dict(inversion_pairs))
|
||||||
|
|
||||||
|
lora_hashes = ""
|
||||||
|
if loras is not None:
|
||||||
|
lora_pairs = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||||
|
).upper(),
|
||||||
|
)
|
||||||
|
for name, _weight in loras
|
||||||
|
]
|
||||||
|
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
||||||
|
hash_map.update(dict(lora_pairs))
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"{params.input_prompt}.\nNegative prompt: {params.input_negative_prompt}.\n"
|
f"{params.input_prompt}\nNegative prompt: {params.input_negative_prompt}\n"
|
||||||
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
|
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
|
||||||
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
|
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
|
||||||
f"Model hash: TODO, Model: {params.model}, "
|
f"Model hash: {model_hash}, Model: {model_name}, "
|
||||||
|
f"Tool: onnx-web, Version: {server.server_version}, "
|
||||||
|
f'Inversion hashes: "{inversion_hashes}", '
|
||||||
f'Lora hashes: "{lora_hashes}", '
|
f'Lora hashes: "{lora_hashes}", '
|
||||||
f"Version: TODO, Tool: onnx-web"
|
f"Hashes: {dumps(hash_map)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,10 +208,10 @@ def save_image(
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
exif.add_text("model", "TODO: server.version")
|
exif.add_text("model", server.server_version)
|
||||||
exif.add_text(
|
exif.add_text(
|
||||||
"parameters",
|
"parameters",
|
||||||
str_params(params, size, inversions=inversions, loras=loras),
|
str_params(server, params, size, inversions=inversions, loras=loras),
|
||||||
)
|
)
|
||||||
|
|
||||||
image.save(path, format=server.image_format, pnginfo=exif)
|
image.save(path, format=server.image_format, pnginfo=exif)
|
||||||
|
@ -182,11 +233,13 @@ def save_image(
|
||||||
encoding="unicode",
|
encoding="unicode",
|
||||||
),
|
),
|
||||||
ExifIFD.UserComment: UserComment.dump(
|
ExifIFD.UserComment: UserComment.dump(
|
||||||
str_params(params, size, inversions=inversions, loras=loras),
|
str_params(
|
||||||
|
server, params, size, inversions=inversions, loras=loras
|
||||||
|
),
|
||||||
encoding="unicode",
|
encoding="unicode",
|
||||||
),
|
),
|
||||||
ImageIFD.Make: "onnx-web",
|
ImageIFD.Make: "onnx-web",
|
||||||
ImageIFD.Model: "TODO: server.version",
|
ImageIFD.Model: server.server_version,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -92,6 +92,7 @@ network_models: List[NetworkModel] = []
|
||||||
upscaling_models: List[str] = []
|
upscaling_models: List[str] = []
|
||||||
|
|
||||||
# Loaded from extra_models
|
# Loaded from extra_models
|
||||||
|
extra_hashes: Dict[str, str] = {}
|
||||||
extra_strings: Dict[str, Any] = {}
|
extra_strings: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -123,6 +124,10 @@ def get_extra_strings():
|
||||||
return extra_strings
|
return extra_strings
|
||||||
|
|
||||||
|
|
||||||
|
def get_extra_hashes():
|
||||||
|
return extra_hashes
|
||||||
|
|
||||||
|
|
||||||
def get_highres_methods():
|
def get_highres_methods():
|
||||||
return highres_methods
|
return highres_methods
|
||||||
|
|
||||||
|
@ -147,6 +152,7 @@ def load_extras(server: ServerContext):
|
||||||
"""
|
"""
|
||||||
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
Load the extras file(s) and collect the relevant parts for the server: labels and strings
|
||||||
"""
|
"""
|
||||||
|
global extra_hashes
|
||||||
global extra_strings
|
global extra_strings
|
||||||
|
|
||||||
labels = {}
|
labels = {}
|
||||||
|
@ -173,8 +179,18 @@ def load_extras(server: ServerContext):
|
||||||
for model_type in ["diffusion", "correction", "upscaling", "networks"]:
|
for model_type in ["diffusion", "correction", "upscaling", "networks"]:
|
||||||
if model_type in data:
|
if model_type in data:
|
||||||
for model in data[model_type]:
|
for model in data[model_type]:
|
||||||
if "label" in model:
|
|
||||||
model_name = model["name"]
|
model_name = model["name"]
|
||||||
|
|
||||||
|
if "hash" in model:
|
||||||
|
logger.debug(
|
||||||
|
"collecting hash for model %s from %s",
|
||||||
|
model_name,
|
||||||
|
file,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_hashes[model_name] = model["hash"]
|
||||||
|
|
||||||
|
if "label" in model:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"collecting label for model %s from %s",
|
"collecting label for model %s from %s",
|
||||||
model_name,
|
model_name,
|
||||||
|
|
Loading…
Reference in New Issue