feat(api): add feature flags, move panoramic tile feature into flags
This commit is contained in:
parent
2a27c3ffd1
commit
def8ad73c5
|
@ -70,7 +70,7 @@ def run_txt2img_pipeline(
|
||||||
stage,
|
stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if server.panorama_tiles:
|
if server.has_feature("panorama-highres"):
|
||||||
highres_size = tile_size * highres.scale
|
highres_size = tile_size * highres.scale
|
||||||
|
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
|
|
|
@ -16,7 +16,6 @@ DEFAULT_JOB_LIMIT = 10
|
||||||
DEFAULT_IMAGE_FORMAT = "png"
|
DEFAULT_IMAGE_FORMAT = "png"
|
||||||
DEFAULT_SERVER_VERSION = "v0.10.0"
|
DEFAULT_SERVER_VERSION = "v0.10.0"
|
||||||
DEFAULT_SHOW_PROGRESS = True
|
DEFAULT_SHOW_PROGRESS = True
|
||||||
DEFAULT_PANORAMA_TILES = False
|
|
||||||
DEFAULT_WORKER_RETRIES = 3
|
DEFAULT_WORKER_RETRIES = 3
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,7 +39,7 @@ class ServerContext:
|
||||||
admin_token: str
|
admin_token: str
|
||||||
server_version: str
|
server_version: str
|
||||||
worker_retries: int
|
worker_retries: int
|
||||||
panorama_tiles: bool
|
feature_flags: List[str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -63,7 +62,7 @@ class ServerContext:
|
||||||
admin_token: Optional[str] = None,
|
admin_token: Optional[str] = None,
|
||||||
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
||||||
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
||||||
panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES,
|
feature_flags: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -84,7 +83,7 @@ class ServerContext:
|
||||||
self.admin_token = admin_token or token_urlsafe()
|
self.admin_token = admin_token or token_urlsafe()
|
||||||
self.server_version = server_version
|
self.server_version = server_version
|
||||||
self.worker_retries = worker_retries
|
self.worker_retries = worker_retries
|
||||||
self.panorama_tiles = panorama_tiles
|
self.feature_flags = feature_flags or []
|
||||||
|
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
|
@ -124,11 +123,12 @@ class ServerContext:
|
||||||
worker_retries=int(
|
worker_retries=int(
|
||||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||||
),
|
),
|
||||||
panorama_tiles=get_boolean(
|
feature_flags=environ.get("ONNX_WEB_FEATURE_FLAGS", "").split(","),
|
||||||
environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def has_feature(self, flag: str) -> bool:
|
||||||
|
return flag in self.feature_flags
|
||||||
|
|
||||||
def torch_dtype(self):
|
def torch_dtype(self):
|
||||||
if "torch-fp16" in self.optimizations:
|
if "torch-fp16" in self.optimizations:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
|
|
Loading…
Reference in New Issue