fix(api): add HF hub download to fetch logic for Inversion concepts
This commit is contained in:
parent
0732058aa8
commit
ae3bcf3b8b
|
@ -144,8 +144,11 @@ def fetch_model(
|
|||
source: str,
|
||||
dest: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
hf_hub_fetch: bool = False,
|
||||
hf_hub_filename: Optional[str] = None,
|
||||
) -> str:
|
||||
cache_path = path.join(dest or ctx.cache_path, name)
|
||||
cache_path = dest or ctx.cache_path
|
||||
cache_name = path.join(cache_path, name)
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
if format is None:
|
||||
|
@ -153,11 +156,11 @@ def fetch_model(
|
|||
ext = path.basename(url.path)
|
||||
_filename, ext = path.splitext(ext)
|
||||
if ext is not None:
|
||||
cache_name = cache_path + ext
|
||||
cache_name = cache_name + ext
|
||||
else:
|
||||
cache_name = cache_path
|
||||
cache_name = cache_name
|
||||
else:
|
||||
cache_name = f"{cache_path}.{format}"
|
||||
cache_name = f"{cache_name}.{format}"
|
||||
|
||||
if path.exists(cache_name):
|
||||
logger.debug("model already exists in cache, skipping fetch")
|
||||
|
@ -176,7 +179,15 @@ def fetch_model(
|
|||
hub_source = remove_prefix(source, model_source_huggingface)
|
||||
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||
return hub_source
|
||||
if hf_hub_fetch:
|
||||
return hf_hub_download(
|
||||
repo_id=source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
)
|
||||
else:
|
||||
return hub_source
|
||||
elif source.startswith("https://"):
|
||||
logger.info("downloading model from: %s", source)
|
||||
return download_progress([(source, cache_name)])
|
||||
|
@ -223,8 +234,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
|
||||
try:
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest = hf_hub_download(
|
||||
repo_id=source, filename="learned_embeds.bin"
|
||||
dest = fetch_model(
|
||||
ctx,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(ctx.model_path, network_type),
|
||||
format=network_format,
|
||||
hf_hub_fetch=True,
|
||||
hf_hub_filename="learned_embeds.bin",
|
||||
)
|
||||
else:
|
||||
dest = fetch_model(
|
||||
|
|
Loading…
Reference in New Issue