1
0
Fork 0

fix(api): add HF hub download to fetch logic for Inversion concepts

This commit is contained in:
Sean Sube 2023-03-19 20:32:21 -05:00
parent 0732058aa8
commit ae3bcf3b8b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 24 additions and 7 deletions

View File

@ -144,8 +144,11 @@ def fetch_model(
source: str,
dest: Optional[str] = None,
format: Optional[str] = None,
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> str:
cache_path = path.join(dest or ctx.cache_path, name)
cache_path = dest or ctx.cache_path
cache_name = path.join(cache_path, name)
# add an extension if possible, some of the conversion code checks for it
if format is None:
@ -153,11 +156,11 @@ def fetch_model(
ext = path.basename(url.path)
_filename, ext = path.splitext(ext)
if ext is not None:
cache_name = cache_path + ext
cache_name = cache_name + ext
else:
cache_name = cache_path
cache_name = cache_name
else:
cache_name = f"{cache_path}.{format}"
cache_name = f"{cache_name}.{format}"
if path.exists(cache_name):
logger.debug("model already exists in cache, skipping fetch")
@ -176,6 +179,14 @@ def fetch_model(
hub_source = remove_prefix(source, model_source_huggingface)
logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return hf_hub_download(
repo_id=source,
filename=hf_hub_filename,
cache_dir=cache_path,
force_filename=f"{name}.bin",
)
else:
return hub_source
elif source.startswith("https://"):
logger.info("downloading model from: %s", source)
@ -223,8 +234,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
try:
if network_type == "inversion" and network_model == "concept":
dest = hf_hub_download(
repo_id=source, filename="learned_embeds.bin"
dest = fetch_model(
ctx,
name,
source,
dest=path.join(ctx.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else:
dest = fetch_model(