1
0
Fork 0
onnx-web/api/onnx_web/convert/client/huggingface.py

44 lines
1.1 KiB
Python
Raw Normal View History

2023-12-10 00:46:47 +00:00
from logging import getLogger
2023-12-10 01:15:28 +00:00
from typing import Optional
2023-12-10 00:46:47 +00:00
from huggingface_hub.file_download import hf_hub_download
2023-12-10 01:15:28 +00:00
from ..utils import ConversionContext, remove_prefix
from .base import BaseClient
logger = getLogger(__name__)
class HuggingfaceClient(BaseClient):
name = "huggingface"
protocol = "huggingface://"
token: Optional[str]
2023-12-10 01:15:28 +00:00
def __init__(self, token: Optional[str] = None):
self.token = token
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
2023-12-10 00:46:47 +00:00
format: Optional[str] = None,
dest: Optional[str] = None,
2023-12-10 01:15:28 +00:00
embeds: bool = False,
**kwargs,
) -> str:
source = remove_prefix(source, HuggingfaceClient.protocol)
logger.info("downloading model from Huggingface Hub: %s", source)
2023-12-10 01:15:28 +00:00
if embeds:
return hf_hub_download(
repo_id=source,
filename="learned_embeds.bin",
2023-12-10 16:57:25 +00:00
cache_dir=dest or conversion.cache_path,
2023-12-10 01:15:28 +00:00
force_filename=f"{name}.bin",
token=self.token,
)
else:
return source