1
0
Fork 0

fix(api): sanitize filenames in user input

This commit is contained in:
Sean Sube 2023-02-28 21:56:12 -06:00
parent c99aa67220
commit 12fb7f52bb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 11 additions and 8 deletions

View File

@ -53,7 +53,8 @@ from transformers import (
CLIPVisionConfig, CLIPVisionConfig,
) )
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name from ...utils import sanitize_name
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml
from .diffusers import convert_diffusion_diffusers from .diffusers import convert_diffusion_diffusers
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -189,13 +189,6 @@ def load_yaml(file: str) -> Config:
return Config(data) return Config(data)
safe_chars = "._-"
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
def remove_prefix(name, prefix): def remove_prefix(name, prefix):
if name.startswith(prefix): if name.startswith(prefix):
return name[len(prefix) :] return name[len(prefix) :]

View File

@ -28,6 +28,7 @@ from ..utils import (
get_from_map, get_from_map,
get_not_empty, get_not_empty,
get_size, get_size,
sanitize_name,
) )
from ..worker.pool import DevicePoolExecutor from ..worker.pool import DevicePoolExecutor
from .config import ( from .config import (
@ -428,6 +429,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
if output_file is None: if output_file is None:
return error_reply("output name is required") return error_reply("output name is required")
output_file = sanitize_name(output_file)
cancel = pool.cancel(output_file) cancel = pool.cancel(output_file)
return ready_reply(cancel) return ready_reply(cancel)
@ -438,6 +440,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if output_file is None: if output_file is None:
return error_reply("output name is required") return error_reply("output name is required")
output_file = sanitize_name(output_file)
done, progress = pool.done(output_file) done, progress = pool.done(output_file)
if done is None: if done is None:

View File

@ -10,6 +10,8 @@ from .params import DeviceParams, SizeChart
logger = getLogger(__name__) logger = getLogger(__name__)
SAFE_CHARS = "._-"
def base_join(base: str, tail: str) -> str: def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
@ -100,3 +102,7 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
(mem_total - mem_free), (mem_total - mem_free),
mem_total, mem_total,
) )
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))