From 4edb32aaac1a4881d962d2caf1e90b947558f58f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 11:08:02 -0600 Subject: [PATCH] feat(api): add env vars for Civitai and Huggingface auth tokens --- api/onnx_web/convert/client/civitai.py | 19 +++++--- api/onnx_web/convert/client/file.py | 3 ++ api/onnx_web/convert/client/http.py | 4 +- api/onnx_web/convert/client/huggingface.py | 4 +- api/onnx_web/server/context.py | 56 +++++++++++----------- 5 files changed, 50 insertions(+), 36 deletions(-) diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index 1ae5ecfd..ebccb223 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -22,9 +22,14 @@ class CivitaiClient(BaseClient): root: str token: Optional[str] - def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT): - self.root = root - self.token = token + def __init__( + self, + conversion: ConversionContext, + token: Optional[str] = None, + root=CIVITAI_ROOT, + ): + self.root = conversion.get_setting("CIVITAI_ROOT", root) + self.token = conversion.get_setting("CIVITAI_TOKEN", token) def download( self, @@ -35,9 +40,6 @@ class CivitaiClient(BaseClient): dest: Optional[str] = None, **kwargs, ) -> str: - """ - TODO: download with auth token - """ cache_paths = build_cache_paths( conversion, name, @@ -51,4 +53,9 @@ class CivitaiClient(BaseClient): source = self.root % (remove_prefix(source, CivitaiClient.protocol)) logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0]) + + if self.token: + logger.debug("adding Civitai token authentication") + source = f"{source}?token={self.token}" + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 7457e28b..0ae41b46 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -12,6 +12,9 @@ logger = getLogger(__name__) class FileClient(BaseClient): protocol = "file://" + def __init__(self, _conversion: ConversionContext): + pass + def download( self, conversion: ConversionContext, diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index 591e3a6c..151ebccd 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -19,7 +19,9 @@ class HttpClient(BaseClient): headers: Dict[str, str] - def __init__(self, headers: Optional[Dict[str, str]] = None): + def __init__( + self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None + ): self.headers = headers or {} def download( diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 02044e50..c35698e7 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -15,8 +15,8 @@ class HuggingfaceClient(BaseClient): token: Optional[str] - def __init__(self, token: Optional[str] = None): - self.token = token + def __init__(self, conversion: ConversionContext, token: Optional[str] = None): + self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token) def download( self, diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 034fc3c6..66a8946a 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,7 +1,7 @@ from logging import getLogger from os import environ, path from secrets import token_urlsafe -from typing import List, Optional +from typing import Dict, List, Optional import torch @@ -42,6 +42,7 @@ class ServerContext: feature_flags: List[str] plugins: List[str] debug: bool + env: Dict[str, str] def __init__( self, @@ -67,6 +68,7 @@ class ServerContext: feature_flags: Optional[List[str]] = None, plugins: Optional[List[str]] = None, debug: bool = False, + env: Dict[str, str] = environ, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -90,49 +92,49 @@ class ServerContext: self.feature_flags = feature_flags or [] self.plugins = plugins or [] self.debug = debug + self.env = env self.cache = ModelCache(self.cache_limit) @classmethod - def from_environ(cls): - memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None) + def from_environ(cls, env=environ): + memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None) if memory_limit is not None: memory_limit = int(memory_limit) return cls( - bundle_path=environ.get( - "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") - ), - model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), - output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), - params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"), + bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")), + model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), + output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), + params_path=env.get("ONNX_WEB_PARAMS_PATH", "."), + cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"), any_platform=get_boolean( - environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM + env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM ), - block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"), - default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), - image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), - cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), + block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"), + default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None), + image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), + cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), show_progress=get_boolean( - environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS + env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS ), - optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"), - extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"), - job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), + optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"), + extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"), + job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, - admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), - server_version=environ.get( - "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION - ), + admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None), + server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION), worker_retries=int( - environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) + env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) ), - feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"), - plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), - debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), + feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"), + plugins=get_list(env, "ONNX_WEB_PLUGINS", ""), + debug=get_boolean(env, "ONNX_WEB_DEBUG", False), ) + def get_setting(self, flag: str, default: str) -> Optional[str]: + return self.env.get(f"ONNX_WEB_{flag}", default) + def has_feature(self, flag: str) -> bool: return flag in self.feature_flags