special downloading for embeds
This commit is contained in:
parent
7496613f4e
commit
20b719fdff
|
@ -202,7 +202,7 @@ base_models: Models = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def convert_source_model(conversion: ConversionContext, model):
|
def convert_model_source(conversion: ConversionContext, model):
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = model["source"]
|
source = model["source"]
|
||||||
|
@ -215,9 +215,10 @@ def convert_source_model(conversion: ConversionContext, model):
|
||||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||||
|
|
||||||
|
|
||||||
def convert_network_model(conversion: ConversionContext, network):
|
def convert_model_network(conversion: ConversionContext, network):
|
||||||
network_format = source_format(network)
|
format = source_format(network)
|
||||||
name = network["name"]
|
name = network["name"]
|
||||||
|
model = network["model"]
|
||||||
network_type = network["type"]
|
network_type = network["type"]
|
||||||
source = network["source"]
|
source = network["source"]
|
||||||
|
|
||||||
|
@ -226,7 +227,7 @@ def convert_network_model(conversion: ConversionContext, network):
|
||||||
conversion,
|
conversion,
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
format=network_format,
|
format=format,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_diffusion_control(
|
convert_diffusion_control(
|
||||||
|
@ -241,13 +242,14 @@ def convert_network_model(conversion: ConversionContext, network):
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
dest=path.join(conversion.model_path, network_type),
|
dest=path.join(conversion.model_path, network_type),
|
||||||
format=network_format,
|
format=format,
|
||||||
|
embeds=(network_type == "inversion" and model == "concept"),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_model(conversion: ConversionContext, model):
|
def convert_model_diffusion(conversion: ConversionContext, model):
|
||||||
# fix up entries with missing prefixes
|
# fix up entries with missing prefixes
|
||||||
name = fix_diffusion_name(model["name"])
|
name = fix_diffusion_name(model["name"])
|
||||||
if name != model["name"]:
|
if name != model["name"]:
|
||||||
|
@ -372,7 +374,7 @@ def convert_diffusion_model(conversion: ConversionContext, model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_upscaling_model(conversion: ConversionContext, model):
|
def convert_model_upscaling(conversion: ConversionContext, model):
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
|
|
||||||
|
@ -389,7 +391,7 @@ def convert_upscaling_model(conversion: ConversionContext, model):
|
||||||
raise ValueError(name)
|
raise ValueError(name)
|
||||||
|
|
||||||
|
|
||||||
def convert_correction_model(conversion: ConversionContext, model):
|
def convert_model_correction(conversion: ConversionContext, model):
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = fetch_model(conversion, name, model["source"], format=model_format)
|
source = fetch_model(conversion, name, model["source"], format=model_format)
|
||||||
|
@ -413,7 +415,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
logger.info("skipping source: %s", name)
|
logger.info("skipping source: %s", name)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_source_model(model)
|
convert_model_source(model)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error fetching source %s", name)
|
logger.exception("error fetching source %s", name)
|
||||||
model_errors.append(name)
|
model_errors.append(name)
|
||||||
|
@ -426,7 +428,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
logger.info("skipping network: %s", name)
|
logger.info("skipping network: %s", name)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_network_model(conversion, model)
|
convert_model_network(conversion, model)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error fetching network %s", name)
|
logger.exception("error fetching network %s", name)
|
||||||
model_errors.append(name)
|
model_errors.append(name)
|
||||||
|
@ -440,7 +442,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
logger.info("skipping model: %s", name)
|
logger.info("skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_diffusion_model(conversion, model)
|
convert_model_diffusion(conversion, model)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting diffusion model %s",
|
"error converting diffusion model %s",
|
||||||
|
@ -457,7 +459,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
logger.info("skipping model: %s", name)
|
logger.info("skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_upscaling_model(conversion, model)
|
convert_model_upscaling(conversion, model)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting upscaling model %s",
|
"error converting upscaling model %s",
|
||||||
|
@ -474,7 +476,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
logger.info("skipping model: %s", name)
|
logger.info("skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_correction_model(conversion, model)
|
convert_model_correction(conversion, model)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting correction model %s",
|
"error converting correction model %s",
|
||||||
|
|
|
@ -11,5 +11,6 @@ class BaseClient:
|
||||||
source: str,
|
source: str,
|
||||||
format: Optional[str] = None,
|
format: Optional[str] = None,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -31,6 +31,7 @@ class CivitaiClient(BaseClient):
|
||||||
source: str,
|
source: str,
|
||||||
format: Optional[str] = None,
|
format: Optional[str] = None,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
TODO: download with auth token
|
TODO: download with auth token
|
||||||
|
|
|
@ -24,6 +24,7 @@ class FileClient(BaseClient):
|
||||||
uri: str,
|
uri: str,
|
||||||
format: Optional[str] = None,
|
format: Optional[str] = None,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
parts = urlparse(uri)
|
parts = urlparse(uri)
|
||||||
logger.info("loading model from: %s", parts.path)
|
logger.info("loading model from: %s", parts.path)
|
||||||
|
|
|
@ -30,6 +30,7 @@ class HttpClient(BaseClient):
|
||||||
uri: str,
|
uri: str,
|
||||||
format: Optional[str] = None,
|
format: Optional[str] = None,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
cache_paths = build_cache_paths(
|
cache_paths = build_cache_paths(
|
||||||
conversion,
|
conversion,
|
||||||
|
@ -46,6 +47,7 @@ class HttpClient(BaseClient):
|
||||||
source = remove_prefix(uri, HttpClient.protocol)
|
source = remove_prefix(uri, HttpClient.protocol)
|
||||||
logger.info("downloading model from: %s", source)
|
logger.info("downloading model from: %s", source)
|
||||||
elif uri.startswith(HttpClient.insecure_protocol):
|
elif uri.startswith(HttpClient.insecure_protocol):
|
||||||
|
source = remove_prefix(uri, HttpClient.insecure_protocol)
|
||||||
logger.warning("downloading model from insecure source: %s", source)
|
logger.warning("downloading model from insecure source: %s", source)
|
||||||
|
|
||||||
return download_progress(source, cache_paths[0])
|
return download_progress(source, cache_paths[0])
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import ConversionContext, remove_prefix
|
||||||
ConversionContext,
|
|
||||||
build_cache_paths,
|
|
||||||
get_first_exists,
|
|
||||||
remove_prefix,
|
|
||||||
)
|
|
||||||
from .base import BaseClient
|
from .base import BaseClient
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -18,11 +14,9 @@ class HuggingfaceClient(BaseClient):
|
||||||
name = "huggingface"
|
name = "huggingface"
|
||||||
protocol = "huggingface://"
|
protocol = "huggingface://"
|
||||||
|
|
||||||
download: Any
|
|
||||||
token: Optional[str]
|
token: Optional[str]
|
||||||
|
|
||||||
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
|
def __init__(self, token: Optional[str] = None):
|
||||||
self.download = download
|
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
|
@ -32,39 +26,22 @@ class HuggingfaceClient(BaseClient):
|
||||||
source: str,
|
source: str,
|
||||||
format: Optional[str] = None,
|
format: Optional[str] = None,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
|
embeds: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
TODO: download with auth
|
|
||||||
TODO: set fetch and filename
|
|
||||||
if network_type == "inversion" and network_model == "concept":
|
|
||||||
"""
|
|
||||||
hf_hub_fetch = True
|
|
||||||
hf_hub_filename = "learned_embeds.bin"
|
|
||||||
|
|
||||||
cache_paths = build_cache_paths(
|
|
||||||
conversion,
|
|
||||||
name,
|
|
||||||
client=HuggingfaceClient.name,
|
|
||||||
format=format,
|
|
||||||
dest=dest,
|
|
||||||
)
|
|
||||||
cached = get_first_exists(cache_paths)
|
|
||||||
if cached:
|
|
||||||
return cached
|
|
||||||
|
|
||||||
source = remove_prefix(source, HuggingfaceClient.protocol)
|
source = remove_prefix(source, HuggingfaceClient.protocol)
|
||||||
logger.info("downloading model from Huggingface Hub: %s", source)
|
logger.info("downloading model from Huggingface Hub: %s", source)
|
||||||
|
|
||||||
if hf_hub_fetch:
|
if embeds:
|
||||||
return (
|
return hf_hub_download(
|
||||||
hf_hub_download(
|
repo_id=source,
|
||||||
repo_id=source,
|
filename="learned_embeds.bin",
|
||||||
filename=hf_hub_filename,
|
cache_dir=conversion.cache_path,
|
||||||
cache_dir=cache_paths[0],
|
force_filename=f"{name}.bin",
|
||||||
force_filename=f"{name}.bin",
|
token=self.token,
|
||||||
),
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO: download pretrained because load doesn't call from_pretrained anymore
|
return snapshot_download(
|
||||||
return source
|
repo_id=source,
|
||||||
|
token=self.token,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue