feat(api): support for SDXL LoRAs
This commit is contained in:
parent
0ad250251e
commit
bbacfd1ca0
|
@ -57,16 +57,12 @@ class BlendImg2ImgStage(BaseStage):
|
|||
)
|
||||
|
||||
pipe_params = {}
|
||||
if params.is_control():
|
||||
pipe_params["controlnet_conditioning_scale"] = strength
|
||||
elif params.is_lpw():
|
||||
pipe_params["strength"] = strength
|
||||
elif params.is_panorama():
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "img2img" or pipe_type == "img2img-sdxl":
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "pix2pix":
|
||||
if params.is_pix2pix():
|
||||
pipe_params["image_guidance_scale"] = strength
|
||||
elif params.is_control():
|
||||
pipe_params["controlnet_conditioning_scale"] = strength
|
||||
else:
|
||||
pipe_params["strength"] = strength
|
||||
|
||||
outputs = []
|
||||
for source in sources:
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
from argparse import ArgumentParser
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Dict, List, Literal, Tuple, Union
|
||||
from re import sub
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import ModelProto, load, numpy_helper
|
||||
from onnx import ModelProto, load, numpy_helper, save_model
|
||||
from onnx.checker import check_model
|
||||
from onnx.external_data_helper import (
|
||||
convert_model_to_external_data,
|
||||
set_external_data,
|
||||
write_external_data_tensors,
|
||||
)
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
|
||||
from ...server.context import ServerContext
|
||||
|
@ -77,6 +79,67 @@ def fix_node_name(key: str):
|
|||
return fixed_name
|
||||
|
||||
|
||||
def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]):
|
||||
fixed = {}
|
||||
|
||||
for key, value in keys.items():
|
||||
root, *rest = key.split(".")
|
||||
logger.debug("fixing XL node name: %s -> %s", key, root) # TODO: move to trace
|
||||
|
||||
if root.startswith("input"):
|
||||
block = "down_blocks"
|
||||
elif root.startswith("middle"):
|
||||
block = "mid_block" # not plural
|
||||
elif root.startswith("output"):
|
||||
block = "up_blocks"
|
||||
elif root.startswith("text_model"):
|
||||
block = "text_model"
|
||||
else:
|
||||
logger.warning("unknown XL key name: %s", key)
|
||||
fixed[key] = value
|
||||
continue
|
||||
|
||||
suffix = None
|
||||
for s in ["fc1", "fc2", "ff_net_0_proj", "ff_net_2", "proj", "proj_in", "proj_out", "to_k", "to_out_0", "to_q", "to_v"]:
|
||||
if root.endswith(s):
|
||||
suffix = s
|
||||
|
||||
if suffix is None:
|
||||
logger.warning("new XL key type: %s", root)
|
||||
continue
|
||||
|
||||
logger.debug("searching for XL node: /%s/*/%s", block, suffix)
|
||||
if block == "text_model":
|
||||
matches = [
|
||||
node for node in nodes
|
||||
if fix_node_name(node.name) == f"{root}_MatMul"
|
||||
]
|
||||
else:
|
||||
matches = [
|
||||
node for node in nodes
|
||||
if node.name.startswith(f"/{block}")
|
||||
and fix_node_name(node.name).endswith(f"{suffix}_MatMul") # needs to be fixed because some places use to_out.0
|
||||
]
|
||||
|
||||
if len(matches) == 0:
|
||||
logger.warning("no matches for XL key: %s", root)
|
||||
continue
|
||||
|
||||
name: str = matches[0].name
|
||||
name = fix_node_name(name.rstrip("/MatMul"))
|
||||
|
||||
if name.endswith("proj_o"):
|
||||
# wtf
|
||||
name = f"{name}ut"
|
||||
|
||||
logger.debug("matching XL key with node: %s -> %s", key, matches[0].name)
|
||||
|
||||
fixed[name] = value
|
||||
nodes.remove(matches[0])
|
||||
|
||||
return fixed
|
||||
|
||||
|
||||
def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||
return (
|
||||
min(x, shape[2] - 1),
|
||||
|
@ -84,11 +147,13 @@ def kernel_slice(x: int, y: int, shape: Tuple[int, int, int, int]) -> Tuple[int,
|
|||
)
|
||||
|
||||
|
||||
|
||||
def blend_loras(
|
||||
_conversion: ServerContext,
|
||||
base_name: Union[str, ModelProto],
|
||||
loras: List[Tuple[str, float]],
|
||||
model_type: Literal["text_encoder", "unet"],
|
||||
model_index: Optional[int] = None,
|
||||
):
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
|
@ -98,7 +163,10 @@ def blend_loras(
|
|||
lora_models = [load_tensor(name, map_location=device) for name, _weight in loras]
|
||||
|
||||
if model_type == "text_encoder":
|
||||
lora_prefix = "lora_te_"
|
||||
if model_index is None:
|
||||
lora_prefix = "lora_te_"
|
||||
else:
|
||||
lora_prefix = f"lora_te{model_index}_"
|
||||
else:
|
||||
lora_prefix = f"lora_{model_type}_"
|
||||
|
||||
|
@ -313,11 +381,15 @@ def blend_loras(
|
|||
else:
|
||||
blended[base_key] = np_weights
|
||||
|
||||
# rewrite node names for XL
|
||||
nodes = list(base_model.graph.node)
|
||||
blended = fix_xl_names(blended, nodes)
|
||||
|
||||
logger.trace(
|
||||
"updating %s of %s initializers: %s",
|
||||
"updating %s of %s initializers, %s missed",
|
||||
len(blended.keys()),
|
||||
len(base_model.graph.initializer),
|
||||
list(blended.keys()),
|
||||
len(nodes)
|
||||
)
|
||||
|
||||
fixed_initializer_names = [
|
||||
|
@ -419,8 +491,13 @@ def blend_loras(
|
|||
onnx_weights.shape,
|
||||
)
|
||||
|
||||
blended = onnx_weights + weights.transpose()
|
||||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
t_weights = weights.transpose()
|
||||
if weights.shape != onnx_weights.shape and t_weights.shape != onnx_weights.shape:
|
||||
logger.warning("weight shapes do not match for %s: %s vs %s", matmul_key, weights.shape, onnx_weights.shape)
|
||||
t_weights = interp_to_match(weights, onnx_weights).transpose()
|
||||
|
||||
blended = onnx_weights + t_weights
|
||||
logger.debug("blended weight shape: %s, %s", blended.shape, onnx_weights.dtype)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(
|
||||
|
@ -442,9 +519,29 @@ def blend_loras(
|
|||
if len(unmatched_keys) > 0:
|
||||
logger.warning("could not find nodes for some keys: %s", unmatched_keys)
|
||||
|
||||
# if model_type == "unet":
|
||||
# save_model(base_model, f"/tmp/lora_blend_{model_type}.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb")
|
||||
|
||||
return base_model
|
||||
|
||||
|
||||
from scipy import interpolate
|
||||
|
||||
def interp_to_match(ref: np.ndarray, resize: np.ndarray) -> np.ndarray:
|
||||
res_x = np.linspace(0, 1, resize.shape[0])
|
||||
res_y = np.linspace(0, 1, resize.shape[1])
|
||||
ref_x = np.linspace(0, 1, ref.shape[0])
|
||||
ref_y = np.linspace(0, 1, ref.shape[1])
|
||||
logger.debug("dims: %s, %s, %s, %s", resize.shape[0], resize.shape[1], ref.shape[0], ref.shape[1])
|
||||
|
||||
f = interpolate.RegularGridInterpolator((ref_x, ref_y), ref, method='linear')
|
||||
xg, yg = np.meshgrid(res_x, res_y)
|
||||
output = f((xg, yg))
|
||||
logger.debug("weights after interpolation: %s", output.shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context = ConversionContext.from_environ()
|
||||
parser = ArgumentParser()
|
||||
|
|
|
@ -45,6 +45,7 @@ from .version_safe_diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from ..torch_before_ort import InferenceSession
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -95,6 +96,7 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
|||
|
||||
return None
|
||||
|
||||
from optimum.onnxruntime.modeling_diffusion import ORTModelUnet, ORTModelTextEncoder
|
||||
|
||||
def load_pipeline(
|
||||
server: ServerContext,
|
||||
|
@ -244,6 +246,7 @@ def load_pipeline(
|
|||
text_encoder,
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"text_encoder",
|
||||
1 if params.is_xl() else None,
|
||||
)
|
||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||
text_encoder
|
||||
|
@ -253,13 +256,49 @@ def load_pipeline(
|
|||
text_encoder_opts.add_external_initializers(
|
||||
list(text_encoder_names), list(text_encoder_values)
|
||||
)
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder_session = InferenceSession(
|
||||
text_encoder.SerializeToString(),
|
||||
provider=device.ort_provider("text-encoder"),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
)
|
||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||
components["text_encoder"] = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
||||
else:
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
provider=device.ort_provider("text-encoder"),
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
)
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||
text_encoder2 = blend_loras(
|
||||
server,
|
||||
text_encoder2,
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"text_encoder",
|
||||
2,
|
||||
)
|
||||
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
|
||||
text_encoder2
|
||||
)
|
||||
text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
|
||||
text_encoder2_opts = device.sess_options(cache=False)
|
||||
text_encoder2_opts.add_external_initializers(
|
||||
list(text_encoder2_names), list(text_encoder2_values)
|
||||
)
|
||||
|
||||
text_encoder2_session = InferenceSession(
|
||||
text_encoder2.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder2_opts,
|
||||
)
|
||||
text_encoder2_session._model_path = path.join(model, "text_encoder_2")
|
||||
components["text_encoder_2"] = ORTModelTextEncoder(text_encoder2_session, text_encoder2)
|
||||
|
||||
# blend and load unet
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
|
@ -273,13 +312,23 @@ def load_pipeline(
|
|||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = device.sess_options(cache=False)
|
||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
||||
if params.is_xl():
|
||||
unet_session = InferenceSession(
|
||||
unet_model.SerializeToString(),
|
||||
provider=device.ort_provider("unet"),
|
||||
providers=[device.ort_provider("unet")],
|
||||
sess_options=unet_opts,
|
||||
)
|
||||
)
|
||||
unet_session._model_path = path.join(model, "unet")
|
||||
components["unet"] = ORTModelUnet(unet_session, unet_model)
|
||||
else:
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
unet_model.SerializeToString(),
|
||||
provider=device.ort_provider("unet"),
|
||||
sess_options=unet_opts,
|
||||
)
|
||||
)
|
||||
|
||||
# make sure a UNet has been loaded
|
||||
if not params.is_xl() and "unet" not in components:
|
||||
|
@ -344,6 +393,15 @@ def load_pipeline(
|
|||
**components,
|
||||
)
|
||||
|
||||
# make sure XL models are actually being used
|
||||
# TODO: why is this needed?
|
||||
logger.info("text encoder matches: %s, %s", pipe.text_encoder == components["text_encoder"], type(pipe.text_encoder))
|
||||
pipe.text_encoder = components["text_encoder"]
|
||||
logger.info("text encoder 2 matches: %s, %s", pipe.text_encoder_2 == components["text_encoder_2"], type(pipe.text_encoder_2))
|
||||
pipe.text_encoder_2 = components["text_encoder_2"]
|
||||
logger.info("unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet))
|
||||
pipe.unet = components["unet"]
|
||||
|
||||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
|
|
|
@ -48,6 +48,12 @@ def run_txt2img_pipeline(
|
|||
else:
|
||||
tile_size = params.tiles
|
||||
|
||||
# split prompts for each stage
|
||||
if "||" in params.prompt:
|
||||
txt_prompt, img_prompt = params.prompt.split("||")
|
||||
else:
|
||||
txt_prompt = img_prompt = params.prompt
|
||||
|
||||
# prepare the chain pipeline and first stage
|
||||
chain = ChainPipeline()
|
||||
chain.stage(
|
||||
|
|
|
@ -436,3 +436,11 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
|
|||
return np.reshape(flat_tile[indices], tile.shape)
|
||||
else:
|
||||
return tile
|
||||
|
||||
|
||||
def slice_prompt(prompt: str, slice: int) -> str:
|
||||
if "||" in prompt:
|
||||
parts = prompt.split("||")
|
||||
return parts[min(slice, len(parts) - 1)]
|
||||
else:
|
||||
return prompt
|
|
@ -4,7 +4,7 @@ from logging import getLogger
|
|||
from os import path
|
||||
from struct import pack
|
||||
from time import time
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from piexif import ExifIFD, ImageIFD, dump
|
||||
from piexif.helper import UserComment
|
||||
|
@ -57,8 +57,10 @@ def json_params(
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
parent: Dict = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
"input_size": size.tojson(),
|
||||
"outputs": outputs,
|
||||
"params": params.tojson(),
|
||||
}
|
||||
|
@ -66,6 +68,7 @@ def json_params(
|
|||
json["params"]["model"] = path.basename(params.model)
|
||||
json["params"]["scheduler"] = params.scheduler
|
||||
|
||||
# calculate final output size
|
||||
output_size = size
|
||||
if border is not None:
|
||||
json["border"] = border.tojson()
|
||||
|
@ -79,7 +82,6 @@ def json_params(
|
|||
json["upscale"] = upscale.tojson()
|
||||
output_size = upscale.resize(output_size)
|
||||
|
||||
json["input_size"] = size.tojson()
|
||||
json["size"] = output_size.tojson()
|
||||
|
||||
return json
|
||||
|
|
|
@ -282,6 +282,9 @@ class ImageParams:
|
|||
def is_panorama(self):
|
||||
return self.pipeline == "panorama"
|
||||
|
||||
def is_pix2pix(self):
|
||||
return self.pipeline == "pix2pix"
|
||||
|
||||
def is_xl(self):
|
||||
return self.pipeline.endswith("-sdxl")
|
||||
|
||||
|
|
Loading…
Reference in New Issue