2023-12-10 00:46:47 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
2023-12-10 00:04:34 +00:00
|
|
|
from ..utils import (
|
|
|
|
ConversionContext,
|
|
|
|
build_cache_paths,
|
|
|
|
download_progress,
|
|
|
|
get_first_exists,
|
|
|
|
remove_prefix,
|
|
|
|
)
|
|
|
|
from .base import BaseClient
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class HttpClient(BaseClient):
|
|
|
|
name = "http"
|
|
|
|
protocol = "https://"
|
|
|
|
insecure_protocol = "http://"
|
|
|
|
|
|
|
|
headers: Dict[str, str]
|
|
|
|
|
|
|
|
def __init__(self, headers: Optional[Dict[str, str]] = None):
|
|
|
|
self.headers = headers or {}
|
|
|
|
|
2023-12-10 00:46:47 +00:00
|
|
|
def download(
|
|
|
|
self,
|
|
|
|
conversion: ConversionContext,
|
|
|
|
name: str,
|
|
|
|
uri: str,
|
|
|
|
format: Optional[str] = None,
|
|
|
|
dest: Optional[str] = None,
|
|
|
|
) -> str:
|
2023-12-10 00:04:34 +00:00
|
|
|
cache_paths = build_cache_paths(
|
2023-12-10 00:46:47 +00:00
|
|
|
conversion,
|
|
|
|
name,
|
|
|
|
client=HttpClient.name,
|
|
|
|
format=format,
|
|
|
|
dest=dest,
|
2023-12-10 00:04:34 +00:00
|
|
|
)
|
|
|
|
cached = get_first_exists(cache_paths)
|
|
|
|
if cached:
|
|
|
|
return cached
|
|
|
|
|
|
|
|
if uri.startswith(HttpClient.protocol):
|
|
|
|
source = remove_prefix(uri, HttpClient.protocol)
|
|
|
|
logger.info("downloading model from: %s", source)
|
|
|
|
elif uri.startswith(HttpClient.insecure_protocol):
|
|
|
|
logger.warning("downloading model from insecure source: %s", source)
|
|
|
|
|
|
|
|
return download_progress(source, cache_paths[0])
|