1
0
Fork 0

special downloading for embeds

This commit is contained in:
Sean Sube 2023-12-09 19:15:28 -06:00
parent 7496613f4e
commit 20b719fdff
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 37 additions and 53 deletions

View File

@ -202,7 +202,7 @@ base_models: Models = {
}
def convert_source_model(conversion: ConversionContext, model):
def convert_model_source(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = model["source"]
@ -215,9 +215,10 @@ def convert_source_model(conversion: ConversionContext, model):
logger.info("finished downloading source: %s -> %s", source, dest)
def convert_network_model(conversion: ConversionContext, network):
network_format = source_format(network)
def convert_model_network(conversion: ConversionContext, network):
format = source_format(network)
name = network["name"]
model = network["model"]
network_type = network["type"]
source = network["source"]
@ -226,7 +227,7 @@ def convert_network_model(conversion: ConversionContext, network):
conversion,
name,
source,
format=network_format,
format=format,
)
convert_diffusion_control(
@ -241,13 +242,14 @@ def convert_network_model(conversion: ConversionContext, network):
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
format=format,
embeds=(network_type == "inversion" and model == "concept"),
)
logger.info("finished downloading network: %s -> %s", source, dest)
def convert_diffusion_model(conversion: ConversionContext, model):
def convert_model_diffusion(conversion: ConversionContext, model):
# fix up entries with missing prefixes
name = fix_diffusion_name(model["name"])
if name != model["name"]:
@ -372,7 +374,7 @@ def convert_diffusion_model(conversion: ConversionContext, model):
)
def convert_upscaling_model(conversion: ConversionContext, model):
def convert_model_upscaling(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
@ -389,7 +391,7 @@ def convert_upscaling_model(conversion: ConversionContext, model):
raise ValueError(name)
def convert_correction_model(conversion: ConversionContext, model):
def convert_model_correction(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = fetch_model(conversion, name, model["source"], format=model_format)
@ -413,7 +415,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("skipping source: %s", name)
else:
try:
convert_source_model(model)
convert_model_source(model)
except Exception:
logger.exception("error fetching source %s", name)
model_errors.append(name)
@ -426,7 +428,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("skipping network: %s", name)
else:
try:
convert_network_model(conversion, model)
convert_model_network(conversion, model)
except Exception:
logger.exception("error fetching network %s", name)
model_errors.append(name)
@ -440,7 +442,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("skipping model: %s", name)
else:
try:
convert_diffusion_model(conversion, model)
convert_model_diffusion(conversion, model)
except Exception:
logger.exception(
"error converting diffusion model %s",
@ -457,7 +459,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("skipping model: %s", name)
else:
try:
convert_upscaling_model(conversion, model)
convert_model_upscaling(conversion, model)
except Exception:
logger.exception(
"error converting upscaling model %s",
@ -474,7 +476,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.info("skipping model: %s", name)
else:
try:
convert_correction_model(conversion, model)
convert_model_correction(conversion, model)
except Exception:
logger.exception(
"error converting correction model %s",

View File

@ -11,5 +11,6 @@ class BaseClient:
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
raise NotImplementedError()

View File

@ -31,6 +31,7 @@ class CivitaiClient(BaseClient):
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
"""
TODO: download with auth token

View File

@ -24,6 +24,7 @@ class FileClient(BaseClient):
uri: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
parts = urlparse(uri)
logger.info("loading model from: %s", parts.path)

View File

@ -30,6 +30,7 @@ class HttpClient(BaseClient):
uri: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
cache_paths = build_cache_paths(
conversion,
@ -46,6 +47,7 @@ class HttpClient(BaseClient):
source = remove_prefix(uri, HttpClient.protocol)
logger.info("downloading model from: %s", source)
elif uri.startswith(HttpClient.insecure_protocol):
source = remove_prefix(uri, HttpClient.insecure_protocol)
logger.warning("downloading model from insecure source: %s", source)
return download_progress(source, cache_paths[0])

View File

@ -1,14 +1,10 @@
from logging import getLogger
from typing import Any, Optional
from typing import Optional
from huggingface_hub import snapshot_download
from huggingface_hub.file_download import hf_hub_download
from ..utils import (
ConversionContext,
build_cache_paths,
get_first_exists,
remove_prefix,
)
from ..utils import ConversionContext, remove_prefix
from .base import BaseClient
logger = getLogger(__name__)
@ -18,11 +14,9 @@ class HuggingfaceClient(BaseClient):
name = "huggingface"
protocol = "huggingface://"
download: Any
token: Optional[str]
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
self.download = download
def __init__(self, token: Optional[str] = None):
self.token = token
def download(
@ -32,39 +26,22 @@ class HuggingfaceClient(BaseClient):
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
embeds: bool = False,
**kwargs,
) -> str:
"""
TODO: download with auth
TODO: set fetch and filename
if network_type == "inversion" and network_model == "concept":
"""
hf_hub_fetch = True
hf_hub_filename = "learned_embeds.bin"
cache_paths = build_cache_paths(
conversion,
name,
client=HuggingfaceClient.name,
format=format,
dest=dest,
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = remove_prefix(source, HuggingfaceClient.protocol)
logger.info("downloading model from Huggingface Hub: %s", source)
if hf_hub_fetch:
return (
hf_hub_download(
repo_id=source,
filename=hf_hub_filename,
cache_dir=cache_paths[0],
force_filename=f"{name}.bin",
),
False,
if embeds:
return hf_hub_download(
repo_id=source,
filename="learned_embeds.bin",
cache_dir=conversion.cache_path,
force_filename=f"{name}.bin",
token=self.token,
)
else:
# TODO: download pretrained because load doesn't call from_pretrained anymore
return source
return snapshot_download(
repo_id=source,
token=self.token,
)