1
0
Fork 0

feat(api): write model hashes to image exif

This commit is contained in:
Sean Sube 2023-06-26 17:24:34 -05:00
parent 003a350a6c
commit 062b1c47aa
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 96 additions and 16 deletions

View File

@ -183,7 +183,8 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
return model return model
model_formats = ["onnx", "pth", "ckpt", "safetensors"] MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
def source_format(model: Dict) -> Optional[str]: def source_format(model: Dict) -> Optional[str]:
@ -192,7 +193,7 @@ def source_format(model: Dict) -> Optional[str]:
if "source" in model: if "source" in model:
_name, ext = path.splitext(model["source"]) _name, ext = path.splitext(model["source"])
if ext in model_formats: if ext in MODEL_FORMATS:
return ext return ext
return None return None
@ -231,7 +232,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
checkpoint = torch.load(name, map_location=map_location) checkpoint = torch.load(name, map_location=map_location)
else: else:
logger.debug("searching for tensors with known extensions") logger.debug("searching for tensors with known extensions")
for next_extension in ["safetensors", "ckpt", "pt", "bin"]: for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}" next_name = f"{name}.{next_extension}"
if path.exists(next_name): if path.exists(next_name):
checkpoint = load_tensor(next_name, map_location=map_location) checkpoint = load_tensor(next_name, map_location=map_location)
@ -275,6 +276,16 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
return checkpoint return checkpoint
def resolve_tensor(name: str) -> Optional[str]:
logger.debug("searching for tensors with known extensions: %s", name)
for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}"
if path.exists(next_name):
return next_name
return None
def onnx_export( def onnx_export(
model, model,
model_args: tuple, model_args: tuple,

View File

@ -10,12 +10,30 @@ from piexif import ExifIFD, ImageIFD, dump
from piexif.helper import UserComment from piexif.helper import UserComment
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from onnx_web.convert.utils import resolve_tensor
from onnx_web.server.load import get_extra_hashes
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
from .server import ServerContext from .server import ServerContext
from .utils import base_join from .utils import base_join
logger = getLogger(__name__) logger = getLogger(__name__)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]): def hash_value(sha, param: Optional[Param]):
if param is None: if param is None:
@ -68,24 +86,57 @@ def json_params(
def str_params( def str_params(
server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
inversions: List[Tuple[str, float]] = None, inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None, loras: List[Tuple[str, float]] = None,
) -> str: ) -> str:
lora_hashes = ( model_hash = get_extra_hashes().get(params.model, "unknown")
",".join([f"{name}: TODO" for name, weight in loras]) model_name = path.basename(path.normpath(params.model))
if loras is not None hash_map = {
else "" model_name: model_hash,
}
inversion_hashes = ""
if inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
) )
for name, _weight in inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return ( return (
f"{params.input_prompt}.\nNegative prompt: {params.input_negative_prompt}.\n" f"{params.input_prompt}\nNegative prompt: {params.input_negative_prompt}\n"
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, " f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
f"Seed: {params.seed}, Size: {size.width}x{size.height}, " f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
f"Model hash: TODO, Model: {params.model}, " f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", ' f'Lora hashes: "{lora_hashes}", '
f"Version: TODO, Tool: onnx-web" f"Hashes: {dumps(hash_map)}"
) )
@ -157,10 +208,10 @@ def save_image(
) )
), ),
) )
exif.add_text("model", "TODO: server.version") exif.add_text("model", server.server_version)
exif.add_text( exif.add_text(
"parameters", "parameters",
str_params(params, size, inversions=inversions, loras=loras), str_params(server, params, size, inversions=inversions, loras=loras),
) )
image.save(path, format=server.image_format, pnginfo=exif) image.save(path, format=server.image_format, pnginfo=exif)
@ -182,11 +233,13 @@ def save_image(
encoding="unicode", encoding="unicode",
), ),
ExifIFD.UserComment: UserComment.dump( ExifIFD.UserComment: UserComment.dump(
str_params(params, size, inversions=inversions, loras=loras), str_params(
server, params, size, inversions=inversions, loras=loras
),
encoding="unicode", encoding="unicode",
), ),
ImageIFD.Make: "onnx-web", ImageIFD.Make: "onnx-web",
ImageIFD.Model: "TODO: server.version", ImageIFD.Model: server.server_version,
} }
} }
) )

View File

@ -92,6 +92,7 @@ network_models: List[NetworkModel] = []
upscaling_models: List[str] = [] upscaling_models: List[str] = []
# Loaded from extra_models # Loaded from extra_models
extra_hashes: Dict[str, str] = {}
extra_strings: Dict[str, Any] = {} extra_strings: Dict[str, Any] = {}
@ -123,6 +124,10 @@ def get_extra_strings():
return extra_strings return extra_strings
def get_extra_hashes():
return extra_hashes
def get_highres_methods(): def get_highres_methods():
return highres_methods return highres_methods
@ -147,6 +152,7 @@ def load_extras(server: ServerContext):
""" """
Load the extras file(s) and collect the relevant parts for the server: labels and strings Load the extras file(s) and collect the relevant parts for the server: labels and strings
""" """
global extra_hashes
global extra_strings global extra_strings
labels = {} labels = {}
@ -173,8 +179,18 @@ def load_extras(server: ServerContext):
for model_type in ["diffusion", "correction", "upscaling", "networks"]: for model_type in ["diffusion", "correction", "upscaling", "networks"]:
if model_type in data: if model_type in data:
for model in data[model_type]: for model in data[model_type]:
if "label" in model:
model_name = model["name"] model_name = model["name"]
if "hash" in model:
logger.debug(
"collecting hash for model %s from %s",
model_name,
file,
)
extra_hashes[model_name] = model["hash"]
if "label" in model:
logger.debug( logger.debug(
"collecting label for model %s from %s", "collecting label for model %s from %s",
model_name, model_name,