lint(api): clean up conversion code from original diffusers, drop pydantic dep
This commit is contained in:
parent
84079e4490
commit
694d15547f
|
@ -1,11 +1,10 @@
|
|||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from logging import getLogger
|
||||
from os import environ, makedirs, path
|
||||
from os import makedirs, path
|
||||
from sys import exit
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from jsonschema import ValidationError, validate
|
||||
from yaml import safe_load
|
||||
|
||||
|
@ -102,9 +101,6 @@ base_models: Models = {
|
|||
],
|
||||
}
|
||||
|
||||
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def fetch_model(
|
||||
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
||||
|
@ -228,14 +224,12 @@ def main() -> int:
|
|||
args = parser.parse_args()
|
||||
logger.info("CLI arguments: %s", args)
|
||||
|
||||
ctx = ConversionContext(
|
||||
model_path, training_device, half=args.half, opset=args.opset, token=args.token
|
||||
)
|
||||
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
|
||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||
|
||||
if not path.exists(model_path):
|
||||
logger.info("Model path does not existing, creating: %s", model_path)
|
||||
makedirs(model_path)
|
||||
if not path.exists(ctx.model_path):
|
||||
logger.info("Model path does not existing, creating: %s", ctx.model_path)
|
||||
makedirs(ctx.model_path)
|
||||
|
||||
logger.info("Converting base models.")
|
||||
convert_models(ctx, args, base_models)
|
||||
|
|
|
@ -15,6 +15,7 @@ import json
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from logging import getLogger
|
||||
from typing import Dict, List
|
||||
|
@ -40,11 +41,10 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
|||
LDMBertConfig,
|
||||
LDMBertModel,
|
||||
)
|
||||
import sys
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
|
@ -54,16 +54,12 @@ from transformers import (
|
|||
)
|
||||
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .utils import ConversionContext, ModelDict
|
||||
from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
|
||||
class DreamboothConfig(BaseModel):
|
||||
class TrainingConfig():
|
||||
adamw_weight_decay: float = 0.01
|
||||
attention: str = "default"
|
||||
cache_latents: bool = True
|
||||
|
@ -145,10 +141,16 @@ class DreamboothConfig(BaseModel):
|
|||
use_subdir: bool = False
|
||||
v2: bool = False
|
||||
|
||||
def __init__(self, ctx: ConversionContext, model_name: str = "", scheduler: str = "ddim", v2: bool = False, src: str = "",
|
||||
resolution: int = 512, **kwargs):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
ctx: ConversionContext,
|
||||
model_name: str = "",
|
||||
scheduler: str = "ddim",
|
||||
v2: bool = False,
|
||||
src: str = "",
|
||||
resolution: int = 512,
|
||||
**kwargs,
|
||||
):
|
||||
model_name = sanitize_name(model_name)
|
||||
model_dir = os.path.join(ctx.cache_path, model_name)
|
||||
working_dir = os.path.join(model_dir, "working")
|
||||
|
@ -164,6 +166,10 @@ class DreamboothConfig(BaseModel):
|
|||
self.scheduler = scheduler
|
||||
self.v2 = v2
|
||||
|
||||
# avoid pydantic dep for this one fn
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def save(self, backup=False):
|
||||
"""
|
||||
Save the config file
|
||||
|
@ -960,9 +966,7 @@ def replace_symlinks(path, base):
|
|||
replace_symlinks(os.path.join(path, subpath), base)
|
||||
|
||||
|
||||
def download_model(db_config: DreamboothConfig, token):
|
||||
tmp_dir = os.path.join(db_config.model_dir, "src")
|
||||
|
||||
def download_model(db_config: TrainingConfig, token):
|
||||
hub_url = db_config.src
|
||||
if "http" in hub_url or "huggingface.co" in hub_url:
|
||||
hub_url = "/".join(hub_url.split("/")[-2:])
|
||||
|
@ -976,7 +980,7 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
)
|
||||
|
||||
if repo_info.sha is None:
|
||||
logger.warning("Unable to fetch repo?")
|
||||
logger.warning("Unable to fetch repo info: %s", hub_url)
|
||||
return None, None
|
||||
|
||||
siblings = repo_info.siblings
|
||||
|
@ -1018,7 +1022,6 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
|
||||
files_to_fetch = None
|
||||
|
||||
cache_dir = tmp_dir
|
||||
if model_file is not None:
|
||||
files_to_fetch = [model_file]
|
||||
elif len(diffusion_files):
|
||||
|
@ -1043,7 +1046,6 @@ def download_model(db_config: DreamboothConfig, token):
|
|||
filename=repo_file,
|
||||
repo_type="model",
|
||||
revision=repo_info.sha,
|
||||
cache_dir=cache_dir,
|
||||
token=token
|
||||
)
|
||||
replace_symlinks(out, db_config.model_dir)
|
||||
|
@ -1163,12 +1165,8 @@ def extract_checkpoint(
|
|||
upcast_attention = False
|
||||
msg = None
|
||||
|
||||
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
|
||||
logger.warning("Please provide a URL and token for huggingface models.")
|
||||
return
|
||||
|
||||
# Create empty config
|
||||
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||
src=checkpoint_file if not from_hub else new_model_url)
|
||||
|
||||
original_config_file = None
|
||||
|
@ -1243,7 +1241,7 @@ def extract_checkpoint(
|
|||
logger.debug("UNet using v2 parameters.")
|
||||
v2 = True
|
||||
except Exception as e:
|
||||
logger.error("Exception loading unet!", traceback.format_exception(e))
|
||||
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
||||
|
||||
if v2 and not is_512:
|
||||
prediction_type = "v_prediction"
|
||||
|
@ -1275,7 +1273,7 @@ def extract_checkpoint(
|
|||
return
|
||||
|
||||
logger.debug(f"Trying to load: {original_config_file}")
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = load_yaml(original_config_file)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
|
@ -1382,7 +1380,7 @@ def extract_checkpoint(
|
|||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=scheduler)
|
||||
except Exception as e:
|
||||
logger.error("Exception setting up output: %s", traceback.format_exception(e))
|
||||
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
|
||||
pipe = None
|
||||
|
||||
if pipe is None or db_config is None:
|
||||
|
@ -1397,12 +1395,12 @@ def extract_checkpoint(
|
|||
scheduler = db_config.scheduler
|
||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||
if original_config_file is not None and os.path.exists(original_config_file):
|
||||
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||
logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||
shutil.copy(original_config_file, db_config.model_dir)
|
||||
basename = os.path.basename(original_config_file)
|
||||
new_ex_path = os.path.join(db_config.model_dir, basename)
|
||||
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
|
||||
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
||||
logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
||||
if os.path.exists(new_name):
|
||||
os.remove(new_name)
|
||||
os.rename(new_ex_path, new_name)
|
||||
|
@ -1446,7 +1444,7 @@ def convert_diffusion_original(
|
|||
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
||||
|
||||
if os.path.exists(torch_path):
|
||||
logger.info("Torch pipeline already exists, reusing.")
|
||||
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||
else:
|
||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import shutil
|
||||
from functools import partial
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from os import environ, path
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from yaml import safe_load
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -19,21 +20,28 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
|||
class ConversionContext:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
device: str,
|
||||
model_path: Optional[str] = None,
|
||||
device: Optional[str] = None,
|
||||
cache_path: Optional[str] = None,
|
||||
half: Optional[bool] = False,
|
||||
opset: Optional[int] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
||||
self.training_device = device
|
||||
self.map_location = torch.device(device)
|
||||
self.model_path = model_path or environ.get(
|
||||
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
|
||||
)
|
||||
self.cache_path = cache_path or path.join(self.model_path, ".cache")
|
||||
self.half = half
|
||||
self.opset = opset
|
||||
self.token = token
|
||||
|
||||
if device is not None:
|
||||
self.training_device = device
|
||||
else:
|
||||
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.map_location = torch.device(self.training_device)
|
||||
|
||||
|
||||
def download_progress(urls: List[Tuple[str, str]]):
|
||||
for url, dest in urls:
|
||||
|
@ -135,3 +143,35 @@ def source_format(model: Dict) -> Optional[str]:
|
|||
return ext
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
for k, v in self.__dict__.items():
|
||||
Config.config_from_key(self, k, v)
|
||||
|
||||
def __iter__(self):
|
||||
for k in self.__dict__.keys():
|
||||
yield k
|
||||
|
||||
@classmethod
|
||||
def config_from_key(cls, target, k, v):
|
||||
if isinstance(v, dict):
|
||||
tmp = Config(v)
|
||||
setattr(target, k, tmp)
|
||||
else:
|
||||
setattr(target, k, v)
|
||||
|
||||
|
||||
|
||||
def load_yaml(file: str) -> str:
|
||||
with open(file, "r") as f:
|
||||
data = safe_load(f.read())
|
||||
return Config(data)
|
||||
|
||||
|
||||
safe_chars = "._-"
|
||||
def sanitize_name(name):
|
||||
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
||||
|
|
|
@ -7,6 +7,7 @@ accelerate
|
|||
diffusers
|
||||
onnx
|
||||
onnxruntime
|
||||
safetensors
|
||||
transformers
|
||||
|
||||
#### Upscaling and face correction
|
||||
|
@ -22,8 +23,3 @@ flask
|
|||
flask-cors
|
||||
jsonschema
|
||||
pyyaml
|
||||
|
||||
# TODO: get rid of these
|
||||
omegaconf
|
||||
pydantic
|
||||
safetensors
|
|
@ -217,7 +217,7 @@ make sure it is named correctly.
|
|||
|
||||
#### Model sources
|
||||
|
||||
You can provide an absolute or relative path to a local model, and there are a few pre-defined sources from which models can be downloaded.
|
||||
You can provide an absolute or relative path to a local model, an HTTPS URL, or use one of the pre-defined sources:
|
||||
|
||||
- `huggingface://`
|
||||
- https://huggingface.co/models?other=stable-diffusion
|
||||
|
@ -229,9 +229,12 @@ You can provide an absolute or relative path to a local model, and there are a f
|
|||
- does not require an account
|
||||
- `https://`
|
||||
- any other HTTPS source
|
||||
- `../models/.cache/your-model.safetensors`
|
||||
- relative or absolute paths
|
||||
|
||||
If the model's `source` does not include a file extension like `.safetensors` or `.ckpt`, make sure to indicate the
|
||||
file format using the `format` key.
|
||||
If the model is a single file and the `source` does not include a file extension like `.safetensors` or `.ckpt`, make
|
||||
sure to indicate the file format using the `format` key. You do not need to provide the `format` for directories and
|
||||
models from the HuggingFace hub.
|
||||
|
||||
## Tabs
|
||||
|
||||
|
|
Loading…
Reference in New Issue