1
0
Fork 0

make blending happen once after conversion

This commit is contained in:
Sean Sube 2023-03-18 07:14:22 -05:00
parent 32b2a76a0b
commit c3979246df
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 88 additions and 81 deletions

View File

@ -223,19 +223,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
ctx, name, model["source"], format=model_format
)
converted = False
if model_format in model_formats_original:
convert_diffusion_original(
converted, _dest = convert_diffusion_original(
ctx,
model,
source,
)
else:
convert_diffusion_diffusers(
converted, _dest = convert_diffusion_diffusers(
ctx,
model,
source,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
@ -284,7 +287,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_source = fetch_model(
ctx,
f"{name}-lora-{lora_name}",
lora_source
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)

View File

@ -88,7 +88,7 @@ def convert_diffusion_diffusers(
ctx: ConversionContext,
model: Dict,
source: str,
):
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
@ -111,7 +111,7 @@ def convert_diffusion_diffusers(
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping")
return
return (False, dest_path)
pipeline = StableDiffusionPipeline.from_pretrained(
source,
@ -328,3 +328,5 @@ def convert_diffusion_diffusers(
)
logger.info("ONNX pipeline is loadable")
return (True, dest_path)

View File

@ -18,7 +18,7 @@ import os
import re
import shutil
from logging import getLogger
from typing import Dict, List
from typing import Dict, List, Tuple
import torch
from diffusers import (
@ -1658,7 +1658,7 @@ def convert_diffusion_original(
ctx: ConversionContext,
model: ModelDict,
source: str,
):
) -> Tuple[bool, str]:
name = model["name"]
source = source or model["source"]
@ -1670,7 +1670,7 @@ def convert_diffusion_original(
if os.path.exists(dest_path) and os.path.exists(dest_index):
logger.info("ONNX pipeline already exists, skipping")
return
return (False, dest_path)
torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name)
@ -1698,10 +1698,12 @@ def convert_diffusion_original(
if "vae" in model:
del model["vae"]
convert_diffusion_diffusers(ctx, model, working_name)
result = convert_diffusion_diffusers(ctx, model, working_name)
if "torch" in ctx.prune:
logger.info("removing intermediate Torch models: %s", torch_path)
shutil.rmtree(torch_path)
logger.info("ONNX pipeline saved to %s", name)
return result