1
0
Fork 0

fix(api): allow PTH tensor files, add helper to check extension

This commit is contained in:
Sean Sube 2023-10-06 19:03:15 -05:00
parent 1351b2f3ff
commit ebdfa78737
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 10 additions and 3 deletions

View File

@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"] RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
_name, ext = path.splitext(name)
ext = ext.strip(".")
return (name in exts, ext)
def source_format(model: Dict) -> Optional[str]: def source_format(model: Dict) -> Optional[str]:
@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]:
return model["format"] return model["format"]
if "source" in model: if "source" in model:
_name, ext = path.splitext(model["source"]) valid, ext = check_ext(model["source"], MODEL_FORMATS)
if ext in MODEL_FORMATS: if valid:
return ext return ext
return None return None