fix(api): only hash networks once per worker lifetime
This commit is contained in:
parent
6b0b2e41a6
commit
7d782842ed
|
@ -56,10 +56,16 @@ class ImageMetadata:
|
|||
self.loras = loras
|
||||
self.models = models
|
||||
|
||||
def get_model_hash(self, model: Optional[str] = None) -> Tuple[str, str]:
|
||||
def get_model_hash(
|
||||
self, server: ServerContext, model: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
model_name = path.basename(path.normpath(model or self.params.model))
|
||||
logger.debug("getting model hash for %s", model_name)
|
||||
|
||||
if model_name in server.hash_cache:
|
||||
logger.debug("using cached model hash for %s", model_name)
|
||||
return (model_name, server.hash_cache[model_name])
|
||||
|
||||
model_hash = get_extra_hashes().get(model_name, None)
|
||||
if model_hash is None:
|
||||
model_hash_path = path.join(self.params.model, "hash.txt")
|
||||
|
@ -67,10 +73,30 @@ class ImageMetadata:
|
|||
with open(model_hash_path, "r") as f:
|
||||
model_hash = f.readline().rstrip(",. \n\t\r")
|
||||
|
||||
return (model_name, model_hash or "unknown")
|
||||
model_hash = model_hash or "unknown"
|
||||
server.hash_cache[model_name] = model_hash
|
||||
|
||||
def to_exif(self, server, output: List[str]) -> str:
|
||||
model_name, model_hash = self.get_model_hash()
|
||||
return (model_name, model_hash)
|
||||
|
||||
def get_network_hash(
|
||||
self, server: ServerContext, network_name: str, network_type: str
|
||||
) -> Tuple[str, str]:
|
||||
# run this again just in case the file path changes
|
||||
network_path = resolve_tensor(
|
||||
path.join(server.model_path, network_type, network_name)
|
||||
)
|
||||
|
||||
if network_path in server.hash_cache:
|
||||
logger.debug("using cached network hash for %s", network_path)
|
||||
return (network_name, server.hash_cache[network_path])
|
||||
|
||||
network_hash = hash_file(network_path).upper()
|
||||
server.hash_cache[network_path] = network_hash
|
||||
|
||||
return (network_name, network_hash)
|
||||
|
||||
def to_exif(self, server: ServerContext, output: List[str]) -> str:
|
||||
model_name, model_hash = self.get_model_hash(server)
|
||||
hash_map = {
|
||||
model_name: model_hash,
|
||||
}
|
||||
|
@ -80,9 +106,7 @@ class ImageMetadata:
|
|||
inversion_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper(),
|
||||
self.get_network_hash(server, name, "inversion")[1],
|
||||
)
|
||||
for name, _weight in self.inversions
|
||||
]
|
||||
|
@ -96,9 +120,7 @@ class ImageMetadata:
|
|||
lora_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper(),
|
||||
self.get_network_hash(server, name, "lora")[1],
|
||||
)
|
||||
for name, _weight in self.loras
|
||||
]
|
||||
|
@ -127,7 +149,7 @@ class ImageMetadata:
|
|||
}
|
||||
|
||||
# fix up some fields
|
||||
model_name, model_hash = self.get_model_hash(self.params.model)
|
||||
model_name, model_hash = self.get_model_hash(server, self.params.model)
|
||||
json["params"]["model"] = model_name
|
||||
json["models"].append(
|
||||
{
|
||||
|
@ -155,18 +177,14 @@ class ImageMetadata:
|
|||
|
||||
if self.inversions is not None:
|
||||
for name, weight in self.inversions:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper()
|
||||
hash = self.get_network_hash(server, name, "inversion")[1]
|
||||
json["inversions"].append(
|
||||
{"name": name, "weight": weight, "hash": hash}
|
||||
)
|
||||
|
||||
if self.loras is not None:
|
||||
for name, weight in self.loras:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper()
|
||||
hash = self.get_network_hash(server, name, "lora")[1]
|
||||
json["loras"].append({"name": name, "weight": weight, "hash": hash})
|
||||
|
||||
if self.models is not None:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from secrets import token_urlsafe
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -44,6 +44,7 @@ class ServerContext:
|
|||
plugins: List[str]
|
||||
debug: bool
|
||||
thumbnail_size: int
|
||||
hash_cache: Dict[str, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -70,6 +71,7 @@ class ServerContext:
|
|||
plugins: Optional[List[str]] = None,
|
||||
debug: bool = False,
|
||||
thumbnail_size: Optional[int] = DEFAULT_THUMBNAIL_SIZE,
|
||||
hash_cache: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -94,6 +96,7 @@ class ServerContext:
|
|||
self.plugins = plugins or []
|
||||
self.debug = debug
|
||||
self.thumbnail_size = thumbnail_size
|
||||
self.hash_cache = hash_cache or {}
|
||||
|
||||
self.cache = ModelCache(self.cache_limit)
|
||||
|
||||
|
|
Loading…
Reference in New Issue