make blending happen once after conversion
This commit is contained in:
parent
32b2a76a0b
commit
c3979246df
|
@ -223,19 +223,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
ctx, name, model["source"], format=model_format
|
ctx, name, model["source"], format=model_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
converted = False
|
||||||
if model_format in model_formats_original:
|
if model_format in model_formats_original:
|
||||||
convert_diffusion_original(
|
converted, _dest = convert_diffusion_original(
|
||||||
ctx,
|
ctx,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
convert_diffusion_diffusers(
|
converted, _dest = convert_diffusion_diffusers(
|
||||||
ctx,
|
ctx,
|
||||||
model,
|
model,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# make sure blending only happens once, not every run
|
||||||
|
if converted:
|
||||||
# keep track of which models have been blended
|
# keep track of which models have been blended
|
||||||
blend_models = {}
|
blend_models = {}
|
||||||
|
|
||||||
|
@ -284,7 +287,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
lora_source = fetch_model(
|
lora_source = fetch_model(
|
||||||
ctx,
|
ctx,
|
||||||
f"{name}-lora-{lora_name}",
|
f"{name}-lora-{lora_name}",
|
||||||
lora_source
|
lora_source,
|
||||||
dest=lora_dest,
|
dest=lora_dest,
|
||||||
)
|
)
|
||||||
lora_weight = lora.get("weight", 1.0)
|
lora_weight = lora.get("weight", 1.0)
|
||||||
|
|
|
@ -88,7 +88,7 @@ def convert_diffusion_diffusers(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
|
@ -111,7 +111,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
if path.exists(dest_path) and path.exists(model_index):
|
if path.exists(dest_path) and path.exists(model_index):
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return
|
return (False, dest_path)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
source,
|
source,
|
||||||
|
@ -328,3 +328,5 @@ def convert_diffusion_diffusers(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("ONNX pipeline is loadable")
|
logger.info("ONNX pipeline is loadable")
|
||||||
|
|
||||||
|
return (True, dest_path)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
@ -1658,7 +1658,7 @@ def convert_diffusion_original(
|
||||||
ctx: ConversionContext,
|
ctx: ConversionContext,
|
||||||
model: ModelDict,
|
model: ModelDict,
|
||||||
source: str,
|
source: str,
|
||||||
):
|
) -> Tuple[bool, str]:
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = source or model["source"]
|
source = source or model["source"]
|
||||||
|
|
||||||
|
@ -1670,7 +1670,7 @@ def convert_diffusion_original(
|
||||||
|
|
||||||
if os.path.exists(dest_path) and os.path.exists(dest_index):
|
if os.path.exists(dest_path) and os.path.exists(dest_index):
|
||||||
logger.info("ONNX pipeline already exists, skipping")
|
logger.info("ONNX pipeline already exists, skipping")
|
||||||
return
|
return (False, dest_path)
|
||||||
|
|
||||||
torch_name = name + "-torch"
|
torch_name = name + "-torch"
|
||||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||||
|
@ -1698,10 +1698,12 @@ def convert_diffusion_original(
|
||||||
if "vae" in model:
|
if "vae" in model:
|
||||||
del model["vae"]
|
del model["vae"]
|
||||||
|
|
||||||
convert_diffusion_diffusers(ctx, model, working_name)
|
result = convert_diffusion_diffusers(ctx, model, working_name)
|
||||||
|
|
||||||
if "torch" in ctx.prune:
|
if "torch" in ctx.prune:
|
||||||
logger.info("removing intermediate Torch models: %s", torch_path)
|
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||||
shutil.rmtree(torch_path)
|
shutil.rmtree(torch_path)
|
||||||
|
|
||||||
logger.info("ONNX pipeline saved to %s", name)
|
logger.info("ONNX pipeline saved to %s", name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue