From 12fb7f52bb17208f82cf6b3274acb4b90bfdc717 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 28 Feb 2023 21:56:12 -0600 Subject: [PATCH] fix(api): sanitize filenames in user input --- api/onnx_web/convert/diffusion/original.py | 3 ++- api/onnx_web/convert/utils.py | 7 ------- api/onnx_web/server/api.py | 3 +++ api/onnx_web/utils.py | 6 ++++++ 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index c331ee74..dbcaa330 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -53,7 +53,8 @@ from transformers import ( CLIPVisionConfig, ) -from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from ...utils import sanitize_name +from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml from .diffusers import convert_diffusion_diffusers logger = getLogger(__name__) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9b3edf66..295fdd87 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -189,13 +189,6 @@ def load_yaml(file: str) -> Config: return Config(data) -safe_chars = "._-" - - -def sanitize_name(name): - return "".join(x for x in name if (x.isalnum() or x in safe_chars)) - - def remove_prefix(name, prefix): if name.startswith(prefix): return name[len(prefix) :] diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 7f69c1c1..ed70c76f 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -28,6 +28,7 @@ from ..utils import ( get_from_map, get_not_empty, get_size, + sanitize_name, ) from ..worker.pool import DevicePoolExecutor from .config import ( @@ -428,6 +429,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor): if output_file is None: return error_reply("output name is required") + output_file = sanitize_name(output_file) cancel = pool.cancel(output_file) return ready_reply(cancel) @@ -438,6 +440,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): if output_file is None: return error_reply("output name is required") + output_file = sanitize_name(output_file) done, progress = pool.done(output_file) if done is None: diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index f794eca0..74f998f8 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -10,6 +10,8 @@ from .params import DeviceParams, SizeChart logger = getLogger(__name__) +SAFE_CHARS = "._-" + def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") @@ -100,3 +102,7 @@ def run_gc(devices: Optional[List[DeviceParams]] = None): (mem_total - mem_free), mem_total, ) + + +def sanitize_name(name): + return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))