fix(api): add HF hub download to fetch logic for Inversion concepts
This commit is contained in:
parent
0732058aa8
commit
ae3bcf3b8b
|
@ -144,8 +144,11 @@ def fetch_model(
|
||||||
source: str,
|
source: str,
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
format: Optional[str] = None,
|
format: Optional[str] = None,
|
||||||
|
hf_hub_fetch: bool = False,
|
||||||
|
hf_hub_filename: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
cache_path = path.join(dest or ctx.cache_path, name)
|
cache_path = dest or ctx.cache_path
|
||||||
|
cache_name = path.join(cache_path, name)
|
||||||
|
|
||||||
# add an extension if possible, some of the conversion code checks for it
|
# add an extension if possible, some of the conversion code checks for it
|
||||||
if format is None:
|
if format is None:
|
||||||
|
@ -153,11 +156,11 @@ def fetch_model(
|
||||||
ext = path.basename(url.path)
|
ext = path.basename(url.path)
|
||||||
_filename, ext = path.splitext(ext)
|
_filename, ext = path.splitext(ext)
|
||||||
if ext is not None:
|
if ext is not None:
|
||||||
cache_name = cache_path + ext
|
cache_name = cache_name + ext
|
||||||
else:
|
else:
|
||||||
cache_name = cache_path
|
cache_name = cache_name
|
||||||
else:
|
else:
|
||||||
cache_name = f"{cache_path}.{format}"
|
cache_name = f"{cache_name}.{format}"
|
||||||
|
|
||||||
if path.exists(cache_name):
|
if path.exists(cache_name):
|
||||||
logger.debug("model already exists in cache, skipping fetch")
|
logger.debug("model already exists in cache, skipping fetch")
|
||||||
|
@ -176,7 +179,15 @@ def fetch_model(
|
||||||
hub_source = remove_prefix(source, model_source_huggingface)
|
hub_source = remove_prefix(source, model_source_huggingface)
|
||||||
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
||||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||||
return hub_source
|
if hf_hub_fetch:
|
||||||
|
return hf_hub_download(
|
||||||
|
repo_id=source,
|
||||||
|
filename=hf_hub_filename,
|
||||||
|
cache_dir=cache_path,
|
||||||
|
force_filename=f"{name}.bin",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hub_source
|
||||||
elif source.startswith("https://"):
|
elif source.startswith("https://"):
|
||||||
logger.info("downloading model from: %s", source)
|
logger.info("downloading model from: %s", source)
|
||||||
return download_progress([(source, cache_name)])
|
return download_progress([(source, cache_name)])
|
||||||
|
@ -223,8 +234,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if network_type == "inversion" and network_model == "concept":
|
if network_type == "inversion" and network_model == "concept":
|
||||||
dest = hf_hub_download(
|
dest = fetch_model(
|
||||||
repo_id=source, filename="learned_embeds.bin"
|
ctx,
|
||||||
|
name,
|
||||||
|
source,
|
||||||
|
dest=path.join(ctx.model_path, network_type),
|
||||||
|
format=network_format,
|
||||||
|
hf_hub_fetch=True,
|
||||||
|
hf_hub_filename="learned_embeds.bin",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dest = fetch_model(
|
dest = fetch_model(
|
||||||
|
|
Loading…
Reference in New Issue