1
0
Fork 0

lint(api): clean up conversion code from original diffusers, drop pydantic dep

This commit is contained in:
Sean Sube 2023-02-11 12:36:54 -06:00
parent 84079e4490
commit 694d15547f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 86 additions and 55 deletions

View File

@ -1,11 +1,10 @@
import warnings
from argparse import ArgumentParser
from logging import getLogger
from os import environ, makedirs, path
from os import makedirs, path
from sys import exit
from typing import Dict, List, Optional, Tuple
import torch
from jsonschema import ValidationError, validate
from yaml import safe_load
@ -102,9 +101,6 @@ base_models: Models = {
],
}
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
def fetch_model(
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
@ -228,14 +224,12 @@ def main() -> int:
args = parser.parse_args()
logger.info("CLI arguments: %s", args)
ctx = ConversionContext(
model_path, training_device, half=args.half, opset=args.opset, token=args.token
)
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path):
logger.info("Model path does not existing, creating: %s", model_path)
makedirs(model_path)
if not path.exists(ctx.model_path):
logger.info("Model path does not existing, creating: %s", ctx.model_path)
makedirs(ctx.model_path)
logger.info("Converting base models.")
convert_models(ctx, args, base_models)

View File

@ -15,6 +15,7 @@ import json
import os
import re
import shutil
import sys
import traceback
from logging import getLogger
from typing import Dict, List
@ -40,11 +41,10 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig,
LDMBertModel,
)
import sys
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download
from omegaconf import OmegaConf
from pydantic import BaseModel
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
@ -54,16 +54,12 @@ from transformers import (
)
from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict
from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml
logger = getLogger(__name__)
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in "._- "))
class DreamboothConfig(BaseModel):
class TrainingConfig():
adamw_weight_decay: float = 0.01
attention: str = "default"
cache_latents: bool = True
@ -145,10 +141,16 @@ class DreamboothConfig(BaseModel):
use_subdir: bool = False
v2: bool = False
def __init__(self, ctx: ConversionContext, model_name: str = "", scheduler: str = "ddim", v2: bool = False, src: str = "",
resolution: int = 512, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
ctx: ConversionContext,
model_name: str = "",
scheduler: str = "ddim",
v2: bool = False,
src: str = "",
resolution: int = 512,
**kwargs,
):
model_name = sanitize_name(model_name)
model_dir = os.path.join(ctx.cache_path, model_name)
working_dir = os.path.join(model_dir, "working")
@ -164,6 +166,10 @@ class DreamboothConfig(BaseModel):
self.scheduler = scheduler
self.v2 = v2
# avoid pydantic dep for this one fn
for k, v in kwargs.items():
setattr(self, k, v)
def save(self, backup=False):
"""
Save the config file
@ -960,9 +966,7 @@ def replace_symlinks(path, base):
replace_symlinks(os.path.join(path, subpath), base)
def download_model(db_config: DreamboothConfig, token):
tmp_dir = os.path.join(db_config.model_dir, "src")
def download_model(db_config: TrainingConfig, token):
hub_url = db_config.src
if "http" in hub_url or "huggingface.co" in hub_url:
hub_url = "/".join(hub_url.split("/")[-2:])
@ -976,7 +980,7 @@ def download_model(db_config: DreamboothConfig, token):
)
if repo_info.sha is None:
logger.warning("Unable to fetch repo?")
logger.warning("Unable to fetch repo info: %s", hub_url)
return None, None
siblings = repo_info.siblings
@ -1018,7 +1022,6 @@ def download_model(db_config: DreamboothConfig, token):
files_to_fetch = None
cache_dir = tmp_dir
if model_file is not None:
files_to_fetch = [model_file]
elif len(diffusion_files):
@ -1043,7 +1046,6 @@ def download_model(db_config: DreamboothConfig, token):
filename=repo_file,
repo_type="model",
revision=repo_info.sha,
cache_dir=cache_dir,
token=token
)
replace_symlinks(out, db_config.model_dir)
@ -1163,12 +1165,8 @@ def extract_checkpoint(
upcast_attention = False
msg = None
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
logger.warning("Please provide a URL and token for huggingface models.")
return
# Create empty config
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
src=checkpoint_file if not from_hub else new_model_url)
original_config_file = None
@ -1243,7 +1241,7 @@ def extract_checkpoint(
logger.debug("UNet using v2 parameters.")
v2 = True
except Exception as e:
logger.error("Exception loading unet!", traceback.format_exception(e))
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
if v2 and not is_512:
prediction_type = "v_prediction"
@ -1275,7 +1273,7 @@ def extract_checkpoint(
return
logger.debug(f"Trying to load: {original_config_file}")
original_config = OmegaConf.load(original_config_file)
original_config = load_yaml(original_config_file)
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
@ -1382,7 +1380,7 @@ def extract_checkpoint(
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler)
except Exception as e:
logger.error("Exception setting up output: %s", traceback.format_exception(e))
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
pipe = None
if pipe is None or db_config is None:
@ -1397,12 +1395,12 @@ def extract_checkpoint(
scheduler = db_config.scheduler
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file):
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir)
shutil.copy(original_config_file, db_config.model_dir)
basename = os.path.basename(original_config_file)
new_ex_path = os.path.join(db_config.model_dir, basename)
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name)
if os.path.exists(new_name):
os.remove(new_name)
os.rename(new_ex_path, new_name)
@ -1446,7 +1444,7 @@ def convert_diffusion_original(
working_name = os.path.join(ctx.cache_path, torch_name, "working")
if os.path.exists(torch_path):
logger.info("Torch pipeline already exists, reusing.")
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, from_hub=False)

View File

@ -1,13 +1,14 @@
import shutil
from functools import partial
from logging import getLogger
from os import path
from os import environ, path
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import requests
import torch
from tqdm.auto import tqdm
from yaml import safe_load
logger = getLogger(__name__)
@ -19,21 +20,28 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext:
def __init__(
self,
model_path: str,
device: str,
model_path: Optional[str] = None,
device: Optional[str] = None,
cache_path: Optional[str] = None,
half: Optional[bool] = False,
opset: Optional[int] = None,
token: Optional[str] = None,
) -> None:
self.model_path = model_path
self.cache_path = cache_path or path.join(model_path, ".cache")
self.training_device = device
self.map_location = torch.device(device)
self.model_path = model_path or environ.get(
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
)
self.cache_path = cache_path or path.join(self.model_path, ".cache")
self.half = half
self.opset = opset
self.token = token
if device is not None:
self.training_device = device
else:
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"
self.map_location = torch.device(self.training_device)
def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
@ -135,3 +143,35 @@ def source_format(model: Dict) -> Optional[str]:
return ext
return None
class Config(object):
def __init__(self, kwargs):
self.__dict__.update(kwargs)
for k, v in self.__dict__.items():
Config.config_from_key(self, k, v)
def __iter__(self):
for k in self.__dict__.keys():
yield k
@classmethod
def config_from_key(cls, target, k, v):
if isinstance(v, dict):
tmp = Config(v)
setattr(target, k, tmp)
else:
setattr(target, k, v)
def load_yaml(file: str) -> str:
with open(file, "r") as f:
data = safe_load(f.read())
return Config(data)
safe_chars = "._-"
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in safe_chars))

View File

@ -7,6 +7,7 @@ accelerate
diffusers
onnx
onnxruntime
safetensors
transformers
#### Upscaling and face correction
@ -22,8 +23,3 @@ flask
flask-cors
jsonschema
pyyaml
# TODO: get rid of these
omegaconf
pydantic
safetensors

View File

@ -217,7 +217,7 @@ make sure it is named correctly.
#### Model sources
You can provide an absolute or relative path to a local model, and there are a few pre-defined sources from which models can be downloaded.
You can provide an absolute or relative path to a local model, an HTTPS URL, or use one of the pre-defined sources:
- `huggingface://`
- https://huggingface.co/models?other=stable-diffusion
@ -229,9 +229,12 @@ You can provide an absolute or relative path to a local model, and there are a f
- does not require an account
- `https://`
- any other HTTPS source
- `../models/.cache/your-model.safetensors`
- relative or absolute paths
If the model's `source` does not include a file extension like `.safetensors` or `.ckpt`, make sure to indicate the
file format using the `format` key.
If the model is a single file and the `source` does not include a file extension like `.safetensors` or `.ckpt`, make
sure to indicate the file format using the `format` key. You do not need to provide the `format` for directories and
models from the HuggingFace hub.
## Tabs