1
0
Fork 0

feat(api): support for SDXL LoRAs

This commit is contained in:
Sean Sube 2023-08-25 23:33:41 -05:00
parent 0ad250251e
commit bbacfd1ca0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 196 additions and 26 deletions

View File

@ -57,16 +57,12 @@ class BlendImg2ImgStage(BaseStage):
) )
pipe_params = {} pipe_params = {}
if params.is_control(): if params.is_pix2pix():
pipe_params["controlnet_conditioning_scale"] = strength
elif params.is_lpw():
pipe_params["strength"] = strength
elif params.is_panorama():
pipe_params["strength"] = strength
elif pipe_type == "img2img" or pipe_type == "img2img-sdxl":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength pipe_params["image_guidance_scale"] = strength
elif params.is_control():
pipe_params["controlnet_conditioning_scale"] = strength
else:
pipe_params["strength"] = strength
outputs = [] outputs = []
for source in sources: for source in sources:

View File

@ -1,17 +1,19 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Dict, List, Literal, Tuple, Union from re import sub
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from onnx import ModelProto, load, numpy_helper from onnx import ModelProto, load, numpy_helper, save_model
from onnx.checker import check_model from onnx.checker import check_model
from onnx.external_data_helper import ( from onnx.external_data_helper import (
convert_model_to_external_data, convert_model_to_external_data,
set_external_data, set_external_data,
write_external_data_tensors, write_external_data_tensors,
) )
from onnx.helper import tensor_dtype_to_np_dtype
from onnxruntime import InferenceSession, OrtValue, SessionOptions from onnxruntime import InferenceSession, OrtValue, SessionOptions
from ...server.context import ServerContext from ...server.context import ServerContext
@ -77,6 +79,67 @@ def fix_node_name(key: str):
return fixed_name return fixed_name
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
fixed = {}
for key, value in keys.items():
root, *rest = key.split(".")
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
if root.startswith("input"):
block = "down_blocks"
elif root.startswith("middle"):
block = "mid_block" # not plural
elif root.startswith("output"):
block = "up_blocks"
elif root.startswith("text_model"):
block = "text_model"
else:
logger.warning("unknown XL key name: %s", key)
fixed[key] = value
continue
suffix = None
for s in ["fc1", "fc2", "ff_net_0_proj", "ff_net_2", "proj", "proj_in", "proj_out", "to_k", "to_out_0", "to_q", "to_v"]:
if root.endswith(s):
suffix = s
if suffix is None:
logger.warning("new XL key type: %s", root)
continue
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
if block == "text_model":
matches = [
node for node in nodes
if fix_node_name(node.name) == f"{root}_MatMul"
]
else:
matches = [
node for node in nodes
if node.name.startswith(f"/{block}")
and fix_node_name(node.name).endswith(f"{suffix}_MatMul") # needs to be fixed because some places use to_out.0
]
if len(matches) == 0:
logger.warning("no matches for XL key: %s", root)
continue
name: str = matches[0].name
name = fix_node_name(name.rstrip("/MatMul"))
if name.endswith("proj_o"):
# wtf
name = f"{name}ut"
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
fixed[name] = value
nodes.remove(matches[0])
return fixed
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]: def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
return ( return (
min(x, shape[2] - 1), min(x, shape[2] - 1),
@ -84,11 +147,13 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
) )
def blend_loras( def blend_loras(
_conversion: ServerContext, _conversion: ServerContext,
base_name: Union[str, ModelProto], base_name: Union[str, ModelProto],
loras: List[Tuple[str, float]], loras: List[Tuple[str, float]],
model_type: Literal["text_encoder", "unet"], model_type: Literal["text_encoder", "unet"],
model_index: Optional[int] = None,
): ):
# always load to CPU for blending # always load to CPU for blending
device = torch.device("cpu") device = torch.device("cpu")
@ -98,7 +163,10 @@ def blend_loras(
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras] lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
if model_type == "text_encoder": if model_type == "text_encoder":
if model_index is None:
lora_prefix = "lora_te_" lora_prefix = "lora_te_"
else:
lora_prefix = f"lora_te{model_index}_"
else: else:
lora_prefix = f"lora_{model_type}_" lora_prefix = f"lora_{model_type}_"
@ -313,11 +381,15 @@ def blend_loras(
else: else:
blended[base_key] = np_weights blended[base_key] = np_weights
# rewrite node names for XL
nodes = list(base_model.graph.node)
blended = fix_xl_names(blended, nodes)
logger.trace( logger.trace(
"updating %s of %s initializers: %s", "updating %s of %s initializers, %s missed",
len(blended.keys()), len(blended.keys()),
len(base_model.graph.initializer), len(base_model.graph.initializer),
list(blended.keys()), len(nodes)
) )
fixed_initializer_names = [ fixed_initializer_names = [
@ -419,8 +491,13 @@ def blend_loras(
onnx_weights.shape, onnx_weights.shape,
) )
blended = onnx_weights + weights.transpose() t_weights = weights.transpose()
logger.trace("blended weight shape: %s", blended.shape) if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
logger.warning("weight shapes do not match for %s: %s vs %s", matmul_key, weights.shape, onnx_weights.shape)
t_weights = interp_to_match(weights, onnx_weights).transpose()
blended = onnx_weights + t_weights
logger.debug("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array( updated_node = numpy_helper.from_array(
@ -442,9 +519,29 @@ def blend_loras(
if len(unmatched_keys) > 0: if len(unmatched_keys) > 0:
logger.warning("could not find nodes for some keys: %s", unmatched_keys) logger.warning("could not find nodes for some keys: %s", unmatched_keys)
# if model_type == "unet":
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
return base_model return base_model
from scipy import interpolate
def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
res_x = np.linspace(0, 1, resize.shape[0])
res_y = np.linspace(0, 1, resize.shape[1])
ref_x = np.linspace(0, 1, ref.shape[0])
ref_y = np.linspace(0, 1, ref.shape[1])
logger.debug("dims: %s, %s, %s, %s", resize.shape[0], resize.shape[1], ref.shape[0], ref.shape[1])
f = interpolate.RegularGridInterpolator((ref_x, ref_y), ref, method='linear')
xg, yg = np.meshgrid(res_x, res_y)
output = f((xg, yg))
logger.debug("weights after interpolation: %s", output.shape)
return output
if __name__ == "__main__": if __name__ == "__main__":
context = ConversionContext.from_environ() context = ConversionContext.from_environ()
parser = ArgumentParser() parser = ArgumentParser()

View File

@ -45,6 +45,7 @@ from .version_safe_diffusers import (
StableDiffusionPipeline, StableDiffusionPipeline,
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
from ..torch_before_ort import InferenceSession
logger = getLogger(__name__) logger = getLogger(__name__)
@ -95,6 +96,7 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
return None return None
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet, ORTModelTextEncoder
def load_pipeline( def load_pipeline(
server: ServerContext, server: ServerContext,
@ -244,6 +246,7 @@ def load_pipeline(
text_encoder, text_encoder,
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"text_encoder", "text_encoder",
1 if params.is_xl() else None,
) )
(text_encoder, text_encoder_data) = buffer_external_data_tensors( (text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder text_encoder
@ -253,6 +256,16 @@ def load_pipeline(
text_encoder_opts.add_external_initializers( text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values) list(text_encoder_names), list(text_encoder_values)
) )
if params.is_xl():
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder"] = ORTModelTextEncoder(text_encoder_session, text_encoder)
else:
components["text_encoder"] = OnnxRuntimeModel( components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(), text_encoder.SerializeToString(),
@ -261,6 +274,32 @@ def load_pipeline(
) )
) )
if params.is_xl():
text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder2 = blend_loras(
server,
text_encoder2,
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
)
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
text_encoder2
)
text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
text_encoder2_opts = device.sess_options(cache=False)
text_encoder2_opts.add_external_initializers(
list(text_encoder2_names), list(text_encoder2_values)
)
text_encoder2_session = InferenceSession(
text_encoder2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder2_opts,
)
text_encoder2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2"] = ORTModelTextEncoder(text_encoder2_session, text_encoder2)
# blend and load unet # blend and load unet
unet = path.join(model, unet_type, ONNX_MODEL) unet = path.join(model, unet_type, ONNX_MODEL)
blended_unet = blend_loras( blended_unet = blend_loras(
@ -273,6 +312,16 @@ def load_pipeline(
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False) unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values)) unet_opts.add_external_initializers(list(unet_names), list(unet_values))
if params.is_xl():
unet_session = InferenceSession(
unet_model.SerializeToString(),
providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet"] = ORTModelUnet(unet_session, unet_model)
else:
components["unet"] = OnnxRuntimeModel( components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
unet_model.SerializeToString(), unet_model.SerializeToString(),
@ -344,6 +393,15 @@ def load_pipeline(
**components, **components,
) )
# make sure XL models are actually being used
# TODO: why is this needed?
logger.info("text encoder matches: %s, %s", pipe.text_encoder == components["text_encoder"], type(pipe.text_encoder))
pipe.text_encoder = components["text_encoder"]
logger.info("text encoder 2 matches: %s, %s", pipe.text_encoder_2 == components["text_encoder_2"], type(pipe.text_encoder_2))
pipe.text_encoder_2 = components["text_encoder_2"]
logger.info("unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet))
pipe.unet = components["unet"]
if not server.show_progress: if not server.show_progress:
pipe.set_progress_bar_config(disable=True) pipe.set_progress_bar_config(disable=True)

View File

@ -48,6 +48,12 @@ def run_txt2img_pipeline(
else: else:
tile_size = params.tiles tile_size = params.tiles
# split prompts for each stage
if "||" in params.prompt:
txt_prompt, img_prompt = params.prompt.split("||")
else:
txt_prompt = img_prompt = params.prompt
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
chain.stage( chain.stage(

View File

@ -436,3 +436,11 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
return np.reshape(flat_tile[indices], tile.shape) return np.reshape(flat_tile[indices], tile.shape)
else: else:
return tile return tile
def slice_prompt(prompt: str, slice: int) -> str:
if "||" in prompt:
parts = prompt.split("||")
return parts[min(slice, len(parts) - 1)]
else:
return prompt

View File

@ -4,7 +4,7 @@ from logging import getLogger
from os import path from os import path
from struct import pack from struct import pack
from time import time from time import time
from typing import Any, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from piexif import ExifIFD, ImageIFD, dump from piexif import ExifIFD, ImageIFD, dump
from piexif.helper import UserComment from piexif.helper import UserComment
@ -57,8 +57,10 @@ def json_params(
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
parent: Dict = None,
) -> Any: ) -> Any:
json = { json = {
"input_size": size.tojson(),
"outputs": outputs, "outputs": outputs,
"params": params.tojson(), "params": params.tojson(),
} }
@ -66,6 +68,7 @@ def json_params(
json["params"]["model"] = path.basename(params.model) json["params"]["model"] = path.basename(params.model)
json["params"]["scheduler"] = params.scheduler json["params"]["scheduler"] = params.scheduler
# calculate final output size
output_size = size output_size = size
if border is not None: if border is not None:
json["border"] = border.tojson() json["border"] = border.tojson()
@ -79,7 +82,6 @@ def json_params(
json["upscale"] = upscale.tojson() json["upscale"] = upscale.tojson()
output_size = upscale.resize(output_size) output_size = upscale.resize(output_size)
json["input_size"] = size.tojson()
json["size"] = output_size.tojson() json["size"] = output_size.tojson()
return json return json

View File

@ -282,6 +282,9 @@ class ImageParams:
def is_panorama(self): def is_panorama(self):
return self.pipeline == "panorama" return self.pipeline == "panorama"
def is_pix2pix(self):
return self.pipeline == "pix2pix"
def is_xl(self): def is_xl(self):
return self.pipeline.endswith("-sdxl") return self.pipeline.endswith("-sdxl")