1
0
Fork 0
onnx-web/api/onnx_web/convert/client/__init__.py

55 lines
1.5 KiB
Python
Raw Permalink Normal View History

2023-12-10 05:40:47 +00:00
from typing import Callable, Dict, Optional
from logging import getLogger
from os import path
from ..utils import ConversionContext
2023-12-10 00:46:47 +00:00
from .base import BaseClient
from .civitai import CivitaiClient
from .file import FileClient
from .http import HttpClient
from .huggingface import HuggingfaceClient
logger = getLogger(__name__)
2023-12-10 05:40:47 +00:00
model_sources: Dict[str, Callable[[], BaseClient]] = {
2023-12-10 00:46:47 +00:00
CivitaiClient.protocol: CivitaiClient,
FileClient.protocol: FileClient,
HttpClient.insecure_protocol: HttpClient,
HttpClient.protocol: HttpClient,
HuggingfaceClient.protocol: HuggingfaceClient,
}
def add_model_source(proto: str, client: BaseClient):
global model_sources
if proto in model_sources:
raise ValueError("protocol has already been taken")
model_sources[proto] = client
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
2023-12-10 05:09:00 +00:00
**kwargs,
2023-12-10 00:46:47 +00:00
) -> str:
# TODO: switch to urlparse's default scheme
if source.startswith(path.sep) or source.startswith("."):
logger.info("adding file protocol to local path source: %s", source)
source = FileClient.protocol + source
for proto, client_type in model_sources.items():
if source.startswith(proto):
2023-12-10 18:16:01 +00:00
client = client_type(conversion)
2023-12-10 05:40:47 +00:00
return client.download(
conversion, name, source, format=format, dest=dest, **kwargs
)
2023-12-10 00:46:47 +00:00
logger.warning("unknown model protocol, using path as provided: %s", source)
return source