1
0
Fork 0

lint(api): clean up conversion code from original diffusers, drop pydantic dep

This commit is contained in:
Sean Sube 2023-02-11 12:36:54 -06:00
parent 84079e4490
commit 694d15547f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 86 additions and 55 deletions

View File

@ -1,11 +1,10 @@
import warnings import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import environ, makedirs, path from os import makedirs, path
from sys import exit from sys import exit
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch
from jsonschema import ValidationError, validate from jsonschema import ValidationError, validate
from yaml import safe_load from yaml import safe_load
@ -102,9 +101,6 @@ base_models: Models = {
], ],
} }
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
training_device = "cuda" if torch.cuda.is_available() else "cpu"
def fetch_model( def fetch_model(
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
@ -228,14 +224,12 @@ def main() -> int:
args = parser.parse_args() args = parser.parse_args()
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
ctx = ConversionContext( ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
model_path, training_device, half=args.half, opset=args.opset, token=args.token
)
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device) logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
if not path.exists(model_path): if not path.exists(ctx.model_path):
logger.info("Model path does not existing, creating: %s", model_path) logger.info("Model path does not existing, creating: %s", ctx.model_path)
makedirs(model_path) makedirs(ctx.model_path)
logger.info("Converting base models.") logger.info("Converting base models.")
convert_models(ctx, args, base_models) convert_models(ctx, args, base_models)

View File

@ -15,6 +15,7 @@ import json
import os import os
import re import re
import shutil import shutil
import sys
import traceback import traceback
from logging import getLogger from logging import getLogger
from typing import Dict, List from typing import Dict, List
@ -40,11 +41,10 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig, LDMBertConfig,
LDMBertModel, LDMBertModel,
) )
import sys
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from omegaconf import OmegaConf
from pydantic import BaseModel
from transformers import ( from transformers import (
AutoFeatureExtractor, AutoFeatureExtractor,
BertTokenizerFast, BertTokenizerFast,
@ -54,16 +54,12 @@ from transformers import (
) )
from .diffusion_stable import convert_diffusion_stable from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml
logger = getLogger(__name__) logger = getLogger(__name__)
def sanitize_name(name): class TrainingConfig():
return "".join(x for x in name if (x.isalnum() or x in "._- "))
class DreamboothConfig(BaseModel):
adamw_weight_decay: float = 0.01 adamw_weight_decay: float = 0.01
attention: str = "default" attention: str = "default"
cache_latents: bool = True cache_latents: bool = True
@ -145,10 +141,16 @@ class DreamboothConfig(BaseModel):
use_subdir: bool = False use_subdir: bool = False
v2: bool = False v2: bool = False
def __init__(self, ctx: ConversionContext, model_name: str = "", scheduler: str = "ddim", v2: bool = False, src: str = "", def __init__(
resolution: int = 512, **kwargs): self,
ctx: ConversionContext,
super().__init__(**kwargs) model_name: str = "",
scheduler: str = "ddim",
v2: bool = False,
src: str = "",
resolution: int = 512,
**kwargs,
):
model_name = sanitize_name(model_name) model_name = sanitize_name(model_name)
model_dir = os.path.join(ctx.cache_path, model_name) model_dir = os.path.join(ctx.cache_path, model_name)
working_dir = os.path.join(model_dir, "working") working_dir = os.path.join(model_dir, "working")
@ -164,6 +166,10 @@ class DreamboothConfig(BaseModel):
self.scheduler = scheduler self.scheduler = scheduler
self.v2 = v2 self.v2 = v2
# avoid pydantic dep for this one fn
for k, v in kwargs.items():
setattr(self, k, v)
def save(self, backup=False): def save(self, backup=False):
""" """
Save the config file Save the config file
@ -960,9 +966,7 @@ def replace_symlinks(path, base):
replace_symlinks(os.path.join(path, subpath), base) replace_symlinks(os.path.join(path, subpath), base)
def download_model(db_config: DreamboothConfig, token): def download_model(db_config: TrainingConfig, token):
tmp_dir = os.path.join(db_config.model_dir, "src")
hub_url = db_config.src hub_url = db_config.src
if "http" in hub_url or "huggingface.co" in hub_url: if "http" in hub_url or "huggingface.co" in hub_url:
hub_url = "/".join(hub_url.split("/")[-2:]) hub_url = "/".join(hub_url.split("/")[-2:])
@ -976,7 +980,7 @@ def download_model(db_config: DreamboothConfig, token):
) )
if repo_info.sha is None: if repo_info.sha is None:
logger.warning("Unable to fetch repo?") logger.warning("Unable to fetch repo info: %s", hub_url)
return None, None return None, None
siblings = repo_info.siblings siblings = repo_info.siblings
@ -1018,7 +1022,6 @@ def download_model(db_config: DreamboothConfig, token):
files_to_fetch = None files_to_fetch = None
cache_dir = tmp_dir
if model_file is not None: if model_file is not None:
files_to_fetch = [model_file] files_to_fetch = [model_file]
elif len(diffusion_files): elif len(diffusion_files):
@ -1043,7 +1046,6 @@ def download_model(db_config: DreamboothConfig, token):
filename=repo_file, filename=repo_file,
repo_type="model", repo_type="model",
revision=repo_info.sha, revision=repo_info.sha,
cache_dir=cache_dir,
token=token token=token
) )
replace_symlinks(out, db_config.model_dir) replace_symlinks(out, db_config.model_dir)
@ -1163,12 +1165,8 @@ def extract_checkpoint(
upcast_attention = False upcast_attention = False
msg = None msg = None
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
logger.warning("Please provide a URL and token for huggingface models.")
return
# Create empty config # Create empty config
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
src=checkpoint_file if not from_hub else new_model_url) src=checkpoint_file if not from_hub else new_model_url)
original_config_file = None original_config_file = None
@ -1243,7 +1241,7 @@ def extract_checkpoint(
logger.debug("UNet using v2 parameters.") logger.debug("UNet using v2 parameters.")
v2 = True v2 = True
except Exception as e: except Exception as e:
logger.error("Exception loading unet!", traceback.format_exception(e)) logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
if v2 and not is_512: if v2 and not is_512:
prediction_type = "v_prediction" prediction_type = "v_prediction"
@ -1275,7 +1273,7 @@ def extract_checkpoint(
return return
logger.debug(f"Trying to load: {original_config_file}") logger.debug(f"Trying to load: {original_config_file}")
original_config = OmegaConf.load(original_config_file) original_config = load_yaml(original_config_file)
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start beta_start = original_config.model.params.linear_start
@ -1382,7 +1380,7 @@ def extract_checkpoint(
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler) scheduler=scheduler)
except Exception as e: except Exception as e:
logger.error("Exception setting up output: %s", traceback.format_exception(e)) logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
pipe = None pipe = None
if pipe is None or db_config is None: if pipe is None or db_config is None:
@ -1397,12 +1395,12 @@ def extract_checkpoint(
scheduler = db_config.scheduler scheduler = db_config.scheduler
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file): if original_config_file is not None and os.path.exists(original_config_file):
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir) logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir)
shutil.copy(original_config_file, db_config.model_dir) shutil.copy(original_config_file, db_config.model_dir)
basename = os.path.basename(original_config_file) basename = os.path.basename(original_config_file)
new_ex_path = os.path.join(db_config.model_dir, basename) new_ex_path = os.path.join(db_config.model_dir, basename)
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name) logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name)
if os.path.exists(new_name): if os.path.exists(new_name):
os.remove(new_name) os.remove(new_name)
os.rename(new_ex_path, new_name) os.rename(new_ex_path, new_name)
@ -1446,7 +1444,7 @@ def convert_diffusion_original(
working_name = os.path.join(ctx.cache_path, torch_name, "working") working_name = os.path.join(ctx.cache_path, torch_name, "working")
if os.path.exists(torch_path): if os.path.exists(torch_path):
logger.info("Torch pipeline already exists, reusing.") logger.info("Torch pipeline already exists, reusing: %s", torch_path)
else: else:
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path) logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, from_hub=False) extract_checkpoint(ctx, torch_name, source, from_hub=False)

View File

@ -1,13 +1,14 @@
import shutil import shutil
from functools import partial from functools import partial
from logging import getLogger from logging import getLogger
from os import path from os import environ, path
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import requests import requests
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from yaml import safe_load
logger = getLogger(__name__) logger = getLogger(__name__)
@ -19,21 +20,28 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
class ConversionContext: class ConversionContext:
def __init__( def __init__(
self, self,
model_path: str, model_path: Optional[str] = None,
device: str, device: Optional[str] = None,
cache_path: Optional[str] = None, cache_path: Optional[str] = None,
half: Optional[bool] = False, half: Optional[bool] = False,
opset: Optional[int] = None, opset: Optional[int] = None,
token: Optional[str] = None, token: Optional[str] = None,
) -> None: ) -> None:
self.model_path = model_path self.model_path = model_path or environ.get(
self.cache_path = cache_path or path.join(model_path, ".cache") "ONNX_WEB_MODEL_PATH", path.join("..", "models")
self.training_device = device )
self.map_location = torch.device(device) self.cache_path = cache_path or path.join(self.model_path, ".cache")
self.half = half self.half = half
self.opset = opset self.opset = opset
self.token = token self.token = token
if device is not None:
self.training_device = device
else:
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"
self.map_location = torch.device(self.training_device)
def download_progress(urls: List[Tuple[str, str]]): def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls: for url, dest in urls:
@ -135,3 +143,35 @@ def source_format(model: Dict) -> Optional[str]:
return ext return ext
return None return None
class Config(object):
def __init__(self, kwargs):
self.__dict__.update(kwargs)
for k, v in self.__dict__.items():
Config.config_from_key(self, k, v)
def __iter__(self):
for k in self.__dict__.keys():
yield k
@classmethod
def config_from_key(cls, target, k, v):
if isinstance(v, dict):
tmp = Config(v)
setattr(target, k, tmp)
else:
setattr(target, k, v)
def load_yaml(file: str) -> str:
with open(file, "r") as f:
data = safe_load(f.read())
return Config(data)
safe_chars = "._-"
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in safe_chars))

View File

@ -7,6 +7,7 @@ accelerate
diffusers diffusers
onnx onnx
onnxruntime onnxruntime
safetensors
transformers transformers
#### Upscaling and face correction #### Upscaling and face correction
@ -22,8 +23,3 @@ flask
flask-cors flask-cors
jsonschema jsonschema
pyyaml pyyaml
# TODO: get rid of these
omegaconf
pydantic
safetensors

View File

@ -217,7 +217,7 @@ make sure it is named correctly.
#### Model sources #### Model sources
You can provide an absolute or relative path to a local model, and there are a few pre-defined sources from which models can be downloaded. You can provide an absolute or relative path to a local model, an HTTPS URL, or use one of the pre-defined sources:
- `huggingface://` - `huggingface://`
- https://huggingface.co/models?other=stable-diffusion - https://huggingface.co/models?other=stable-diffusion
@ -229,9 +229,12 @@ You can provide an absolute or relative path to a local model, and there are a f
- does not require an account - does not require an account
- `https://` - `https://`
- any other HTTPS source - any other HTTPS source
- `../models/.cache/your-model.safetensors`
- relative or absolute paths
If the model's `source` does not include a file extension like `.safetensors` or `.ckpt`, make sure to indicate the If the model is a single file and the `source` does not include a file extension like `.safetensors` or `.ckpt`, make
file format using the `format` key. sure to indicate the file format using the `format` key. You do not need to provide the `format` for directories and
models from the HuggingFace hub.
## Tabs ## Tabs