1
0
Fork 0

feat(api): add converter to extract ZIP archives (#437)

This commit is contained in:
Sean Sube 2023-12-16 22:14:43 -06:00
parent 193f82e7b4
commit f3e1beaa71
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 44 additions and 0 deletions

View File

@ -12,6 +12,7 @@ from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..server.plugin import load_plugins
from ..utils import load_config
from .archive import convert_extract_archive
from .client import add_model_source, fetch_model
from .client.huggingface import HuggingfaceClient
from .correction.gfpgan import convert_correction_gfpgan
@ -50,6 +51,7 @@ ModelDict = Dict[str, Union[float, int, str]]
Models = Dict[str, List[ModelDict]]
model_converters: Dict[str, Any] = {
"archive": convert_extract_archive,
"img2img": convert_diffusion_diffusers,
"img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers,

View File

@ -0,0 +1,42 @@
from logging import getLogger
from os import path
from typing import Any, Dict
from zipfile import ZipFile
from regex import match
from .client import fetch_model
from .utils import ConversionContext
logger = getLogger(__name__)
def convert_extract_archive(
conversion: ConversionContext, model: Dict[str, Any], format: str
):
name = str(model.get("name")).strip()
source = model.get("source")
dest_path = path.join(conversion.model_path, name)
logger.info("extracting archived model %s: %s -> %s/", name, source, dest_path)
if path.exists(dest_path):
logger.info("destination path already exists, skipping extraction")
cache_path = fetch_model(conversion, name, model["source"], format=format)
with ZipFile(cache_path) as zip:
names = zip.namelist()
if not all([is_safe(name) for name in names]):
raise ValueError("archive contains unsafe filenames")
logger.debug("archive is valid, extracting all files: %s", names)
zip.extractall(path=dest_path)
SAFE_NAME = r"^[-_a-zA-Z/\\\.]+$"
def is_safe(name: str):
return match(SAFE_NAME, name)