1
0
Fork 0

make blending happen once after conversion

This commit is contained in:
Sean Sube 2023-03-18 07:14:22 -05:00
parent 32b2a76a0b
commit c3979246df
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 88 additions and 81 deletions

View File

@ -223,19 +223,22 @@ def convert_models(ctx: ConversionContext, args, models: Models):
ctx, name, model["source"], format=model_format ctx, name, model["source"], format=model_format
) )
converted = False
if model_format in model_formats_original: if model_format in model_formats_original:
convert_diffusion_original( converted, _dest = convert_diffusion_original(
ctx, ctx,
model, model,
source, source,
) )
else: else:
convert_diffusion_diffusers( converted, _dest = convert_diffusion_diffusers(
ctx, ctx,
model, model,
source, source,
) )
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended # keep track of which models have been blended
blend_models = {} blend_models = {}
@ -284,7 +287,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
lora_source = fetch_model( lora_source = fetch_model(
ctx, ctx,
f"{name}-lora-{lora_name}", f"{name}-lora-{lora_name}",
lora_source lora_source,
dest=lora_dest, dest=lora_dest,
) )
lora_weight = lora.get("weight", 1.0) lora_weight = lora.get("weight", 1.0)

View File

@ -88,7 +88,7 @@ def convert_diffusion_diffusers(
ctx: ConversionContext, ctx: ConversionContext,
model: Dict, model: Dict,
source: str, source: str,
): ) -> Tuple[bool, str]:
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
@ -111,7 +111,7 @@ def convert_diffusion_diffusers(
if path.exists(dest_path) and path.exists(model_index): if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping") logger.info("ONNX model already exists, skipping")
return return (False, dest_path)
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
source, source,
@ -328,3 +328,5 @@ def convert_diffusion_diffusers(
) )
logger.info("ONNX pipeline is loadable") logger.info("ONNX pipeline is loadable")
return (True, dest_path)

View File

@ -18,7 +18,7 @@ import os
import re import re
import shutil import shutil
from logging import getLogger from logging import getLogger
from typing import Dict, List from typing import Dict, List, Tuple
import torch import torch
from diffusers import ( from diffusers import (
@ -1658,7 +1658,7 @@ def convert_diffusion_original(
ctx: ConversionContext, ctx: ConversionContext,
model: ModelDict, model: ModelDict,
source: str, source: str,
): ) -> Tuple[bool, str]:
name = model["name"] name = model["name"]
source = source or model["source"] source = source or model["source"]
@ -1670,7 +1670,7 @@ def convert_diffusion_original(
if os.path.exists(dest_path) and os.path.exists(dest_index): if os.path.exists(dest_path) and os.path.exists(dest_index):
logger.info("ONNX pipeline already exists, skipping") logger.info("ONNX pipeline already exists, skipping")
return return (False, dest_path)
torch_name = name + "-torch" torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name) torch_path = os.path.join(ctx.cache_path, torch_name)
@ -1698,10 +1698,12 @@ def convert_diffusion_original(
if "vae" in model: if "vae" in model:
del model["vae"] del model["vae"]
convert_diffusion_diffusers(ctx, model, working_name) result = convert_diffusion_diffusers(ctx, model, working_name)
if "torch" in ctx.prune: if "torch" in ctx.prune:
logger.info("removing intermediate Torch models: %s", torch_path) logger.info("removing intermediate Torch models: %s", torch_path)
shutil.rmtree(torch_path) shutil.rmtree(torch_path)
logger.info("ONNX pipeline saved to %s", name) logger.info("ONNX pipeline saved to %s", name)
return result