make blending happen once after conversion
This commit is contained in:
parent
32b2a76a0b
commit
c3979246df
|
@ -223,97 +223,100 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
ctx, name, model["source"], format=model_format
|
ctx, name, model["source"], format=model_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
converted = False
|
||||||
if model_format in model_formats_original:
|
if model_format in model_formats_original:
|
||||||
convert_diffusion_original(
|
converted, _dest = convert_diffusion_original(
|
||||||
ctx,
|
ctx,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
convert_diffusion_diffusers(
|
converted, _dest = convert_diffusion_diffusers(
|
||||||
ctx,
|
ctx,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
|
|
||||||
# keep track of which models have been blended
|
# make sure blending only happens once, not every run
|
||||||
blend_models = {}
|
if converted:
|
||||||
|
# keep track of which models have been blended
|
||||||
|
blend_models = {}
|
||||||
|
|
||||||
inversion_dest = path.join(ctx.model_path, "inversion")
|
inversion_dest = path.join(ctx.model_path, "inversion")
|
||||||
lora_dest = path.join(ctx.model_path, "lora")
|
lora_dest = path.join(ctx.model_path, "lora")
|
||||||
|
|
||||||
for inversion in model.get("inversions", []):
|
for inversion in model.get("inversions", []):
|
||||||
if "text_encoder" not in blend_models:
|
if "text_encoder" not in blend_models:
|
||||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
||||||
|
|
||||||
if "tokenizer" not in blend_models:
|
if "tokenizer" not in blend_models:
|
||||||
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer")
|
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(path.join(ctx.model_path, model), subfolder="tokenizer")
|
||||||
|
|
||||||
inversion_name = inversion["name"]
|
inversion_name = inversion["name"]
|
||||||
inversion_source = inversion["source"]
|
inversion_source = inversion["source"]
|
||||||
inversion_format = inversion.get("format", "embeddings")
|
inversion_format = inversion.get("format", "embeddings")
|
||||||
inversion_source = fetch_model(
|
inversion_source = fetch_model(
|
||||||
ctx,
|
ctx,
|
||||||
f"{name}-inversion-{inversion_name}",
|
f"{name}-inversion-{inversion_name}",
|
||||||
inversion_source,
|
inversion_source,
|
||||||
dest=inversion_dest,
|
dest=inversion_dest,
|
||||||
)
|
|
||||||
inversion_token = inversion.get("token", inversion_name)
|
|
||||||
inversion_weight = inversion.get("weight", 1.0)
|
|
||||||
|
|
||||||
blend_textual_inversions(
|
|
||||||
ctx,
|
|
||||||
blend_models["text_encoder"],
|
|
||||||
blend_models["tokenizer"],
|
|
||||||
[inversion_source],
|
|
||||||
[inversion_format],
|
|
||||||
base_token=inversion_token,
|
|
||||||
inversion_weights=[inversion_weight],
|
|
||||||
)
|
|
||||||
|
|
||||||
for lora in model.get("loras", []):
|
|
||||||
if "text_encoder" not in blend_models:
|
|
||||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
|
||||||
|
|
||||||
if "unet" not in blend_models:
|
|
||||||
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx"))
|
|
||||||
|
|
||||||
# load models if not loaded yet
|
|
||||||
lora_name = lora["name"]
|
|
||||||
lora_source = lora["source"]
|
|
||||||
lora_source = fetch_model(
|
|
||||||
ctx,
|
|
||||||
f"{name}-lora-{lora_name}",
|
|
||||||
lora_source
|
|
||||||
dest=lora_dest,
|
|
||||||
)
|
|
||||||
lora_weight = lora.get("weight", 1.0)
|
|
||||||
|
|
||||||
blend_loras(
|
|
||||||
ctx,
|
|
||||||
blend_models["text_encoder"],
|
|
||||||
[lora_name],
|
|
||||||
[lora_source],
|
|
||||||
"text_encoder",
|
|
||||||
lora_weights=[lora_weight],
|
|
||||||
)
|
|
||||||
|
|
||||||
if "tokenizer" in blend_models:
|
|
||||||
dest_path = path.join(ctx.model_path, model, "tokenizer")
|
|
||||||
logger.debug("saving blended tokenizer to %s", dest_path)
|
|
||||||
blend_models["tokenizer"].save_pretrained(dest_path)
|
|
||||||
|
|
||||||
for name in ["text_encoder", "unet"]:
|
|
||||||
if name in blend_models:
|
|
||||||
dest_path = path.join(ctx.model_path, model, name, "model.onnx")
|
|
||||||
logger.debug("saving blended %s model to %s", name, dest_path)
|
|
||||||
save_model(
|
|
||||||
blend_models[name],
|
|
||||||
dest_path,
|
|
||||||
save_as_external_data=True,
|
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location="weights.pb",
|
|
||||||
)
|
)
|
||||||
|
inversion_token = inversion.get("token", inversion_name)
|
||||||
|
inversion_weight = inversion.get("weight", 1.0)
|
||||||
|
|
||||||
|
blend_textual_inversions(
|
||||||
|
ctx,
|
||||||
|
blend_models["text_encoder"],
|
||||||
|
blend_models["tokenizer"],
|
||||||
|
[inversion_source],
|
||||||
|
[inversion_format],
|
||||||
|
base_token=inversion_token,
|
||||||
|
inversion_weights=[inversion_weight],
|
||||||
|
)
|
||||||
|
|
||||||
|
for lora in model.get("loras", []):
|
||||||
|
if "text_encoder" not in blend_models:
|
||||||
|
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "text_encoder", "model.onnx"))
|
||||||
|
|
||||||
|
if "unet" not in blend_models:
|
||||||
|
blend_models["text_encoder"] = load_model(path.join(ctx.model_path, model, "unet", "model.onnx"))
|
||||||
|
|
||||||
|
# load models if not loaded yet
|
||||||
|
lora_name = lora["name"]
|
||||||
|
lora_source = lora["source"]
|
||||||
|
lora_source = fetch_model(
|
||||||
|
ctx,
|
||||||
|
f"{name}-lora-{lora_name}",
|
||||||
|
lora_source,
|
||||||
|
dest=lora_dest,
|
||||||
|
)
|
||||||
|
lora_weight = lora.get("weight", 1.0)
|
||||||
|
|
||||||
|
blend_loras(
|
||||||
|
ctx,
|
||||||
|
blend_models["text_encoder"],
|
||||||
|
[lora_name],
|
||||||
|
[lora_source],
|
||||||
|
"text_encoder",
|
||||||
|
lora_weights=[lora_weight],
|
||||||
|
)
|
||||||
|
|
||||||
|
if "tokenizer" in blend_models:
|
||||||
|
dest_path = path.join(ctx.model_path, model, "tokenizer")
|
||||||
|
logger.debug("saving blended tokenizer to %s", dest_path)
|
||||||
|
blend_models["tokenizer"].save_pretrained(dest_path)
|
||||||
|
|
||||||
|
for name in ["text_encoder", "unet"]:
|
||||||
|
if name in blend_models:
|
||||||
|
dest_path = path.join(ctx.model_path, model, name, "model.onnx")
|
||||||
|
logger.debug("saving blended %s model to %s", name, dest_path)
|
||||||
|
save_model(
|
||||||
|
blend_models[name],
|
||||||
|
dest_path,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location="weights.pb",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -88,7 +88,7 @@ def convert_diffusion_diffusers(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
|
@ -111,7 +111,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
if path.exists(dest_path) and path.exists(model_index):
|
if path.exists(dest_path) and path.exists(model_index):
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return
|
return (False, dest_path)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
source,
|
source,
|
||||||
|
@ -328,3 +328,5 @@ def convert_diffusion_diffusers(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("ONNX pipeline is loadable")
|
logger.info("ONNX pipeline is loadable")
|
||||||
|
|
||||||
|
return (True, dest_path)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
@ -1658,7 +1658,7 @@ def convert_diffusion_original(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
) -> Tuple[bool, str]:
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = source or model["source"]
|
source = source or model["source"]
|
||||||
|
|
||||||
|
@ -1670,7 +1670,7 @@ def convert_diffusion_original(
|
||||||
|
|
||||||
if os.path.exists(dest_path) and os.path.exists(dest_index):
|
if os.path.exists(dest_path) and os.path.exists(dest_index):
|
||||||
logger.info("ONNX pipeline already exists, skipping")
|
logger.info("ONNX pipeline already exists, skipping")
|
||||||
return
|
return (False, dest_path)
|
||||||
|
|
||||||
torch_name = name + "-torch"
|
torch_name = name + "-torch"
|
||||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||||
|
@ -1698,10 +1698,12 @@ def convert_diffusion_original(
|
||||||
if "vae" in model:
|
if "vae" in model:
|
||||||
del model["vae"]
|
del model["vae"]
|
||||||
|
|
||||||
convert_diffusion_diffusers(ctx, model, working_name)
|
result = convert_diffusion_diffusers(ctx, model, working_name)
|
||||||
|
|
||||||
if "torch" in ctx.prune:
|
if "torch" in ctx.prune:
|
||||||
logger.info("removing intermediate Torch models: %s", torch_path)
|
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||||
shutil.rmtree(torch_path)
|
shutil.rmtree(torch_path)
|
||||||
|
|
||||||
logger.info("ONNX pipeline saved to %s", name)
|
logger.info("ONNX pipeline saved to %s", name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue