lint(api): clean up conversion code from original diffusers, drop pydantic dep
This commit is contained in:
parent
84079e4490
commit
694d15547f
|
@ -1,11 +1,10 @@
|
||||||
import warnings
|
import warnings
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, makedirs, path
|
from os import makedirs, path
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
from jsonschema import ValidationError, validate
|
from jsonschema import ValidationError, validate
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
|
@ -102,9 +101,6 @@ base_models: Models = {
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models"))
|
|
||||||
training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(
|
def fetch_model(
|
||||||
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
ctx: ConversionContext, name: str, source: str, format: Optional[str] = None
|
||||||
|
@ -228,14 +224,12 @@ def main() -> int:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("CLI arguments: %s", args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
ctx = ConversionContext(
|
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
|
||||||
model_path, training_device, half=args.half, opset=args.opset, token=args.token
|
|
||||||
)
|
|
||||||
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
logger.info("Converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||||
|
|
||||||
if not path.exists(model_path):
|
if not path.exists(ctx.model_path):
|
||||||
logger.info("Model path does not existing, creating: %s", model_path)
|
logger.info("Model path does not existing, creating: %s", ctx.model_path)
|
||||||
makedirs(model_path)
|
makedirs(ctx.model_path)
|
||||||
|
|
||||||
logger.info("Converting base models.")
|
logger.info("Converting base models.")
|
||||||
convert_models(ctx, args, base_models)
|
convert_models(ctx, args, base_models)
|
||||||
|
|
|
@ -15,6 +15,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
@ -40,11 +41,10 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||||
LDMBertConfig,
|
LDMBertConfig,
|
||||||
LDMBertModel,
|
LDMBertModel,
|
||||||
)
|
)
|
||||||
|
import sys
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoFeatureExtractor,
|
AutoFeatureExtractor,
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
|
@ -54,16 +54,12 @@ from transformers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
from .utils import ConversionContext, ModelDict
|
from .utils import ConversionContext, ModelDict, sanitize_name, load_yaml
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(name):
|
class TrainingConfig():
|
||||||
return "".join(x for x in name if (x.isalnum() or x in "._- "))
|
|
||||||
|
|
||||||
|
|
||||||
class DreamboothConfig(BaseModel):
|
|
||||||
adamw_weight_decay: float = 0.01
|
adamw_weight_decay: float = 0.01
|
||||||
attention: str = "default"
|
attention: str = "default"
|
||||||
cache_latents: bool = True
|
cache_latents: bool = True
|
||||||
|
@ -145,10 +141,16 @@ class DreamboothConfig(BaseModel):
|
||||||
use_subdir: bool = False
|
use_subdir: bool = False
|
||||||
v2: bool = False
|
v2: bool = False
|
||||||
|
|
||||||
def __init__(self, ctx: ConversionContext, model_name: str = "", scheduler: str = "ddim", v2: bool = False, src: str = "",
|
def __init__(
|
||||||
resolution: int = 512, **kwargs):
|
self,
|
||||||
|
ctx: ConversionContext,
|
||||||
super().__init__(**kwargs)
|
model_name: str = "",
|
||||||
|
scheduler: str = "ddim",
|
||||||
|
v2: bool = False,
|
||||||
|
src: str = "",
|
||||||
|
resolution: int = 512,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
model_name = sanitize_name(model_name)
|
model_name = sanitize_name(model_name)
|
||||||
model_dir = os.path.join(ctx.cache_path, model_name)
|
model_dir = os.path.join(ctx.cache_path, model_name)
|
||||||
working_dir = os.path.join(model_dir, "working")
|
working_dir = os.path.join(model_dir, "working")
|
||||||
|
@ -164,6 +166,10 @@ class DreamboothConfig(BaseModel):
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.v2 = v2
|
self.v2 = v2
|
||||||
|
|
||||||
|
# avoid pydantic dep for this one fn
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
def save(self, backup=False):
|
def save(self, backup=False):
|
||||||
"""
|
"""
|
||||||
Save the config file
|
Save the config file
|
||||||
|
@ -960,9 +966,7 @@ def replace_symlinks(path, base):
|
||||||
replace_symlinks(os.path.join(path, subpath), base)
|
replace_symlinks(os.path.join(path, subpath), base)
|
||||||
|
|
||||||
|
|
||||||
def download_model(db_config: DreamboothConfig, token):
|
def download_model(db_config: TrainingConfig, token):
|
||||||
tmp_dir = os.path.join(db_config.model_dir, "src")
|
|
||||||
|
|
||||||
hub_url = db_config.src
|
hub_url = db_config.src
|
||||||
if "http" in hub_url or "huggingface.co" in hub_url:
|
if "http" in hub_url or "huggingface.co" in hub_url:
|
||||||
hub_url = "/".join(hub_url.split("/")[-2:])
|
hub_url = "/".join(hub_url.split("/")[-2:])
|
||||||
|
@ -976,7 +980,7 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
)
|
)
|
||||||
|
|
||||||
if repo_info.sha is None:
|
if repo_info.sha is None:
|
||||||
logger.warning("Unable to fetch repo?")
|
logger.warning("Unable to fetch repo info: %s", hub_url)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
siblings = repo_info.siblings
|
siblings = repo_info.siblings
|
||||||
|
@ -1018,7 +1022,6 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
|
|
||||||
files_to_fetch = None
|
files_to_fetch = None
|
||||||
|
|
||||||
cache_dir = tmp_dir
|
|
||||||
if model_file is not None:
|
if model_file is not None:
|
||||||
files_to_fetch = [model_file]
|
files_to_fetch = [model_file]
|
||||||
elif len(diffusion_files):
|
elif len(diffusion_files):
|
||||||
|
@ -1043,7 +1046,6 @@ def download_model(db_config: DreamboothConfig, token):
|
||||||
filename=repo_file,
|
filename=repo_file,
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
revision=repo_info.sha,
|
revision=repo_info.sha,
|
||||||
cache_dir=cache_dir,
|
|
||||||
token=token
|
token=token
|
||||||
)
|
)
|
||||||
replace_symlinks(out, db_config.model_dir)
|
replace_symlinks(out, db_config.model_dir)
|
||||||
|
@ -1163,12 +1165,8 @@ def extract_checkpoint(
|
||||||
upcast_attention = False
|
upcast_attention = False
|
||||||
msg = None
|
msg = None
|
||||||
|
|
||||||
if from_hub and (new_model_url == "" or new_model_url is None) and (new_model_token is None or new_model_token == ""):
|
|
||||||
logger.warning("Please provide a URL and token for huggingface models.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create empty config
|
# Create empty config
|
||||||
db_config = DreamboothConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
|
||||||
src=checkpoint_file if not from_hub else new_model_url)
|
src=checkpoint_file if not from_hub else new_model_url)
|
||||||
|
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
|
@ -1243,7 +1241,7 @@ def extract_checkpoint(
|
||||||
logger.debug("UNet using v2 parameters.")
|
logger.debug("UNet using v2 parameters.")
|
||||||
v2 = True
|
v2 = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Exception loading unet!", traceback.format_exception(e))
|
logger.error("Exception loading unet!", traceback.format_exception(*sys.exc_info()))
|
||||||
|
|
||||||
if v2 and not is_512:
|
if v2 and not is_512:
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
|
@ -1275,7 +1273,7 @@ def extract_checkpoint(
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(f"Trying to load: {original_config_file}")
|
logger.debug(f"Trying to load: {original_config_file}")
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = load_yaml(original_config_file)
|
||||||
|
|
||||||
num_train_timesteps = original_config.model.params.timesteps
|
num_train_timesteps = original_config.model.params.timesteps
|
||||||
beta_start = original_config.model.params.linear_start
|
beta_start = original_config.model.params.linear_start
|
||||||
|
@ -1382,7 +1380,7 @@ def extract_checkpoint(
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
|
||||||
scheduler=scheduler)
|
scheduler=scheduler)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Exception setting up output: %s", traceback.format_exception(e))
|
logger.error("Exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
|
||||||
pipe = None
|
pipe = None
|
||||||
|
|
||||||
if pipe is None or db_config is None:
|
if pipe is None or db_config is None:
|
||||||
|
@ -1397,12 +1395,12 @@ def extract_checkpoint(
|
||||||
scheduler = db_config.scheduler
|
scheduler = db_config.scheduler
|
||||||
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
|
||||||
if original_config_file is not None and os.path.exists(original_config_file):
|
if original_config_file is not None and os.path.exists(original_config_file):
|
||||||
logger.warning("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir)
|
||||||
shutil.copy(original_config_file, db_config.model_dir)
|
shutil.copy(original_config_file, db_config.model_dir)
|
||||||
basename = os.path.basename(original_config_file)
|
basename = os.path.basename(original_config_file)
|
||||||
new_ex_path = os.path.join(db_config.model_dir, basename)
|
new_ex_path = os.path.join(db_config.model_dir, basename)
|
||||||
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
|
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
|
||||||
logger.warning("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name)
|
||||||
if os.path.exists(new_name):
|
if os.path.exists(new_name):
|
||||||
os.remove(new_name)
|
os.remove(new_name)
|
||||||
os.rename(new_ex_path, new_name)
|
os.rename(new_ex_path, new_name)
|
||||||
|
@ -1446,7 +1444,7 @@ def convert_diffusion_original(
|
||||||
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
||||||
|
|
||||||
if os.path.exists(torch_path):
|
if os.path.exists(torch_path):
|
||||||
logger.info("Torch pipeline already exists, reusing.")
|
logger.info("Torch pipeline already exists, reusing: %s", torch_path)
|
||||||
else:
|
else:
|
||||||
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
logger.info("Converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
|
||||||
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
extract_checkpoint(ctx, torch_name, source, from_hub=False)
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import shutil
|
import shutil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import environ, path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
from yaml import safe_load
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -19,21 +20,28 @@ LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||||
class ConversionContext:
|
class ConversionContext:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: Optional[str] = None,
|
||||||
device: str,
|
device: Optional[str] = None,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
half: Optional[bool] = False,
|
half: Optional[bool] = False,
|
||||||
opset: Optional[int] = None,
|
opset: Optional[int] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_path = model_path
|
self.model_path = model_path or environ.get(
|
||||||
self.cache_path = cache_path or path.join(model_path, ".cache")
|
"ONNX_WEB_MODEL_PATH", path.join("..", "models")
|
||||||
self.training_device = device
|
)
|
||||||
self.map_location = torch.device(device)
|
self.cache_path = cache_path or path.join(self.model_path, ".cache")
|
||||||
self.half = half
|
self.half = half
|
||||||
self.opset = opset
|
self.opset = opset
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
self.training_device = device
|
||||||
|
else:
|
||||||
|
self.training_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
self.map_location = torch.device(self.training_device)
|
||||||
|
|
||||||
|
|
||||||
def download_progress(urls: List[Tuple[str, str]]):
|
def download_progress(urls: List[Tuple[str, str]]):
|
||||||
for url, dest in urls:
|
for url, dest in urls:
|
||||||
|
@ -135,3 +143,35 @@ def source_format(model: Dict) -> Optional[str]:
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Config(object):
|
||||||
|
def __init__(self, kwargs):
|
||||||
|
self.__dict__.update(kwargs)
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
Config.config_from_key(self, k, v)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for k in self.__dict__.keys():
|
||||||
|
yield k
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_from_key(cls, target, k, v):
|
||||||
|
if isinstance(v, dict):
|
||||||
|
tmp = Config(v)
|
||||||
|
setattr(target, k, tmp)
|
||||||
|
else:
|
||||||
|
setattr(target, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml(file: str) -> str:
|
||||||
|
with open(file, "r") as f:
|
||||||
|
data = safe_load(f.read())
|
||||||
|
return Config(data)
|
||||||
|
|
||||||
|
|
||||||
|
safe_chars = "._-"
|
||||||
|
def sanitize_name(name):
|
||||||
|
return "".join(x for x in name if (x.isalnum() or x in safe_chars))
|
||||||
|
|
|
@ -7,6 +7,7 @@ accelerate
|
||||||
diffusers
|
diffusers
|
||||||
onnx
|
onnx
|
||||||
onnxruntime
|
onnxruntime
|
||||||
|
safetensors
|
||||||
transformers
|
transformers
|
||||||
|
|
||||||
#### Upscaling and face correction
|
#### Upscaling and face correction
|
||||||
|
@ -22,8 +23,3 @@ flask
|
||||||
flask-cors
|
flask-cors
|
||||||
jsonschema
|
jsonschema
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|
||||||
# TODO: get rid of these
|
|
||||||
omegaconf
|
|
||||||
pydantic
|
|
||||||
safetensors
|
|
|
@ -217,7 +217,7 @@ make sure it is named correctly.
|
||||||
|
|
||||||
#### Model sources
|
#### Model sources
|
||||||
|
|
||||||
You can provide an absolute or relative path to a local model, and there are a few pre-defined sources from which models can be downloaded.
|
You can provide an absolute or relative path to a local model, an HTTPS URL, or use one of the pre-defined sources:
|
||||||
|
|
||||||
- `huggingface://`
|
- `huggingface://`
|
||||||
- https://huggingface.co/models?other=stable-diffusion
|
- https://huggingface.co/models?other=stable-diffusion
|
||||||
|
@ -229,9 +229,12 @@ You can provide an absolute or relative path to a local model, and there are a f
|
||||||
- does not require an account
|
- does not require an account
|
||||||
- `https://`
|
- `https://`
|
||||||
- any other HTTPS source
|
- any other HTTPS source
|
||||||
|
- `../models/.cache/your-model.safetensors`
|
||||||
|
- relative or absolute paths
|
||||||
|
|
||||||
If the model's `source` does not include a file extension like `.safetensors` or `.ckpt`, make sure to indicate the
|
If the model is a single file and the `source` does not include a file extension like `.safetensors` or `.ckpt`, make
|
||||||
file format using the `format` key.
|
sure to indicate the file format using the `format` key. You do not need to provide the `format` for directories and
|
||||||
|
models from the HuggingFace hub.
|
||||||
|
|
||||||
## Tabs
|
## Tabs
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue