1
0
Fork 0

feat(api): add env vars for Civitai and Huggingface auth tokens

This commit is contained in:
Sean Sube 2023-12-10 11:08:02 -06:00
parent e6978243cc
commit 4edb32aaac
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 50 additions and 36 deletions

View File

@ -22,9 +22,14 @@ class CivitaiClient(BaseClient):
root: str root: str
token: Optional[str] token: Optional[str]
def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT): def __init__(
self.root = root self,
self.token = token conversion: ConversionContext,
token: Optional[str] = None,
root=CIVITAI_ROOT,
):
self.root = conversion.get_setting("CIVITAI_ROOT", root)
self.token = conversion.get_setting("CIVITAI_TOKEN", token)
def download( def download(
self, self,
@ -35,9 +40,6 @@ class CivitaiClient(BaseClient):
dest: Optional[str] = None, dest: Optional[str] = None,
**kwargs, **kwargs,
) -> str: ) -> str:
"""
TODO: download with auth token
"""
cache_paths = build_cache_paths( cache_paths = build_cache_paths(
conversion, conversion,
name, name,
@ -51,4 +53,9 @@ class CivitaiClient(BaseClient):
source = self.root % (remove_prefix(source, CivitaiClient.protocol)) source = self.root % (remove_prefix(source, CivitaiClient.protocol))
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0]) logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
if self.token:
logger.debug("adding Civitai token authentication")
source = f"{source}?token={self.token}"
return download_progress(source, cache_paths[0]) return download_progress(source, cache_paths[0])

View File

@ -12,6 +12,9 @@ logger = getLogger(__name__)
class FileClient(BaseClient): class FileClient(BaseClient):
protocol = "file://" protocol = "file://"
def __init__(self, _conversion: ConversionContext):
pass
def download( def download(
self, self,
conversion: ConversionContext, conversion: ConversionContext,

View File

@ -19,7 +19,9 @@ class HttpClient(BaseClient):
headers: Dict[str, str] headers: Dict[str, str]
def __init__(self, headers: Optional[Dict[str, str]] = None): def __init__(
self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None
):
self.headers = headers or {} self.headers = headers or {}
def download( def download(

View File

@ -15,8 +15,8 @@ class HuggingfaceClient(BaseClient):
token: Optional[str] token: Optional[str]
def __init__(self, token: Optional[str] = None): def __init__(self, conversion: ConversionContext, token: Optional[str] = None):
self.token = token self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token)
def download( def download(
self, self,

View File

@ -1,7 +1,7 @@
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from secrets import token_urlsafe from secrets import token_urlsafe
from typing import List, Optional from typing import Dict, List, Optional
import torch import torch
@ -42,6 +42,7 @@ class ServerContext:
feature_flags: List[str] feature_flags: List[str]
plugins: List[str] plugins: List[str]
debug: bool debug: bool
env: Dict[str, str]
def __init__( def __init__(
self, self,
@ -67,6 +68,7 @@ class ServerContext:
feature_flags: Optional[List[str]] = None, feature_flags: Optional[List[str]] = None,
plugins: Optional[List[str]] = None, plugins: Optional[List[str]] = None,
debug: bool = False, debug: bool = False,
env: Dict[str, str] = environ,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -90,49 +92,49 @@ class ServerContext:
self.feature_flags = feature_flags or [] self.feature_flags = feature_flags or []
self.plugins = plugins or [] self.plugins = plugins or []
self.debug = debug self.debug = debug
self.env = env
self.cache = ModelCache(self.cache_limit) self.cache = ModelCache(self.cache_limit)
@classmethod @classmethod
def from_environ(cls): def from_environ(cls, env=environ):
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None) memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None)
if memory_limit is not None: if memory_limit is not None:
memory_limit = int(memory_limit) memory_limit = int(memory_limit)
return cls( return cls(
bundle_path=environ.get( bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")),
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
), output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), params_path=env.get("ONNX_WEB_PARAMS_PATH", "."),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
any_platform=get_boolean( any_platform=get_boolean(
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
), ),
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"), block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean( show_progress=get_boolean(
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
), ),
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"), optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"),
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"), extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit, memory_limit=memory_limit,
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None),
server_version=environ.get( server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION),
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
),
worker_retries=int( worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
), ),
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"), feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"),
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), plugins=get_list(env, "ONNX_WEB_PLUGINS", ""),
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
) )
def get_setting(self, flag: str, default: str) -> Optional[str]:
return self.env.get(f"ONNX_WEB_{flag}", default)
def has_feature(self, flag: str) -> bool: def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags return flag in self.feature_flags