1
0
Fork 0

fix(api): only hash networks once per worker lifetime

This commit is contained in:
Sean Sube 2024-01-05 21:11:44 -06:00
parent 6b0b2e41a6
commit 7d782842ed
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 39 additions and 18 deletions

View File

@ -56,10 +56,16 @@ class ImageMetadata:
self.loras = loras
self.models = models
def get_model_hash(self, model: Optional[str] = None) -> Tuple[str, str]:
def get_model_hash(
self, server: ServerContext, model: Optional[str] = None
) -> Tuple[str, str]:
model_name = path.basename(path.normpath(model or self.params.model))
logger.debug("getting model hash for %s", model_name)
if model_name in server.hash_cache:
logger.debug("using cached model hash for %s", model_name)
return (model_name, server.hash_cache[model_name])
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(self.params.model, "hash.txt")
@ -67,10 +73,30 @@ class ImageMetadata:
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
return (model_name, model_hash or "unknown")
model_hash = model_hash or "unknown"
server.hash_cache[model_name] = model_hash
def to_exif(self, server, output: List[str]) -> str:
model_name, model_hash = self.get_model_hash()
return (model_name, model_hash)
def get_network_hash(
self, server: ServerContext, network_name: str, network_type: str
) -> Tuple[str, str]:
# run this again just in case the file path changes
network_path = resolve_tensor(
path.join(server.model_path, network_type, network_name)
)
if network_path in server.hash_cache:
logger.debug("using cached network hash for %s", network_path)
return (network_name, server.hash_cache[network_path])
network_hash = hash_file(network_path).upper()
server.hash_cache[network_path] = network_hash
return (network_name, network_hash)
def to_exif(self, server: ServerContext, output: List[str]) -> str:
model_name, model_hash = self.get_model_hash(server)
hash_map = {
model_name: model_hash,
}
@ -80,9 +106,7 @@ class ImageMetadata:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
self.get_network_hash(server, name, "inversion")[1],
)
for name, _weight in self.inversions
]
@ -96,9 +120,7 @@ class ImageMetadata:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
self.get_network_hash(server, name, "lora")[1],
)
for name, _weight in self.loras
]
@ -127,7 +149,7 @@ class ImageMetadata:
}
# fix up some fields
model_name, model_hash = self.get_model_hash(self.params.model)
model_name, model_hash = self.get_model_hash(server, self.params.model)
json["params"]["model"] = model_name
json["models"].append(
{
@ -155,18 +177,14 @@ class ImageMetadata:
if self.inversions is not None:
for name, weight in self.inversions:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper()
hash = self.get_network_hash(server, name, "inversion")[1]
json["inversions"].append(
{"name": name, "weight": weight, "hash": hash}
)
if self.loras is not None:
for name, weight in self.loras:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper()
hash = self.get_network_hash(server, name, "lora")[1]
json["loras"].append({"name": name, "weight": weight, "hash": hash})
if self.models is not None:

View File

@ -1,7 +1,7 @@
from logging import getLogger
from os import environ, path
from secrets import token_urlsafe
from typing import List, Optional
from typing import Dict, List, Optional
import torch
@ -44,6 +44,7 @@ class ServerContext:
plugins: List[str]
debug: bool
thumbnail_size: int
hash_cache: Dict[str, str]
def __init__(
self,
@ -70,6 +71,7 @@ class ServerContext:
plugins: Optional[List[str]] = None,
debug: bool = False,
thumbnail_size: Optional[int] = DEFAULT_THUMBNAIL_SIZE,
hash_cache: Optional[Dict[str, str]] = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -94,6 +96,7 @@ class ServerContext:
self.plugins = plugins or []
self.debug = debug
self.thumbnail_size = thumbnail_size
self.hash_cache = hash_cache or {}
self.cache = ModelCache(self.cache_limit)