special downloading for embeds
This commit is contained in:
parent
7496613f4e
commit
20b719fdff
|
@ -202,7 +202,7 @@ base_models: Models = {
|
|||
}
|
||||
|
||||
|
||||
def convert_source_model(conversion: ConversionContext, model):
|
||||
def convert_model_source(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
source = model["source"]
|
||||
|
@ -215,9 +215,10 @@ def convert_source_model(conversion: ConversionContext, model):
|
|||
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||
|
||||
|
||||
def convert_network_model(conversion: ConversionContext, network):
|
||||
network_format = source_format(network)
|
||||
def convert_model_network(conversion: ConversionContext, network):
|
||||
format = source_format(network)
|
||||
name = network["name"]
|
||||
model = network["model"]
|
||||
network_type = network["type"]
|
||||
source = network["source"]
|
||||
|
||||
|
@ -226,7 +227,7 @@ def convert_network_model(conversion: ConversionContext, network):
|
|||
conversion,
|
||||
name,
|
||||
source,
|
||||
format=network_format,
|
||||
format=format,
|
||||
)
|
||||
|
||||
convert_diffusion_control(
|
||||
|
@ -241,13 +242,14 @@ def convert_network_model(conversion: ConversionContext, network):
|
|||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
format=format,
|
||||
embeds=(network_type == "inversion" and model == "concept"),
|
||||
)
|
||||
|
||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||
|
||||
|
||||
def convert_diffusion_model(conversion: ConversionContext, model):
|
||||
def convert_model_diffusion(conversion: ConversionContext, model):
|
||||
# fix up entries with missing prefixes
|
||||
name = fix_diffusion_name(model["name"])
|
||||
if name != model["name"]:
|
||||
|
@ -372,7 +374,7 @@ def convert_diffusion_model(conversion: ConversionContext, model):
|
|||
)
|
||||
|
||||
|
||||
def convert_upscaling_model(conversion: ConversionContext, model):
|
||||
def convert_model_upscaling(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
|
||||
|
@ -389,7 +391,7 @@ def convert_upscaling_model(conversion: ConversionContext, model):
|
|||
raise ValueError(name)
|
||||
|
||||
|
||||
def convert_correction_model(conversion: ConversionContext, model):
|
||||
def convert_model_correction(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
source = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
|
@ -413,7 +415,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.info("skipping source: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_source_model(model)
|
||||
convert_model_source(model)
|
||||
except Exception:
|
||||
logger.exception("error fetching source %s", name)
|
||||
model_errors.append(name)
|
||||
|
@ -426,7 +428,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.info("skipping network: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_network_model(conversion, model)
|
||||
convert_model_network(conversion, model)
|
||||
except Exception:
|
||||
logger.exception("error fetching network %s", name)
|
||||
model_errors.append(name)
|
||||
|
@ -440,7 +442,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_diffusion_model(conversion, model)
|
||||
convert_model_diffusion(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting diffusion model %s",
|
||||
|
@ -457,7 +459,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_upscaling_model(conversion, model)
|
||||
convert_model_upscaling(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting upscaling model %s",
|
||||
|
@ -474,7 +476,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_correction_model(conversion, model)
|
||||
convert_model_correction(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting correction model %s",
|
||||
|
|
|
@ -11,5 +11,6 @@ class BaseClient:
|
|||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -31,6 +31,7 @@ class CivitaiClient(BaseClient):
|
|||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
TODO: download with auth token
|
||||
|
|
|
@ -24,6 +24,7 @@ class FileClient(BaseClient):
|
|||
uri: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
parts = urlparse(uri)
|
||||
logger.info("loading model from: %s", parts.path)
|
||||
|
|
|
@ -30,6 +30,7 @@ class HttpClient(BaseClient):
|
|||
uri: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
cache_paths = build_cache_paths(
|
||||
conversion,
|
||||
|
@ -46,6 +47,7 @@ class HttpClient(BaseClient):
|
|||
source = remove_prefix(uri, HttpClient.protocol)
|
||||
logger.info("downloading model from: %s", source)
|
||||
elif uri.startswith(HttpClient.insecure_protocol):
|
||||
source = remove_prefix(uri, HttpClient.insecure_protocol)
|
||||
logger.warning("downloading model from insecure source: %s", source)
|
||||
|
||||
return download_progress(source, cache_paths[0])
|
||||
|
|
|
@ -1,14 +1,10 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
|
||||
from ..utils import (
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
get_first_exists,
|
||||
remove_prefix,
|
||||
)
|
||||
from ..utils import ConversionContext, remove_prefix
|
||||
from .base import BaseClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -18,11 +14,9 @@ class HuggingfaceClient(BaseClient):
|
|||
name = "huggingface"
|
||||
protocol = "huggingface://"
|
||||
|
||||
download: Any
|
||||
token: Optional[str]
|
||||
|
||||
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
|
||||
self.download = download
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
self.token = token
|
||||
|
||||
def download(
|
||||
|
@ -32,39 +26,22 @@ class HuggingfaceClient(BaseClient):
|
|||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
embeds: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
TODO: download with auth
|
||||
TODO: set fetch and filename
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
"""
|
||||
hf_hub_fetch = True
|
||||
hf_hub_filename = "learned_embeds.bin"
|
||||
|
||||
cache_paths = build_cache_paths(
|
||||
conversion,
|
||||
name,
|
||||
client=HuggingfaceClient.name,
|
||||
format=format,
|
||||
dest=dest,
|
||||
)
|
||||
cached = get_first_exists(cache_paths)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
source = remove_prefix(source, HuggingfaceClient.protocol)
|
||||
logger.info("downloading model from Huggingface Hub: %s", source)
|
||||
|
||||
if hf_hub_fetch:
|
||||
return (
|
||||
hf_hub_download(
|
||||
repo_id=source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_paths[0],
|
||||
force_filename=f"{name}.bin",
|
||||
),
|
||||
False,
|
||||
if embeds:
|
||||
return hf_hub_download(
|
||||
repo_id=source,
|
||||
filename="learned_embeds.bin",
|
||||
cache_dir=conversion.cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
token=self.token,
|
||||
)
|
||||
else:
|
||||
# TODO: download pretrained because load doesn't call from_pretrained anymore
|
||||
return source
|
||||
return snapshot_download(
|
||||
repo_id=source,
|
||||
token=self.token,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue