fix(api): sanitize filenames in user input
This commit is contained in:
parent
c99aa67220
commit
12fb7f52bb
|
@ -53,7 +53,8 @@ from transformers import (
|
|||
CLIPVisionConfig,
|
||||
)
|
||||
|
||||
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
||||
from ...utils import sanitize_name
|
||||
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml
|
||||
from .diffusers import convert_diffusion_diffusers
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -189,13 +189,6 @@ def load_yaml(file: str) -> Config:
|
|||
return Config(data)
|
||||
|
||||
|
||||
safe_chars = "._-"
|
||||
|
||||
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
||||
|
||||
|
||||
def remove_prefix(name, prefix):
|
||||
if name.startswith(prefix):
|
||||
return name[len(prefix) :]
|
||||
|
|
|
@ -28,6 +28,7 @@ from ..utils import (
|
|||
get_from_map,
|
||||
get_not_empty,
|
||||
get_size,
|
||||
sanitize_name,
|
||||
)
|
||||
from ..worker.pool import DevicePoolExecutor
|
||||
from .config import (
|
||||
|
@ -428,6 +429,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
|||
if output_file is None:
|
||||
return error_reply("output name is required")
|
||||
|
||||
output_file = sanitize_name(output_file)
|
||||
cancel = pool.cancel(output_file)
|
||||
|
||||
return ready_reply(cancel)
|
||||
|
@ -438,6 +440,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
if output_file is None:
|
||||
return error_reply("output name is required")
|
||||
|
||||
output_file = sanitize_name(output_file)
|
||||
done, progress = pool.done(output_file)
|
||||
|
||||
if done is None:
|
||||
|
|
|
@ -10,6 +10,8 @@ from .params import DeviceParams, SizeChart
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
SAFE_CHARS = "._-"
|
||||
|
||||
|
||||
def base_join(base: str, tail: str) -> str:
|
||||
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
||||
|
@ -100,3 +102,7 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
|
|||
(mem_total - mem_free),
|
||||
mem_total,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
|
||||
|
|
Loading…
Reference in New Issue