feat(api): add env vars for Civitai and Huggingface auth tokens
This commit is contained in:
parent
e6978243cc
commit
4edb32aaac
|
@ -22,9 +22,14 @@ class CivitaiClient(BaseClient):
|
||||||
root: str
|
root: str
|
||||||
token: Optional[str]
|
token: Optional[str]
|
||||||
|
|
||||||
def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT):
|
def __init__(
|
||||||
self.root = root
|
self,
|
||||||
self.token = token
|
conversion: ConversionContext,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
root=CIVITAI_ROOT,
|
||||||
|
):
|
||||||
|
self.root = conversion.get_setting("CIVITAI_ROOT", root)
|
||||||
|
self.token = conversion.get_setting("CIVITAI_TOKEN", token)
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
self,
|
self,
|
||||||
|
@ -35,9 +40,6 @@ class CivitaiClient(BaseClient):
|
||||||
dest: Optional[str] = None,
|
dest: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
TODO: download with auth token
|
|
||||||
"""
|
|
||||||
cache_paths = build_cache_paths(
|
cache_paths = build_cache_paths(
|
||||||
conversion,
|
conversion,
|
||||||
name,
|
name,
|
||||||
|
@ -51,4 +53,9 @@ class CivitaiClient(BaseClient):
|
||||||
|
|
||||||
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
|
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
|
||||||
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
|
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
|
||||||
|
|
||||||
|
if self.token:
|
||||||
|
logger.debug("adding Civitai token authentication")
|
||||||
|
source = f"{source}?token={self.token}"
|
||||||
|
|
||||||
return download_progress(source, cache_paths[0])
|
return download_progress(source, cache_paths[0])
|
||||||
|
|
|
@ -12,6 +12,9 @@ logger = getLogger(__name__)
|
||||||
class FileClient(BaseClient):
|
class FileClient(BaseClient):
|
||||||
protocol = "file://"
|
protocol = "file://"
|
||||||
|
|
||||||
|
def __init__(self, _conversion: ConversionContext):
|
||||||
|
pass
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
self,
|
self,
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
|
|
|
@ -19,7 +19,9 @@ class HttpClient(BaseClient):
|
||||||
|
|
||||||
headers: Dict[str, str]
|
headers: Dict[str, str]
|
||||||
|
|
||||||
def __init__(self, headers: Optional[Dict[str, str]] = None):
|
def __init__(
|
||||||
|
self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None
|
||||||
|
):
|
||||||
self.headers = headers or {}
|
self.headers = headers or {}
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
|
|
|
@ -15,8 +15,8 @@ class HuggingfaceClient(BaseClient):
|
||||||
|
|
||||||
token: Optional[str]
|
token: Optional[str]
|
||||||
|
|
||||||
def __init__(self, token: Optional[str] = None):
|
def __init__(self, conversion: ConversionContext, token: Optional[str] = None):
|
||||||
self.token = token
|
self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token)
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from secrets import token_urlsafe
|
from secrets import token_urlsafe
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ class ServerContext:
|
||||||
feature_flags: List[str]
|
feature_flags: List[str]
|
||||||
plugins: List[str]
|
plugins: List[str]
|
||||||
debug: bool
|
debug: bool
|
||||||
|
env: Dict[str, str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -67,6 +68,7 @@ class ServerContext:
|
||||||
feature_flags: Optional[List[str]] = None,
|
feature_flags: Optional[List[str]] = None,
|
||||||
plugins: Optional[List[str]] = None,
|
plugins: Optional[List[str]] = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
|
env: Dict[str, str] = environ,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -90,49 +92,49 @@ class ServerContext:
|
||||||
self.feature_flags = feature_flags or []
|
self.feature_flags = feature_flags or []
|
||||||
self.plugins = plugins or []
|
self.plugins = plugins or []
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
self.env = env
|
||||||
|
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls, env=environ):
|
||||||
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None)
|
memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None)
|
||||||
if memory_limit is not None:
|
if memory_limit is not None:
|
||||||
memory_limit = int(memory_limit)
|
memory_limit = int(memory_limit)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
bundle_path=environ.get(
|
bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")),
|
||||||
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||||
),
|
output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||||
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
params_path=env.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||||
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
||||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
|
||||||
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
|
||||||
any_platform=get_boolean(
|
any_platform=get_boolean(
|
||||||
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||||
),
|
),
|
||||||
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
|
block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
||||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||||
show_progress=get_boolean(
|
show_progress=get_boolean(
|
||||||
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||||
),
|
),
|
||||||
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
|
optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"),
|
||||||
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
|
extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"),
|
||||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||||
memory_limit=memory_limit,
|
memory_limit=memory_limit,
|
||||||
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
|
admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None),
|
||||||
server_version=environ.get(
|
server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION),
|
||||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
|
||||||
),
|
|
||||||
worker_retries=int(
|
worker_retries=int(
|
||||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||||
),
|
),
|
||||||
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
|
feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"),
|
||||||
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
|
plugins=get_list(env, "ONNX_WEB_PLUGINS", ""),
|
||||||
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
|
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_setting(self, flag: str, default: str) -> Optional[str]:
|
||||||
|
return self.env.get(f"ONNX_WEB_{flag}", default)
|
||||||
|
|
||||||
def has_feature(self, flag: str) -> bool:
|
def has_feature(self, flag: str) -> bool:
|
||||||
return flag in self.feature_flags
|
return flag in self.feature_flags
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue