fix(api): allow PTH tensor files, add helper to check extension
This commit is contained in:
parent
1351b2f3ff
commit
ebdfa78737
|
@ -185,7 +185,14 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|||
|
||||
|
||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "bin"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
||||
|
||||
|
||||
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
||||
_name, ext = path.splitext(name)
|
||||
ext = ext.strip(".")
|
||||
|
||||
return (name in exts, ext)
|
||||
|
||||
|
||||
def source_format(model: Dict) -> Optional[str]:
|
||||
|
@ -193,8 +200,8 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
return model["format"]
|
||||
|
||||
if "source" in model:
|
||||
_name, ext = path.splitext(model["source"])
|
||||
if ext in MODEL_FORMATS:
|
||||
valid, ext = check_ext(model["source"], MODEL_FORMATS)
|
||||
if valid:
|
||||
return ext
|
||||
|
||||
return None
|
||||
|
|
Loading…
Reference in New Issue