1
0
Fork 0

start wiring LoRAs into prompt

This commit is contained in:
Sean Sube 2023-03-14 22:10:33 -05:00
parent ce05e76947
commit 03f4e1b922
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 41 additions and 15 deletions

View File

@ -8,7 +8,6 @@ import torch
from onnx import ModelProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import (
ExternalDataInfo,
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
@ -61,7 +60,6 @@ def fix_node_name(key: str):
def merge_lora(
base_name: str,
lora_names: str,
dest_path: str,
dest_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None,
):
@ -227,7 +225,7 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
@ -247,3 +245,7 @@ if __name__ == "__main__":
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
import numpy as np
from diffusers import (
@ -106,6 +106,13 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt]
def get_loras_from_prompt(prompt: str) -> List[str]:
return [
"arch",
"glass",
]
def optimize_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
@ -156,7 +163,9 @@ def load_pipeline(
device: DeviceParams,
lpw: bool,
inversion: Optional[str],
loras: Optional[List[str]] = None,
):
loras = loras or []
pipe_key = (
pipeline.__name__,
model,
@ -164,6 +173,7 @@ def load_pipeline(
device.provider,
lpw,
inversion,
loras,
)
scheduler_key = (scheduler_name, model)
scheduler_type = get_pipeline_schedulers()[scheduler_name]
@ -223,26 +233,36 @@ def load_pipeline(
)
# test LoRA blending
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [
"arch",
"glass",
]]
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
logger.info("blending text encoder with LoRA models: %s", lora_models)
blended_text_encoder = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), lora_models, None, "text_encoder")
# blend and load text encoder
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, None, "text_encoder")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
components["text_encoder"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(text_encoder_model.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts))
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder_model.SerializeToString(),
provider=device.ort_provider(),
sess_options=text_encoder_opts,
)
)
logger.info("blending unet with LoRA models: %s", lora_models)
blended_unet = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/unet/model.onnx"), lora_models, None, "unet")
# blend and load unet
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, None, "unet")
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions()
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
components["unet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(unet_model.SerializeToString(), provider=device.ort_provider(), sess_options=unet_opts))
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider(),
sess_options=unet_opts,
)
)
pipe = pipeline.from_pretrained(
model,

View File

@ -14,7 +14,7 @@ from ..server import ServerContext
from ..upscale import run_upscale_correction
from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline
from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline
logger = getLogger(__name__)
@ -28,6 +28,7 @@ def run_txt2img_pipeline(
upscale: UpscaleParams,
) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
loras = get_loras_from_prompt(params.prompt)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
@ -36,6 +37,7 @@ def run_txt2img_pipeline(
job.get_device(),
params.lpw,
params.inversion,
loras,
)
progress = job.get_progress_callback()
@ -99,6 +101,7 @@ def run_img2img_pipeline(
source: Image.Image,
strength: float,
) -> None:
loras = get_loras_from_prompt(params.prompt)
pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
@ -107,6 +110,7 @@ def run_img2img_pipeline(
job.get_device(),
params.lpw,
params.inversion,
loras,
)
progress = job.get_progress_callback()
if params.lpw: