diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 3e058362..2c4eeb66 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -1,11 +1,10 @@ import warnings from argparse import ArgumentParser from logging import getLogger -from os import environ, makedirs, path +from os import makedirs, path from sys import exit from typing import Dict, List, Optional, Tuple -import torch from jsonschema import ValidationError, validate from yaml import safe_load @@ -102,9 +101,6 @@ base_models: Models = { ], } -model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")) -training_device = "cuda" if torch.cuda.is_available() else "cpu" - def fetch_model( ctx: ConversionContext, name: str, source: str, format: Optional[str] = None @@ -228,14 +224,12 @@ def main() -> int: args = parser.parse_args() logger.info("CLI arguments: %s", args) - ctx = ConversionContext( - model_path, training_device, half=args.half, opset=args.opset, token=args.token - ) + ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) - if not path.exists(model_path): - logger.info("Model path does not existing, creating: %s", model_path) - makedirs(model_path) + if not path.exists(ctx.model_path): + logger.info("Model path does not existing, creating: %s", ctx.model_path) + makedirs(ctx.model_path) logger.info("Converting base models.") convert_models(ctx, args, base_models) diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 35c9a4a2..292be414 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -15,6 +15,7 @@ import json import os import re import shutil +import sys import traceback from logging import getLogger from typing import Dict, List @@ -40,11 +41,10 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( LDMBertConfig, LDMBertModel, ) +import sys from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfApi, hf_hub_download -from omegaconf import OmegaConf -from pydantic import BaseModel from transformers import ( AutoFeatureExtractor, BertTokenizerFast, @@ -54,16 +54,12 @@ from transformers import ( ) from .diffusion_stable import convert_diffusion_stable -from .utils import ConversionContext, ModelDict +from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml logger = getLogger(__name__) -def sanitize_name(name): - return "".join(x for x in name if (x.isalnum() or x in "._- ")) - - -class DreamboothConfig(BaseModel): +class TrainingConfig(): adamw_weight_decay: float = 0.01 attention: str = "default" cache_latents: bool = True @@ -145,10 +141,16 @@ class DreamboothConfig(BaseModel): use_subdir: bool = False v2: bool = False - def __init__(self, ctx: ConversionContext, model_name: str = "", scheduler: str = "ddim", v2: bool = False, src: str = "", - resolution: int = 512, **kwargs): - - super().__init__(**kwargs) + def __init__( + self, + ctx: ConversionContext, + model_name: str = "", + scheduler: str = "ddim", + v2: bool = False, + src: str = "", + resolution: int = 512, + **kwargs, + ): model_name = sanitize_name(model_name) model_dir = os.path.join(ctx.cache_path, model_name) working_dir = os.path.join(model_dir, "working") @@ -164,6 +166,10 @@ class DreamboothConfig(BaseModel): self.scheduler = scheduler self.v2 = v2 + # avoid pydantic dep for this one fn + for k, v in kwargs.items(): + setattr(self, k, v) + def save(self, backup=False): """ Save the config file @@ -960,9 +966,7 @@ def replace_symlinks(path, base): replace_symlinks(os.path.join(path, subpath), base) -def download_model(db_config: DreamboothConfig, token): - tmp_dir = os.path.join(db_config.model_dir, "src") - +def download_model(db_config: TrainingConfig, token): hub_url = db_config.src if "http" in hub_url or "huggingface.co" in hub_url: hub_url = "/".join(hub_url.split("/")[-2:]) @@ -976,7 +980,7 @@ def download_model(db_config: DreamboothConfig, token): ) if repo_info.sha is None: - logger.warning("Unable to fetch repo?") + logger.warning("Unable to fetch repo info: %s", hub_url) return None, None siblings = repo_info.siblings @@ -1018,7 +1022,6 @@ def download_model(db_config: DreamboothConfig, token): files_to_fetch = None - cache_dir = tmp_dir if model_file is not None: files_to_fetch = [model_file] elif len(diffusion_files): @@ -1043,7 +1046,6 @@ def download_model(db_config: DreamboothConfig, token): filename=repo_file, repo_type="model", revision=repo_info.sha, - cache_dir=cache_dir, token=token ) replace_symlinks(out, db_config.model_dir) @@ -1163,12 +1165,8 @@ def extract_checkpoint( upcast_attention = False msg = None - if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""): - logger.warning("Please provide a URL and token for huggingface models.") - return - # Create empty config - db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, + db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file if not from_hub else new_model_url) original_config_file = None @@ -1243,7 +1241,7 @@ def extract_checkpoint( logger.debug("UNet using v2 parameters.") v2 = True except Exception as e: - logger.error("Exception loading unet!", traceback.format_exception(e)) + logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info())) if v2 and not is_512: prediction_type = "v_prediction" @@ -1275,7 +1273,7 @@ def extract_checkpoint( return logger.debug(f"Trying to load: {original_config_file}") - original_config = OmegaConf.load(original_config_file) + original_config = load_yaml(original_config_file) num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start @@ -1382,7 +1380,7 @@ def extract_checkpoint( pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) except Exception as e: - logger.error("Exception setting up output: %s", traceback.format_exception(e)) + logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info())) pipe = None if pipe is None or db_config is None: @@ -1397,12 +1395,12 @@ def extract_checkpoint( scheduler = db_config.scheduler required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] if original_config_file is not None and os.path.exists(original_config_file): - logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir) + logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir) shutil.copy(original_config_file, db_config.model_dir) basename = os.path.basename(original_config_file) new_ex_path = os.path.join(db_config.model_dir, basename) new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") - logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name) + logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name) if os.path.exists(new_name): os.remove(new_name) os.rename(new_ex_path, new_name) @@ -1446,7 +1444,7 @@ def convert_diffusion_original( working_name = os.path.join(ctx.cache_path, torch_name, "working") if os.path.exists(torch_path): - logger.info("Torch pipeline already exists, reusing.") + logger.info("Torch pipeline already exists, reusing: %s", torch_path) else: logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) extract_checkpoint(ctx, torch_name, source, from_hub=False) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 141481ee..cedcd96b 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -1,13 +1,14 @@ import shutil from functools import partial from logging import getLogger -from os import path +from os import environ, path from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import requests import torch from tqdm.auto import tqdm +from yaml import safe_load logger = getLogger(__name__) @@ -19,21 +20,28 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] class ConversionContext: def __init__( self, - model_path: str, - device: str, + model_path: Optional[str] = None, + device: Optional[str] = None, cache_path: Optional[str] = None, half: Optional[bool] = False, opset: Optional[int] = None, token: Optional[str] = None, ) -> None: - self.model_path = model_path - self.cache_path = cache_path or path.join(model_path, ".cache") - self.training_device = device - self.map_location = torch.device(device) + self.model_path = model_path or environ.get( + "ONNX_WEB_MODEL_PATH", path.join("..", "models") + ) + self.cache_path = cache_path or path.join(self.model_path, ".cache") self.half = half self.opset = opset self.token = token + if device is not None: + self.training_device = device + else: + self.training_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.map_location = torch.device(self.training_device) + def download_progress(urls: List[Tuple[str, str]]): for url, dest in urls: @@ -135,3 +143,35 @@ def source_format(model: Dict) -> Optional[str]: return ext return None + + + +class Config(object): + def __init__(self, kwargs): + self.__dict__.update(kwargs) + for k, v in self.__dict__.items(): + Config.config_from_key(self, k, v) + + def __iter__(self): + for k in self.__dict__.keys(): + yield k + + @classmethod + def config_from_key(cls, target, k, v): + if isinstance(v, dict): + tmp = Config(v) + setattr(target, k, tmp) + else: + setattr(target, k, v) + + + +def load_yaml(file: str) -> str: + with open(file, "r") as f: + data = safe_load(f.read()) + return Config(data) + + +safe_chars = "._-" +def sanitize_name(name): + return "".join(x for x in name if (x.isalnum() or x in safe_chars)) diff --git a/api/requirements.txt b/api/requirements.txt index d380bcd0..e08b7c31 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -7,6 +7,7 @@ accelerate diffusers onnx onnxruntime +safetensors transformers #### Upscaling and face correction @@ -22,8 +23,3 @@ flask flask-cors jsonschema pyyaml - -# TODO: get rid of these -omegaconf -pydantic -safetensors \ No newline at end of file diff --git a/docs/user-guide.md b/docs/user-guide.md index 5f196254..d9cd3c43 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -217,7 +217,7 @@ make sure it is named correctly. #### Model sources -You can provide an absolute or relative path to a local model, and there are a few pre-defined sources from which models can be downloaded. +You can provide an absolute or relative path to a local model, an HTTPS URL, or use one of the pre-defined sources: - `huggingface://` - https://huggingface.co/models?other=stable-diffusion @@ -229,9 +229,12 @@ You can provide an absolute or relative path to a local model, and there are a f - does not require an account - `https://` - any other HTTPS source +- `../models/.cache/your-model.safetensors` + - relative or absolute paths -If the model's `source` does not include a file extension like `.safetensors` or `.ckpt`, make sure to indicate the -file format using the `format` key. +If the model is a single file and the `source` does not include a file extension like `.safetensors` or `.ckpt`, make +sure to indicate the file format using the `format` key. You do not need to provide the `format` for directories and +models from the HuggingFace hub. ## Tabs