feat(api): add feature flags, move panoramic tile feature into flags
This commit is contained in:
parent
2a27c3ffd1
commit
def8ad73c5
|
@ -70,7 +70,7 @@ def run_txt2img_pipeline(
|
|||
stage,
|
||||
)
|
||||
|
||||
if server.panorama_tiles:
|
||||
if server.has_feature("panorama-highres"):
|
||||
highres_size = tile_size * highres.scale
|
||||
|
||||
first_upscale, after_upscale = split_upscale(upscale)
|
||||
|
|
|
@ -16,7 +16,6 @@ DEFAULT_JOB_LIMIT = 10
|
|||
DEFAULT_IMAGE_FORMAT = "png"
|
||||
DEFAULT_SERVER_VERSION = "v0.10.0"
|
||||
DEFAULT_SHOW_PROGRESS = True
|
||||
DEFAULT_PANORAMA_TILES = False
|
||||
DEFAULT_WORKER_RETRIES = 3
|
||||
|
||||
|
||||
|
@ -40,7 +39,7 @@ class ServerContext:
|
|||
admin_token: str
|
||||
server_version: str
|
||||
worker_retries: int
|
||||
panorama_tiles: bool
|
||||
feature_flags: List[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -63,7 +62,7 @@ class ServerContext:
|
|||
admin_token: Optional[str] = None,
|
||||
server_version: Optional[str] = DEFAULT_SERVER_VERSION,
|
||||
worker_retries: Optional[int] = DEFAULT_WORKER_RETRIES,
|
||||
panorama_tiles: Optional[bool] = DEFAULT_PANORAMA_TILES,
|
||||
feature_flags: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -84,7 +83,7 @@ class ServerContext:
|
|||
self.admin_token = admin_token or token_urlsafe()
|
||||
self.server_version = server_version
|
||||
self.worker_retries = worker_retries
|
||||
self.panorama_tiles = panorama_tiles
|
||||
self.feature_flags = feature_flags or []
|
||||
|
||||
self.cache = ModelCache(self.cache_limit)
|
||||
|
||||
|
@ -124,11 +123,12 @@ class ServerContext:
|
|||
worker_retries=int(
|
||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||
),
|
||||
panorama_tiles=get_boolean(
|
||||
environ, "ONNX_WEB_PANORAMA_TILES", DEFAULT_PANORAMA_TILES
|
||||
),
|
||||
feature_flags=environ.get("ONNX_WEB_FEATURE_FLAGS", "").split(","),
|
||||
)
|
||||
|
||||
def has_feature(self, flag: str) -> bool:
|
||||
return flag in self.feature_flags
|
||||
|
||||
def torch_dtype(self):
|
||||
if "torch-fp16" in self.optimizations:
|
||||
return torch.float16
|
||||
|
|
Loading…
Reference in New Issue