fix(api): sanitize filenames in user input
This commit is contained in:
parent
c99aa67220
commit
12fb7f52bb
|
@ -53,7 +53,8 @@ from transformers import (
|
||||||
CLIPVisionConfig,
|
CLIPVisionConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
from ...utils import sanitize_name
|
||||||
|
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml
|
||||||
from .diffusers import convert_diffusion_diffusers
|
from .diffusers import convert_diffusion_diffusers
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
|
@ -189,13 +189,6 @@ def load_yaml(file: str) -> Config:
|
||||||
return Config(data)
|
return Config(data)
|
||||||
|
|
||||||
|
|
||||||
safe_chars = "._-"
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(name):
|
|
||||||
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(name, prefix):
|
def remove_prefix(name, prefix):
|
||||||
if name.startswith(prefix):
|
if name.startswith(prefix):
|
||||||
return name[len(prefix) :]
|
return name[len(prefix) :]
|
||||||
|
|
|
@ -28,6 +28,7 @@ from ..utils import (
|
||||||
get_from_map,
|
get_from_map,
|
||||||
get_not_empty,
|
get_not_empty,
|
||||||
get_size,
|
get_size,
|
||||||
|
sanitize_name,
|
||||||
)
|
)
|
||||||
from ..worker.pool import DevicePoolExecutor
|
from ..worker.pool import DevicePoolExecutor
|
||||||
from .config import (
|
from .config import (
|
||||||
|
@ -428,6 +429,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
|
||||||
|
output_file = sanitize_name(output_file)
|
||||||
cancel = pool.cancel(output_file)
|
cancel = pool.cancel(output_file)
|
||||||
|
|
||||||
return ready_reply(cancel)
|
return ready_reply(cancel)
|
||||||
|
@ -438,6 +440,7 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
if output_file is None:
|
if output_file is None:
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
|
||||||
|
output_file = sanitize_name(output_file)
|
||||||
done, progress = pool.done(output_file)
|
done, progress = pool.done(output_file)
|
||||||
|
|
||||||
if done is None:
|
if done is None:
|
||||||
|
|
|
@ -10,6 +10,8 @@ from .params import DeviceParams, SizeChart
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
SAFE_CHARS = "._-"
|
||||||
|
|
||||||
|
|
||||||
def base_join(base: str, tail: str) -> str:
|
def base_join(base: str, tail: str) -> str:
|
||||||
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
||||||
|
@ -100,3 +102,7 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
|
||||||
(mem_total - mem_free),
|
(mem_total - mem_free),
|
||||||
mem_total,
|
mem_total,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_name(name):
|
||||||
|
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
|
||||||
|
|
Loading…
Reference in New Issue