diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py index 23a71177..6047b2b3 100644 --- a/api/onnx_web/convert/client/__init__.py +++ b/api/onnx_web/convert/client/__init__.py @@ -1,17 +1,18 @@ +from typing import Callable, Dict, Optional +from logging import getLogger +from os import path + +from ..utils import ConversionContext from .base import BaseClient from .civitai import CivitaiClient from .file import FileClient from .http import HttpClient from .huggingface import HuggingfaceClient -from ..utils import ConversionContext -from typing import Dict, Optional -from logging import getLogger -from os import path logger = getLogger(__name__) -model_sources: Dict[str, BaseClient] = { +model_sources: Dict[str, Callable[[], BaseClient]] = { CivitaiClient.protocol: CivitaiClient, FileClient.protocol: FileClient, HttpClient.insecure_protocol: HttpClient, @@ -44,9 +45,10 @@ def fetch_model( for proto, client_type in model_sources.items(): if source.startswith(proto): - # TODO: fix type of client_type - client: BaseClient = client_type() - return client.download(conversion, name, source, format=format, dest=dest, **kwargs) + client = client_type() + return client.download( + conversion, name, source, format=format, dest=dest, **kwargs + ) logger.warning("unknown model protocol, using path as provided: %s", source) return source