From def8ad73c549b5d891c20c77296c1609e293f73e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Nov 2023 22:53:43 -0600 Subject: [PATCH] feat(api): add feature flags, move panoramic tile feature into flags --- api/onnx_web/diffusers/run.py | 2 +- api/onnx_web/server/context.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 60e80e65..11594468 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -70,7 +70,7 @@ def run_txt2img_pipeline( stage, ) - if server.panorama_tiles: + if server.has_feature("panorama-highres"): highres_size = tile_size * highres.scale first_upscale, after_upscale = split_upscale(upscale) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index eb7f5a2a..98bf6af5 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -16,7 +16,6 @@ DEFAULT_JOB_LIMIT = 10 DEFAULT_IMAGE_FORMAT = "png" DEFAULT_SERVER_VERSION = "v0.10.0" DEFAULT_SHOW_PROGRESS = True -DEFAULT_PANORAMA_TILES = False DEFAULT_WORKER_RETRIES = 3 @@ -40,7 +39,7 @@ class ServerContext: admin_token: str server_version: str worker_retries: int - panorama_tiles: bool + feature_flags: List[str] def __init__( self, @@ -63,7 +62,7 @@ class ServerContext: admin_token: Optional[str] = None, server_version: Optional[str] = DEFAULT_SERVER_VERSION, worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES, - panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES, + feature_flags: Optional[List[str]] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -84,7 +83,7 @@ class ServerContext: self.admin_token = admin_token or token_urlsafe() self.server_version = server_version self.worker_retries = worker_retries - self.panorama_tiles = panorama_tiles + self.feature_flags = feature_flags or [] self.cache = ModelCache(self.cache_limit) @@ -124,11 +123,12 @@ class ServerContext: worker_retries=int( environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) ), - panorama_tiles=get_boolean( - environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES - ), + feature_flags=environ.get("ONNX_WEB_FEATURE_FLAGS", "").split(","), ) + def has_feature(self, flag: str) -> bool: + return flag in self.feature_flags + def torch_dtype(self): if "torch-fp16" in self.optimizations: return torch.float16