feat(api): support for SDXL LoRAs
This commit is contained in:
parent
0ad250251e
commit
bbacfd1ca0
|
@ -57,16 +57,12 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe_params = {}
|
pipe_params = {}
|
||||||
if params.is_control():
|
if params.is_pix2pix():
|
||||||
pipe_params["controlnet_conditioning_scale"] = strength
|
|
||||||
elif params.is_lpw():
|
|
||||||
pipe_params["strength"] = strength
|
|
||||||
elif params.is_panorama():
|
|
||||||
pipe_params["strength"] = strength
|
|
||||||
elif pipe_type == "img2img" or pipe_type == "img2img-sdxl":
|
|
||||||
pipe_params["strength"] = strength
|
|
||||||
elif pipe_type == "pix2pix":
|
|
||||||
pipe_params["image_guidance_scale"] = strength
|
pipe_params["image_guidance_scale"] = strength
|
||||||
|
elif params.is_control():
|
||||||
|
pipe_params["controlnet_conditioning_scale"] = strength
|
||||||
|
else:
|
||||||
|
pipe_params["strength"] = strength
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for source in sources:
|
for source in sources:
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Dict, List, Literal, Tuple, Union
|
from re import sub
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from onnx import ModelProto, load, numpy_helper
|
from onnx import ModelProto, load, numpy_helper, save_model
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from onnx.external_data_helper import (
|
from onnx.external_data_helper import (
|
||||||
convert_model_to_external_data,
|
convert_model_to_external_data,
|
||||||
set_external_data,
|
set_external_data,
|
||||||
write_external_data_tensors,
|
write_external_data_tensors,
|
||||||
)
|
)
|
||||||
|
from onnx.helper import tensor_dtype_to_np_dtype
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
|
@ -77,6 +79,67 @@ def fix_node_name(key: str):
|
||||||
return fixed_name
|
return fixed_name
|
||||||
|
|
||||||
|
|
||||||
|
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||||
|
fixed = {}
|
||||||
|
|
||||||
|
for key, value in keys.items():
|
||||||
|
root, *rest = key.split(".")
|
||||||
|
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
|
||||||
|
|
||||||
|
if root.startswith("input"):
|
||||||
|
block = "down_blocks"
|
||||||
|
elif root.startswith("middle"):
|
||||||
|
block = "mid_block" # not plural
|
||||||
|
elif root.startswith("output"):
|
||||||
|
block = "up_blocks"
|
||||||
|
elif root.startswith("text_model"):
|
||||||
|
block = "text_model"
|
||||||
|
else:
|
||||||
|
logger.warning("unknown XL key name: %s", key)
|
||||||
|
fixed[key] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
suffix = None
|
||||||
|
for s in ["fc1", "fc2", "ff_net_0_proj", "ff_net_2", "proj", "proj_in", "proj_out", "to_k", "to_out_0", "to_q", "to_v"]:
|
||||||
|
if root.endswith(s):
|
||||||
|
suffix = s
|
||||||
|
|
||||||
|
if suffix is None:
|
||||||
|
logger.warning("new XL key type: %s", root)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
||||||
|
if block == "text_model":
|
||||||
|
matches = [
|
||||||
|
node for node in nodes
|
||||||
|
if fix_node_name(node.name) == f"{root}_MatMul"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
matches = [
|
||||||
|
node for node in nodes
|
||||||
|
if node.name.startswith(f"/{block}")
|
||||||
|
and fix_node_name(node.name).endswith(f"{suffix}_MatMul") # needs to be fixed because some places use to_out.0
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(matches) == 0:
|
||||||
|
logger.warning("no matches for XL key: %s", root)
|
||||||
|
continue
|
||||||
|
|
||||||
|
name: str = matches[0].name
|
||||||
|
name = fix_node_name(name.rstrip("/MatMul"))
|
||||||
|
|
||||||
|
if name.endswith("proj_o"):
|
||||||
|
# wtf
|
||||||
|
name = f"{name}ut"
|
||||||
|
|
||||||
|
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
|
||||||
|
|
||||||
|
fixed[name] = value
|
||||||
|
nodes.remove(matches[0])
|
||||||
|
|
||||||
|
return fixed
|
||||||
|
|
||||||
|
|
||||||
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||||
return (
|
return (
|
||||||
min(x, shape[2] - 1),
|
min(x, shape[2] - 1),
|
||||||
|
@ -84,11 +147,13 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def blend_loras(
|
def blend_loras(
|
||||||
_conversion: ServerContext,
|
_conversion: ServerContext,
|
||||||
base_name: Union[str, ModelProto],
|
base_name: Union[str, ModelProto],
|
||||||
loras: List[Tuple[str, float]],
|
loras: List[Tuple[str, float]],
|
||||||
model_type: Literal["text_encoder", "unet"],
|
model_type: Literal["text_encoder", "unet"],
|
||||||
|
model_index: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
@ -98,7 +163,10 @@ def blend_loras(
|
||||||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||||
|
|
||||||
if model_type == "text_encoder":
|
if model_type == "text_encoder":
|
||||||
lora_prefix = "lora_te_"
|
if model_index is None:
|
||||||
|
lora_prefix = "lora_te_"
|
||||||
|
else:
|
||||||
|
lora_prefix = f"lora_te{model_index}_"
|
||||||
else:
|
else:
|
||||||
lora_prefix = f"lora_{model_type}_"
|
lora_prefix = f"lora_{model_type}_"
|
||||||
|
|
||||||
|
@ -313,11 +381,15 @@ def blend_loras(
|
||||||
else:
|
else:
|
||||||
blended[base_key] = np_weights
|
blended[base_key] = np_weights
|
||||||
|
|
||||||
|
# rewrite node names for XL
|
||||||
|
nodes = list(base_model.graph.node)
|
||||||
|
blended = fix_xl_names(blended, nodes)
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"updating %s of %s initializers: %s",
|
"updating %s of %s initializers, %s missed",
|
||||||
len(blended.keys()),
|
len(blended.keys()),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
list(blended.keys()),
|
len(nodes)
|
||||||
)
|
)
|
||||||
|
|
||||||
fixed_initializer_names = [
|
fixed_initializer_names = [
|
||||||
|
@ -419,8 +491,13 @@ def blend_loras(
|
||||||
onnx_weights.shape,
|
onnx_weights.shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
blended = onnx_weights + weights.transpose()
|
t_weights = weights.transpose()
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
|
||||||
|
logger.warning("weight shapes do not match for %s: %s vs %s", matmul_key, weights.shape, onnx_weights.shape)
|
||||||
|
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
||||||
|
|
||||||
|
blended = onnx_weights + t_weights
|
||||||
|
logger.debug("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(
|
updated_node = numpy_helper.from_array(
|
||||||
|
@ -442,9 +519,29 @@ def blend_loras(
|
||||||
if len(unmatched_keys) > 0:
|
if len(unmatched_keys) > 0:
|
||||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
||||||
|
|
||||||
|
# if model_type == "unet":
|
||||||
|
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
|
||||||
|
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
|
|
||||||
|
from scipy import interpolate
|
||||||
|
|
||||||
|
def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
|
||||||
|
res_x = np.linspace(0, 1, resize.shape[0])
|
||||||
|
res_y = np.linspace(0, 1, resize.shape[1])
|
||||||
|
ref_x = np.linspace(0, 1, ref.shape[0])
|
||||||
|
ref_y = np.linspace(0, 1, ref.shape[1])
|
||||||
|
logger.debug("dims: %s, %s, %s, %s", resize.shape[0], resize.shape[1], ref.shape[0], ref.shape[1])
|
||||||
|
|
||||||
|
f = interpolate.RegularGridInterpolator((ref_x, ref_y), ref, method='linear')
|
||||||
|
xg, yg = np.meshgrid(res_x, res_y)
|
||||||
|
output = f((xg, yg))
|
||||||
|
logger.debug("weights after interpolation: %s", output.shape)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
context = ConversionContext.from_environ()
|
context = ConversionContext.from_environ()
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
|
@ -45,6 +45,7 @@ from .version_safe_diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
UniPCMultistepScheduler,
|
UniPCMultistepScheduler,
|
||||||
)
|
)
|
||||||
|
from ..torch_before_ort import InferenceSession
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -95,6 +96,7 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet, ORTModelTextEncoder
|
||||||
|
|
||||||
def load_pipeline(
|
def load_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -244,6 +246,7 @@ def load_pipeline(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
list(zip(lora_models, lora_weights)),
|
list(zip(lora_models, lora_weights)),
|
||||||
"text_encoder",
|
"text_encoder",
|
||||||
|
1 if params.is_xl() else None,
|
||||||
)
|
)
|
||||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||||
text_encoder
|
text_encoder
|
||||||
|
@ -253,13 +256,49 @@ def load_pipeline(
|
||||||
text_encoder_opts.add_external_initializers(
|
text_encoder_opts.add_external_initializers(
|
||||||
list(text_encoder_names), list(text_encoder_values)
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
)
|
)
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
if params.is_xl():
|
||||||
|
text_encoder_session = InferenceSession(
|
||||||
text_encoder.SerializeToString(),
|
text_encoder.SerializeToString(),
|
||||||
provider=device.ort_provider("text-encoder"),
|
providers=[device.ort_provider("text-encoder")],
|
||||||
sess_options=text_encoder_opts,
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
)
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
|
components["text_encoder"] = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
||||||
|
else:
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
provider=device.ort_provider("text-encoder"),
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||||
|
text_encoder2 = blend_loras(
|
||||||
|
server,
|
||||||
|
text_encoder2,
|
||||||
|
list(zip(lora_models, lora_weights)),
|
||||||
|
"text_encoder",
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
|
||||||
|
text_encoder2
|
||||||
|
)
|
||||||
|
text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
|
||||||
|
text_encoder2_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder2_opts.add_external_initializers(
|
||||||
|
list(text_encoder2_names), list(text_encoder2_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder2_session = InferenceSession(
|
||||||
|
text_encoder2.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder2_opts,
|
||||||
|
)
|
||||||
|
text_encoder2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
|
components["text_encoder_2"] = ORTModelTextEncoder(text_encoder2_session, text_encoder2)
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||||
|
@ -273,13 +312,23 @@ def load_pipeline(
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = device.sess_options(cache=False)
|
unet_opts = device.sess_options(cache=False)
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
components["unet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
if params.is_xl():
|
||||||
|
unet_session = InferenceSession(
|
||||||
unet_model.SerializeToString(),
|
unet_model.SerializeToString(),
|
||||||
provider=device.ort_provider("unet"),
|
providers=[device.ort_provider("unet")],
|
||||||
sess_options=unet_opts,
|
sess_options=unet_opts,
|
||||||
)
|
)
|
||||||
)
|
unet_session._model_path = path.join(model, "unet")
|
||||||
|
components["unet"] = ORTModelUnet(unet_session, unet_model)
|
||||||
|
else:
|
||||||
|
components["unet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
unet_model.SerializeToString(),
|
||||||
|
provider=device.ort_provider("unet"),
|
||||||
|
sess_options=unet_opts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# make sure a UNet has been loaded
|
# make sure a UNet has been loaded
|
||||||
if not params.is_xl() and "unet" not in components:
|
if not params.is_xl() and "unet" not in components:
|
||||||
|
@ -344,6 +393,15 @@ def load_pipeline(
|
||||||
**components,
|
**components,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# make sure XL models are actually being used
|
||||||
|
# TODO: why is this needed?
|
||||||
|
logger.info("text encoder matches: %s, %s", pipe.text_encoder == components["text_encoder"], type(pipe.text_encoder))
|
||||||
|
pipe.text_encoder = components["text_encoder"]
|
||||||
|
logger.info("text encoder 2 matches: %s, %s", pipe.text_encoder_2 == components["text_encoder_2"], type(pipe.text_encoder_2))
|
||||||
|
pipe.text_encoder_2 = components["text_encoder_2"]
|
||||||
|
logger.info("unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet))
|
||||||
|
pipe.unet = components["unet"]
|
||||||
|
|
||||||
if not server.show_progress:
|
if not server.show_progress:
|
||||||
pipe.set_progress_bar_config(disable=True)
|
pipe.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,12 @@ def run_txt2img_pipeline(
|
||||||
else:
|
else:
|
||||||
tile_size = params.tiles
|
tile_size = params.tiles
|
||||||
|
|
||||||
|
# split prompts for each stage
|
||||||
|
if "||" in params.prompt:
|
||||||
|
txt_prompt, img_prompt = params.prompt.split("||")
|
||||||
|
else:
|
||||||
|
txt_prompt = img_prompt = params.prompt
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
chain.stage(
|
chain.stage(
|
||||||
|
|
|
@ -436,3 +436,11 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
|
||||||
return np.reshape(flat_tile[indices], tile.shape)
|
return np.reshape(flat_tile[indices], tile.shape)
|
||||||
else:
|
else:
|
||||||
return tile
|
return tile
|
||||||
|
|
||||||
|
|
||||||
|
def slice_prompt(prompt: str, slice: int) -> str:
|
||||||
|
if "||" in prompt:
|
||||||
|
parts = prompt.split("||")
|
||||||
|
return parts[min(slice, len(parts) - 1)]
|
||||||
|
else:
|
||||||
|
return prompt
|
|
@ -4,7 +4,7 @@ from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from piexif import ExifIFD, ImageIFD, dump
|
from piexif import ExifIFD, ImageIFD, dump
|
||||||
from piexif.helper import UserComment
|
from piexif.helper import UserComment
|
||||||
|
@ -57,8 +57,10 @@ def json_params(
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
|
parent: Dict = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
json = {
|
json = {
|
||||||
|
"input_size": size.tojson(),
|
||||||
"outputs": outputs,
|
"outputs": outputs,
|
||||||
"params": params.tojson(),
|
"params": params.tojson(),
|
||||||
}
|
}
|
||||||
|
@ -66,6 +68,7 @@ def json_params(
|
||||||
json["params"]["model"] = path.basename(params.model)
|
json["params"]["model"] = path.basename(params.model)
|
||||||
json["params"]["scheduler"] = params.scheduler
|
json["params"]["scheduler"] = params.scheduler
|
||||||
|
|
||||||
|
# calculate final output size
|
||||||
output_size = size
|
output_size = size
|
||||||
if border is not None:
|
if border is not None:
|
||||||
json["border"] = border.tojson()
|
json["border"] = border.tojson()
|
||||||
|
@ -79,7 +82,6 @@ def json_params(
|
||||||
json["upscale"] = upscale.tojson()
|
json["upscale"] = upscale.tojson()
|
||||||
output_size = upscale.resize(output_size)
|
output_size = upscale.resize(output_size)
|
||||||
|
|
||||||
json["input_size"] = size.tojson()
|
|
||||||
json["size"] = output_size.tojson()
|
json["size"] = output_size.tojson()
|
||||||
|
|
||||||
return json
|
return json
|
||||||
|
|
|
@ -282,6 +282,9 @@ class ImageParams:
|
||||||
def is_panorama(self):
|
def is_panorama(self):
|
||||||
return self.pipeline == "panorama"
|
return self.pipeline == "panorama"
|
||||||
|
|
||||||
|
def is_pix2pix(self):
|
||||||
|
return self.pipeline == "pix2pix"
|
||||||
|
|
||||||
def is_xl(self):
|
def is_xl(self):
|
||||||
return self.pipeline.endswith("-sdxl")
|
return self.pipeline.endswith("-sdxl")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue