fix type of client instance
This commit is contained in:
parent
419b2811ef
commit
50db19922a
|
@ -1,17 +1,18 @@
|
|||
from typing import Callable, Dict, Optional
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
from ..utils import ConversionContext
|
||||
from .base import BaseClient
|
||||
from .civitai import CivitaiClient
|
||||
from .file import FileClient
|
||||
from .http import HttpClient
|
||||
from .huggingface import HuggingfaceClient
|
||||
from ..utils import ConversionContext
|
||||
from typing import Dict, Optional
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
model_sources: Dict[str, BaseClient] = {
|
||||
model_sources: Dict[str, Callable[[], BaseClient]] = {
|
||||
CivitaiClient.protocol: CivitaiClient,
|
||||
FileClient.protocol: FileClient,
|
||||
HttpClient.insecure_protocol: HttpClient,
|
||||
|
@ -44,9 +45,10 @@ def fetch_model(
|
|||
|
||||
for proto, client_type in model_sources.items():
|
||||
if source.startswith(proto):
|
||||
# TODO: fix type of client_type
|
||||
client: BaseClient = client_type()
|
||||
return client.download(conversion, name, source, format=format, dest=dest, **kwargs)
|
||||
client = client_type()
|
||||
return client.download(
|
||||
conversion, name, source, format=format, dest=dest, **kwargs
|
||||
)
|
||||
|
||||
logger.warning("unknown model protocol, using path as provided: %s", source)
|
||||
return source
|
||||
|
|
Loading…
Reference in New Issue