diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index bf900d01..702045c0 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -17,6 +17,7 @@ from .utils import ( ConversionContext, download_progress, model_formats_original, + remove_prefix, source_format, tuple_to_correction, tuple_to_diffusion, @@ -157,14 +158,14 @@ def fetch_model( for proto in model_sources: api_name, api_root = model_sources.get(proto) if source.startswith(proto): - api_source = api_root % (source.removeprefix(proto)) + api_source = api_root % (remove_prefix(source, proto)) logger.info( "Downloading model from %s: %s -> %s", api_name, api_source, cache_name ) return download_progress([(api_source, cache_name)]) if source.startswith(model_source_huggingface): - hub_source = source.removeprefix(model_source_huggingface) + hub_source = remove_prefix(source, model_source_huggingface) logger.info("Downloading model from Huggingface Hub: %s", hub_source) # from_pretrained has a bunch of useful logic that snapshot_download by itself down not return hub_source diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 22186ad5..be3c2fbf 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -192,3 +192,10 @@ safe_chars = "._-" def sanitize_name(name): return "".join(x for x in name if (x.isalnum() or x in safe_chars)) + + +def remove_prefix(name, prefix): + if name.startswith(prefix): + return name[len(prefix) :] + + return name