1
0
Fork 0
onnx-web/api/onnx_web/convert/client/__init__.py

52 lines
1.5 KiB
Python

from .base import BaseClient
from .civitai import CivitaiClient
from .file import FileClient
from .http import HttpClient
from .huggingface import HuggingfaceClient
from ..utils import ConversionContext
from typing import Dict, Optional
from logging import getLogger
from os import path
logger = getLogger(__name__)
model_sources: Dict[str, BaseClient] = {
CivitaiClient.protocol: CivitaiClient,
FileClient.protocol: FileClient,
HttpClient.insecure_protocol: HttpClient,
HttpClient.protocol: HttpClient,
HuggingfaceClient.protocol: HuggingfaceClient,
}
def add_model_source(proto: str, client: BaseClient):
global model_sources
if proto in model_sources:
raise ValueError("protocol has already been taken")
model_sources[proto] = client
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
) -> str:
# TODO: switch to urlparse's default scheme
if source.startswith(path.sep) or source.startswith("."):
logger.info("adding file protocol to local path source: %s", source)
source = FileClient.protocol + source
for proto, client_type in model_sources.items():
if source.startswith(proto):
# TODO: fix type of client_type
client: BaseClient = client_type()
return client.download(conversion, name, source, format=format, dest=dest)
logger.warning("unknown model protocol, using path as provided: %s", source)
return source