diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 2f9557b6..b145d91c 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -56,10 +56,16 @@ class ImageMetadata: self.loras = loras self.models = models - def get_model_hash(self, model: Optional[str] = None) -> Tuple[str, str]: + def get_model_hash( + self, server: ServerContext, model: Optional[str] = None + ) -> Tuple[str, str]: model_name = path.basename(path.normpath(model or self.params.model)) logger.debug("getting model hash for %s", model_name) + if model_name in server.hash_cache: + logger.debug("using cached model hash for %s", model_name) + return (model_name, server.hash_cache[model_name]) + model_hash = get_extra_hashes().get(model_name, None) if model_hash is None: model_hash_path = path.join(self.params.model, "hash.txt") @@ -67,10 +73,30 @@ class ImageMetadata: with open(model_hash_path, "r") as f: model_hash = f.readline().rstrip(",. \n\t\r") - return (model_name, model_hash or "unknown") + model_hash = model_hash or "unknown" + server.hash_cache[model_name] = model_hash - def to_exif(self, server, output: List[str]) -> str: - model_name, model_hash = self.get_model_hash() + return (model_name, model_hash) + + def get_network_hash( + self, server: ServerContext, network_name: str, network_type: str + ) -> Tuple[str, str]: + # run this again just in case the file path changes + network_path = resolve_tensor( + path.join(server.model_path, network_type, network_name) + ) + + if network_path in server.hash_cache: + logger.debug("using cached network hash for %s", network_path) + return (network_name, server.hash_cache[network_path]) + + network_hash = hash_file(network_path).upper() + server.hash_cache[network_path] = network_hash + + return (network_name, network_hash) + + def to_exif(self, server: ServerContext, output: List[str]) -> str: + model_name, model_hash = self.get_model_hash(server) hash_map = { model_name: model_hash, } @@ -80,9 +106,7 @@ class ImageMetadata: inversion_pairs = [ ( name, - hash_file( - resolve_tensor(path.join(server.model_path, "inversion", name)) - ).upper(), + self.get_network_hash(server, name, "inversion")[1], ) for name, _weight in self.inversions ] @@ -96,9 +120,7 @@ class ImageMetadata: lora_pairs = [ ( name, - hash_file( - resolve_tensor(path.join(server.model_path, "lora", name)) - ).upper(), + self.get_network_hash(server, name, "lora")[1], ) for name, _weight in self.loras ] @@ -127,7 +149,7 @@ class ImageMetadata: } # fix up some fields - model_name, model_hash = self.get_model_hash(self.params.model) + model_name, model_hash = self.get_model_hash(server, self.params.model) json["params"]["model"] = model_name json["models"].append( { @@ -155,18 +177,14 @@ class ImageMetadata: if self.inversions is not None: for name, weight in self.inversions: - hash = hash_file( - resolve_tensor(path.join(server.model_path, "inversion", name)) - ).upper() + hash = self.get_network_hash(server, name, "inversion")[1] json["inversions"].append( {"name": name, "weight": weight, "hash": hash} ) if self.loras is not None: for name, weight in self.loras: - hash = hash_file( - resolve_tensor(path.join(server.model_path, "lora", name)) - ).upper() + hash = self.get_network_hash(server, name, "lora")[1] json["loras"].append({"name": name, "weight": weight, "hash": hash}) if self.models is not None: diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index dfdc2a07..bbad066a 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,7 +1,7 @@ from logging import getLogger from os import environ, path from secrets import token_urlsafe -from typing import List, Optional +from typing import Dict, List, Optional import torch @@ -44,6 +44,7 @@ class ServerContext: plugins: List[str] debug: bool thumbnail_size: int + hash_cache: Dict[str, str] def __init__( self, @@ -70,6 +71,7 @@ class ServerContext: plugins: Optional[List[str]] = None, debug: bool = False, thumbnail_size: Optional[int] = DEFAULT_THUMBNAIL_SIZE, + hash_cache: Optional[Dict[str, str]] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -94,6 +96,7 @@ class ServerContext: self.plugins = plugins or [] self.debug = debug self.thumbnail_size = thumbnail_size + self.hash_cache = hash_cache or {} self.cache = ModelCache(self.cache_limit)