1
0
Fork 0

feat(api): support for SDXL LoRAs

This commit is contained in:
Sean Sube 2023-08-25 23:33:41 -05:00
parent 0ad250251e
commit bbacfd1ca0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 196 additions and 26 deletions

View File

@ -57,16 +57,12 @@ class BlendImg2ImgStage(BaseStage):
)
pipe_params = {}
if params.is_control():
pipe_params["controlnet_conditioning_scale"] = strength
elif params.is_lpw():
pipe_params["strength"] = strength
elif params.is_panorama():
pipe_params["strength"] = strength
elif pipe_type == "img2img" or pipe_type == "img2img-sdxl":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
if params.is_pix2pix():
pipe_params["image_guidance_scale"] = strength
elif params.is_control():
pipe_params["controlnet_conditioning_scale"] = strength
else:
pipe_params["strength"] = strength
outputs = []
for source in sources:

View File

@ -1,17 +1,19 @@
from argparse import ArgumentParser
from logging import getLogger
from os import path
from typing import Dict, List, Literal, Tuple, Union
from re import sub
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
from onnx import ModelProto, load, numpy_helper
from onnx import ModelProto, load, numpy_helper, save_model
from onnx.checker import check_model
from onnx.external_data_helper import (
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
)
from onnx.helper import tensor_dtype_to_np_dtype
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from ...server.context import ServerContext
@ -77,6 +79,67 @@ def fix_node_name(key: str):
return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
fixed = {}
for key, value in keys.items():
root, *rest = key.split(".")
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
if root.startswith("input"):
block = "down_blocks"
elif root.startswith("middle"):
block = "mid_block" # not plural
elif root.startswith("output"):
block = "up_blocks"
elif root.startswith("text_model"):
block = "text_model"
else:
logger.warning("unknown XL key name: %s", key)
fixed[key] = value
continue
suffix = None
for s in ["fc1", "fc2", "ff_net_0_proj", "ff_net_2", "proj", "proj_in", "proj_out", "to_k", "to_out_0", "to_q", "to_v"]:
if root.endswith(s):
suffix = s
if suffix is None:
logger.warning("new XL key type: %s", root)
continue
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
if block == "text_model":
matches = [
node for node in nodes
if fix_node_name(node.name) == f"{root}_MatMul"
]
else:
matches = [
node for node in nodes
if node.name.startswith(f"/{block}")
and fix_node_name(node.name).endswith(f"{suffix}_MatMul") # needs to be fixed because some places use to_out.0
]
if len(matches) == 0:
logger.warning("no matches for XL key: %s", root)
continue
name: str = matches[0].name
name = fix_node_name(name.rstrip("/MatMul"))
if name.endswith("proj_o"):
# wtf
name = f"{name}ut"
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
fixed[name] = value
nodes.remove(matches[0])
return fixed
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
return (
min(x, shape[2] - 1),
@ -84,11 +147,13 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
)
def blend_loras(
_conversion: ServerContext,
base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
):
# always load to CPU for blending
device = torch.device("cpu")
@ -98,7 +163,10 @@ def blend_loras(
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder":
if model_index is None:
lora_prefix = "lora_te_"
else:
lora_prefix = f"lora_te{model_index}_"
else:
lora_prefix = f"lora_{model_type}_"
@ -313,11 +381,15 @@ def blend_loras(
else:
blended[base_key] = np_weights
# rewrite node names for XL
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.trace(
"updating %s of %s initializers: %s",
"updating %s of %s initializers, %s missed",
len(blended.keys()),
len(base_model.graph.initializer),
list(blended.keys()),
len(nodes)
)
fixed_initializer_names = [
@ -419,8 +491,13 @@ def blend_loras(
onnx_weights.shape,
)
blended = onnx_weights + weights.transpose()
logger.trace("blended weight shape: %s", blended.shape)
t_weights = weights.transpose()
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
logger.warning("weight shapes do not match for %s: %s vs %s", matmul_key, weights.shape, onnx_weights.shape)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.debug("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
# replace the original initializer
updated_node = numpy_helper.from_array(
@ -442,9 +519,29 @@ def blend_loras(
if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
# if model_type == "unet":
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
return base_model
from scipy import interpolate
def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
res_x = np.linspace(0, 1, resize.shape[0])
res_y = np.linspace(0, 1, resize.shape[1])
ref_x = np.linspace(0, 1, ref.shape[0])
ref_y = np.linspace(0, 1, ref.shape[1])
logger.debug("dims: %s, %s, %s, %s", resize.shape[0], resize.shape[1], ref.shape[0], ref.shape[1])
f = interpolate.RegularGridInterpolator((ref_x, ref_y), ref, method='linear')
xg, yg = np.meshgrid(res_x, res_y)
output = f((xg, yg))
logger.debug("weights after interpolation: %s", output.shape)
return output
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()

View File

@ -45,6 +45,7 @@ from .version_safe_diffusers import (
StableDiffusionPipeline,
UniPCMultistepScheduler,
)
from ..torch_before_ort import InferenceSession
logger = getLogger(__name__)
@ -95,6 +96,7 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
return None
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet, ORTModelTextEncoder
def load_pipeline(
server: ServerContext,
@ -244,6 +246,7 @@ def load_pipeline(
text_encoder,
list(zip(lora_models, lora_weights)),
"text_encoder",
1 if params.is_xl() else None,
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
@ -253,6 +256,16 @@ def load_pipeline(
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl():
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder"] = ORTModelTextEncoder(text_encoder_session, text_encoder)
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
@ -261,6 +274,32 @@ def load_pipeline(
)
)
if params.is_xl():
text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder2 = blend_loras(
server,
text_encoder2,
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
)
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
text_encoder2
)
text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
text_encoder2_opts = device.sess_options(cache=False)
text_encoder2_opts.add_external_initializers(
list(text_encoder2_names), list(text_encoder2_values)
)
text_encoder2_session = InferenceSession(
text_encoder2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder2_opts,
)
text_encoder2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2"] = ORTModelTextEncoder(text_encoder2_session, text_encoder2)
# blend and load unet
unet = path.join(model, unet_type, ONNX_MODEL)
blended_unet = blend_loras(
@ -273,6 +312,16 @@ def load_pipeline(
unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
if params.is_xl():
unet_session = InferenceSession(
unet_model.SerializeToString(),
providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet"] = ORTModelUnet(unet_session, unet_model)
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
@ -344,6 +393,15 @@ def load_pipeline(
**components,
)
# make sure XL models are actually being used
# TODO: why is this needed?
logger.info("text encoder matches: %s, %s", pipe.text_encoder == components["text_encoder"], type(pipe.text_encoder))
pipe.text_encoder = components["text_encoder"]
logger.info("text encoder 2 matches: %s, %s", pipe.text_encoder_2 == components["text_encoder_2"], type(pipe.text_encoder_2))
pipe.text_encoder_2 = components["text_encoder_2"]
logger.info("unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet))
pipe.unet = components["unet"]
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

View File

@ -48,6 +48,12 @@ def run_txt2img_pipeline(
else:
tile_size = params.tiles
# split prompts for each stage
if "||" in params.prompt:
txt_prompt, img_prompt = params.prompt.split("||")
else:
txt_prompt = img_prompt = params.prompt
# prepare the chain pipeline and first stage
chain = ChainPipeline()
chain.stage(

View File

@ -436,3 +436,11 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
return np.reshape(flat_tile[indices], tile.shape)
else:
return tile
def slice_prompt(prompt: str, slice: int) -> str:
if "||" in prompt:
parts = prompt.split("||")
return parts[min(slice, len(parts) - 1)]
else:
return prompt

View File

@ -4,7 +4,7 @@ from logging import getLogger
from os import path
from struct import pack
from time import time
from typing import Any, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from piexif import ExifIFD, ImageIFD, dump
from piexif.helper import UserComment
@ -57,8 +57,10 @@ def json_params(
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
parent: Dict = None,
) -> Any:
json = {
"input_size": size.tojson(),
"outputs": outputs,
"params": params.tojson(),
}
@ -66,6 +68,7 @@ def json_params(
json["params"]["model"] = path.basename(params.model)
json["params"]["scheduler"] = params.scheduler
# calculate final output size
output_size = size
if border is not None:
json["border"] = border.tojson()
@ -79,7 +82,6 @@ def json_params(
json["upscale"] = upscale.tojson()
output_size = upscale.resize(output_size)
json["input_size"] = size.tojson()
json["size"] = output_size.tojson()
return json

View File

@ -282,6 +282,9 @@ class ImageParams:
def is_panorama(self):
return self.pipeline == "panorama"
def is_pix2pix(self):
return self.pipeline == "pix2pix"
def is_xl(self):
return self.pipeline.endswith("-sdxl")