1
0
Fork 0
onnx-web/api/onnx_web/convert/client/huggingface.py

71 lines
1.8 KiB
Python
Raw Normal View History

2023-12-10 00:46:47 +00:00
from logging import getLogger
from typing import Any, Optional
from huggingface_hub.file_download import hf_hub_download
from ..utils import (
ConversionContext,
build_cache_paths,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
logger = getLogger(__name__)
class HuggingfaceClient(BaseClient):
name = "huggingface"
protocol = "huggingface://"
download: Any
token: Optional[str]
def __init__(self, token: Optional[str] = None, download=hf_hub_download):
self.download = download
self.token = token
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
2023-12-10 00:46:47 +00:00
format: Optional[str] = None,
dest: Optional[str] = None,
) -> str:
"""
TODO: download with auth
2023-12-10 00:46:47 +00:00
TODO: set fetch and filename
if network_type == "inversion" and network_model == "concept":
"""
2023-12-10 00:46:47 +00:00
hf_hub_fetch = True
hf_hub_filename = "learned_embeds.bin"
cache_paths = build_cache_paths(
2023-12-10 00:46:47 +00:00
conversion,
name,
client=HuggingfaceClient.name,
format=format,
dest=dest,
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = remove_prefix(source, HuggingfaceClient.protocol)
logger.info("downloading model from Huggingface Hub: %s", source)
if hf_hub_fetch:
return (
hf_hub_download(
repo_id=source,
filename=hf_hub_filename,
cache_dir=cache_paths[0],
force_filename=f"{name}.bin",
),
False,
)
else:
2023-12-10 00:46:47 +00:00
# TODO: download pretrained because load doesn't call from_pretrained anymore
return source