feat(api): add converter to extract ZIP archives (#437)
This commit is contained in:
parent
193f82e7b4
commit
f3e1beaa71
|
@ -12,6 +12,7 @@ from transformers import CLIPTokenizer
|
|||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ..server.plugin import load_plugins
|
||||
from ..utils import load_config
|
||||
from .archive import convert_extract_archive
|
||||
from .client import add_model_source, fetch_model
|
||||
from .client.huggingface import HuggingfaceClient
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
|
@ -50,6 +51,7 @@ ModelDict = Dict[str, Union[float, int, str]]
|
|||
Models = Dict[str, List[ModelDict]]
|
||||
|
||||
model_converters: Dict[str, Any] = {
|
||||
"archive": convert_extract_archive,
|
||||
"img2img": convert_diffusion_diffusers,
|
||||
"img2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
"inpaint": convert_diffusion_diffusers,
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Dict
|
||||
from zipfile import ZipFile
|
||||
|
||||
from regex import match
|
||||
|
||||
from .client import fetch_model
|
||||
from .utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def convert_extract_archive(
|
||||
conversion: ConversionContext, model: Dict[str, Any], format: str
|
||||
):
|
||||
name = str(model.get("name")).strip()
|
||||
source = model.get("source")
|
||||
|
||||
dest_path = path.join(conversion.model_path, name)
|
||||
|
||||
logger.info("extracting archived model %s: %s -> %s/", name, source, dest_path)
|
||||
|
||||
if path.exists(dest_path):
|
||||
logger.info("destination path already exists, skipping extraction")
|
||||
|
||||
cache_path = fetch_model(conversion, name, model["source"], format=format)
|
||||
|
||||
with ZipFile(cache_path) as zip:
|
||||
names = zip.namelist()
|
||||
if not all([is_safe(name) for name in names]):
|
||||
raise ValueError("archive contains unsafe filenames")
|
||||
|
||||
logger.debug("archive is valid, extracting all files: %s", names)
|
||||
zip.extractall(path=dest_path)
|
||||
|
||||
|
||||
SAFE_NAME = r"^[-_a-zA-Z/\\\.]+$"
|
||||
|
||||
|
||||
def is_safe(name: str):
|
||||
return match(SAFE_NAME, name)
|
Loading…
Reference in New Issue