make blending happen once after conversion
This commit is contained in:
parent
32b2a76a0b
commit
c3979246df
|
@ -223,19 +223,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
ctx, name, model["source"], format=model_format
|
||||
)
|
||||
|
||||
converted = False
|
||||
if model_format in model_formats_original:
|
||||
convert_diffusion_original(
|
||||
converted, _dest = convert_diffusion_original(
|
||||
ctx,
|
||||
model,
|
||||
source,
|
||||
)
|
||||
else:
|
||||
convert_diffusion_diffusers(
|
||||
converted, _dest = convert_diffusion_diffusers(
|
||||
ctx,
|
||||
model,
|
||||
source,
|
||||
)
|
||||
|
||||
# make sure blending only happens once, not every run
|
||||
if converted:
|
||||
# keep track of which models have been blended
|
||||
blend_models = {}
|
||||
|
||||
|
@ -284,7 +287,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
|||
lora_source = fetch_model(
|
||||
ctx,
|
||||
f"{name}-lora-{lora_name}",
|
||||
lora_source
|
||||
lora_source,
|
||||
dest=lora_dest,
|
||||
)
|
||||
lora_weight = lora.get("weight", 1.0)
|
||||
|
|
|
@ -88,7 +88,7 @@ def convert_diffusion_diffusers(
|
|||
ctx: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
):
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
|
@ -111,7 +111,7 @@ def convert_diffusion_diffusers(
|
|||
|
||||
if path.exists(dest_path) and path.exists(model_index):
|
||||
logger.info("ONNX model already exists, skipping")
|
||||
return
|
||||
return (False, dest_path)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
source,
|
||||
|
@ -328,3 +328,5 @@ def convert_diffusion_diffusers(
|
|||
)
|
||||
|
||||
logger.info("ONNX pipeline is loadable")
|
||||
|
||||
return (True, dest_path)
|
||||
|
|
|
@ -18,7 +18,7 @@ import os
|
|||
import re
|
||||
import shutil
|
||||
from logging import getLogger
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
|
@ -1658,7 +1658,7 @@ def convert_diffusion_original(
|
|||
ctx: ConversionContext,
|
||||
model: ModelDict,
|
||||
source: str,
|
||||
):
|
||||
) -> Tuple[bool, str]:
|
||||
name = model["name"]
|
||||
source = source or model["source"]
|
||||
|
||||
|
@ -1670,7 +1670,7 @@ def convert_diffusion_original(
|
|||
|
||||
if os.path.exists(dest_path) and os.path.exists(dest_index):
|
||||
logger.info("ONNX pipeline already exists, skipping")
|
||||
return
|
||||
return (False, dest_path)
|
||||
|
||||
torch_name = name + "-torch"
|
||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||
|
@ -1698,10 +1698,12 @@ def convert_diffusion_original(
|
|||
if "vae" in model:
|
||||
del model["vae"]
|
||||
|
||||
convert_diffusion_diffusers(ctx, model, working_name)
|
||||
result = convert_diffusion_diffusers(ctx, model, working_name)
|
||||
|
||||
if "torch" in ctx.prune:
|
||||
logger.info("removing intermediate Torch models: %s", torch_path)
|
||||
shutil.rmtree(torch_path)
|
||||
|
||||
logger.info("ONNX pipeline saved to %s", name)
|
||||
return result
|
||||
|
||||
|
|
Loading…
Reference in New Issue