fix(api): only hash networks once per worker lifetime
This commit is contained in:
parent
6b0b2e41a6
commit
7d782842ed
|
@ -56,10 +56,16 @@ class ImageMetadata:
|
||||||
self.loras = loras
|
self.loras = loras
|
||||||
self.models = models
|
self.models = models
|
||||||
|
|
||||||
def get_model_hash(self, model: Optional[str] = None) -> Tuple[str, str]:
|
def get_model_hash(
|
||||||
|
self, server: ServerContext, model: Optional[str] = None
|
||||||
|
) -> Tuple[str, str]:
|
||||||
model_name = path.basename(path.normpath(model or self.params.model))
|
model_name = path.basename(path.normpath(model or self.params.model))
|
||||||
logger.debug("getting model hash for %s", model_name)
|
logger.debug("getting model hash for %s", model_name)
|
||||||
|
|
||||||
|
if model_name in server.hash_cache:
|
||||||
|
logger.debug("using cached model hash for %s", model_name)
|
||||||
|
return (model_name, server.hash_cache[model_name])
|
||||||
|
|
||||||
model_hash = get_extra_hashes().get(model_name, None)
|
model_hash = get_extra_hashes().get(model_name, None)
|
||||||
if model_hash is None:
|
if model_hash is None:
|
||||||
model_hash_path = path.join(self.params.model, "hash.txt")
|
model_hash_path = path.join(self.params.model, "hash.txt")
|
||||||
|
@ -67,10 +73,30 @@ class ImageMetadata:
|
||||||
with open(model_hash_path, "r") as f:
|
with open(model_hash_path, "r") as f:
|
||||||
model_hash = f.readline().rstrip(",. \n\t\r")
|
model_hash = f.readline().rstrip(",. \n\t\r")
|
||||||
|
|
||||||
return (model_name, model_hash or "unknown")
|
model_hash = model_hash or "unknown"
|
||||||
|
server.hash_cache[model_name] = model_hash
|
||||||
|
|
||||||
def to_exif(self, server, output: List[str]) -> str:
|
return (model_name, model_hash)
|
||||||
model_name, model_hash = self.get_model_hash()
|
|
||||||
|
def get_network_hash(
|
||||||
|
self, server: ServerContext, network_name: str, network_type: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
# run this again just in case the file path changes
|
||||||
|
network_path = resolve_tensor(
|
||||||
|
path.join(server.model_path, network_type, network_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if network_path in server.hash_cache:
|
||||||
|
logger.debug("using cached network hash for %s", network_path)
|
||||||
|
return (network_name, server.hash_cache[network_path])
|
||||||
|
|
||||||
|
network_hash = hash_file(network_path).upper()
|
||||||
|
server.hash_cache[network_path] = network_hash
|
||||||
|
|
||||||
|
return (network_name, network_hash)
|
||||||
|
|
||||||
|
def to_exif(self, server: ServerContext, output: List[str]) -> str:
|
||||||
|
model_name, model_hash = self.get_model_hash(server)
|
||||||
hash_map = {
|
hash_map = {
|
||||||
model_name: model_hash,
|
model_name: model_hash,
|
||||||
}
|
}
|
||||||
|
@ -80,9 +106,7 @@ class ImageMetadata:
|
||||||
inversion_pairs = [
|
inversion_pairs = [
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
hash_file(
|
self.get_network_hash(server, name, "inversion")[1],
|
||||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
|
||||||
).upper(),
|
|
||||||
)
|
)
|
||||||
for name, _weight in self.inversions
|
for name, _weight in self.inversions
|
||||||
]
|
]
|
||||||
|
@ -96,9 +120,7 @@ class ImageMetadata:
|
||||||
lora_pairs = [
|
lora_pairs = [
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
hash_file(
|
self.get_network_hash(server, name, "lora")[1],
|
||||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
|
||||||
).upper(),
|
|
||||||
)
|
)
|
||||||
for name, _weight in self.loras
|
for name, _weight in self.loras
|
||||||
]
|
]
|
||||||
|
@ -127,7 +149,7 @@ class ImageMetadata:
|
||||||
}
|
}
|
||||||
|
|
||||||
# fix up some fields
|
# fix up some fields
|
||||||
model_name, model_hash = self.get_model_hash(self.params.model)
|
model_name, model_hash = self.get_model_hash(server, self.params.model)
|
||||||
json["params"]["model"] = model_name
|
json["params"]["model"] = model_name
|
||||||
json["models"].append(
|
json["models"].append(
|
||||||
{
|
{
|
||||||
|
@ -155,18 +177,14 @@ class ImageMetadata:
|
||||||
|
|
||||||
if self.inversions is not None:
|
if self.inversions is not None:
|
||||||
for name, weight in self.inversions:
|
for name, weight in self.inversions:
|
||||||
hash = hash_file(
|
hash = self.get_network_hash(server, name, "inversion")[1]
|
||||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
|
||||||
).upper()
|
|
||||||
json["inversions"].append(
|
json["inversions"].append(
|
||||||
{"name": name, "weight": weight, "hash": hash}
|
{"name": name, "weight": weight, "hash": hash}
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.loras is not None:
|
if self.loras is not None:
|
||||||
for name, weight in self.loras:
|
for name, weight in self.loras:
|
||||||
hash = hash_file(
|
hash = self.get_network_hash(server, name, "lora")[1]
|
||||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
|
||||||
).upper()
|
|
||||||
json["loras"].append({"name": name, "weight": weight, "hash": hash})
|
json["loras"].append({"name": name, "weight": weight, "hash": hash})
|
||||||
|
|
||||||
if self.models is not None:
|
if self.models is not None:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from secrets import token_urlsafe
|
from secrets import token_urlsafe
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ class ServerContext:
|
||||||
plugins: List[str]
|
plugins: List[str]
|
||||||
debug: bool
|
debug: bool
|
||||||
thumbnail_size: int
|
thumbnail_size: int
|
||||||
|
hash_cache: Dict[str, str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -70,6 +71,7 @@ class ServerContext:
|
||||||
plugins: Optional[List[str]] = None,
|
plugins: Optional[List[str]] = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
thumbnail_size: Optional[int] = DEFAULT_THUMBNAIL_SIZE,
|
thumbnail_size: Optional[int] = DEFAULT_THUMBNAIL_SIZE,
|
||||||
|
hash_cache: Optional[Dict[str, str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -94,6 +96,7 @@ class ServerContext:
|
||||||
self.plugins = plugins or []
|
self.plugins = plugins or []
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.thumbnail_size = thumbnail_size
|
self.thumbnail_size = thumbnail_size
|
||||||
|
self.hash_cache = hash_cache or {}
|
||||||
|
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue