1
0
Fork 0

start wiring LoRAs into prompt

This commit is contained in:
Sean Sube 2023-03-14 22:10:33 -05:00
parent ce05e76947
commit 03f4e1b922
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 41 additions and 15 deletions

View File

@ -8,7 +8,6 @@ import torch
from onnx import ModelProto, load, numpy_helper from onnx import ModelProto, load, numpy_helper
from onnx.checker import check_model from onnx.checker import check_model
from onnx.external_data_helper import ( from onnx.external_data_helper import (
ExternalDataInfo,
convert_model_to_external_data, convert_model_to_external_data,
set_external_data, set_external_data,
write_external_data_tensors, write_external_data_tensors,
@ -61,7 +60,6 @@ def fix_node_name(key: str):
def merge_lora( def merge_lora(
base_name: str, base_name: str,
lora_names: str, lora_names: str,
dest_path: str,
dest_type: Literal["text_encoder", "unet"], dest_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None, lora_weights: "np.NDArray[np.float64]" = None,
): ):
@ -227,7 +225,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights) logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights) blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
if args.dest is None or args.dest == "" or args.dest == "ort": if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory # convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model) (bare_model, external_data) = buffer_external_data_tensors(blend_model)
@ -247,3 +245,7 @@ if __name__ == "__main__":
model_file.write(bare_model.SerializeToString()) model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file) logger.info("successfully saved blended model: %s", dest_file)
check_model(dest_file)
logger.info("checked blended model")

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, Optional, Tuple from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
from diffusers import ( from diffusers import (
@ -106,6 +106,13 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt] return full_latents[:, :, y:yt, x:xt]
def get_loras_from_prompt(prompt: str) -> List[str]:
return [
"arch",
"glass",
]
def optimize_pipeline( def optimize_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
@ -156,7 +163,9 @@ def load_pipeline(
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
inversion: Optional[str], inversion: Optional[str],
loras: Optional[List[str]] = None,
): ):
loras = loras or []
pipe_key = ( pipe_key = (
pipeline.__name__, pipeline.__name__,
model, model,
@ -164,6 +173,7 @@ def load_pipeline(
device.provider, device.provider,
lpw, lpw,
inversion, inversion,
loras,
) )
scheduler_key = (scheduler_name, model) scheduler_key = (scheduler_name, model)
scheduler_type = get_pipeline_schedulers()[scheduler_name] scheduler_type = get_pipeline_schedulers()[scheduler_name]
@ -223,26 +233,36 @@ def load_pipeline(
) )
# test LoRA blending # test LoRA blending
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [ lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras]
"arch", logger.info("blending base model %s with LoRA models: %s", model, lora_models)
"glass",
]]
logger.info("blending text encoder with LoRA models: %s", lora_models) # blend and load text encoder
blended_text_encoder = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), lora_models, None, "text_encoder") blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, None, "text_encoder")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
components["text_encoder"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(text_encoder_model.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts)) components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder_model.SerializeToString(),
provider=device.ort_provider(),
sess_options=text_encoder_opts,
)
)
logger.info("blending unet with LoRA models: %s", lora_models) # blend and load unet
blended_unet = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/unet/model.onnx"), lora_models, None, "unet") blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, None, "unet")
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = SessionOptions()
unet_opts.add_external_initializers(list(unet_names), list(unet_values)) unet_opts.add_external_initializers(list(unet_names), list(unet_values))
components["unet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(unet_model.SerializeToString(), provider=device.ort_provider(), sess_options=unet_opts)) components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider(),
sess_options=unet_opts,
)
)
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,

View File

@ -14,7 +14,7 @@ from ..server import ServerContext
from ..upscale import run_upscale_correction from ..upscale import run_upscale_correction
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline
logger = getLogger(__name__) logger = getLogger(__name__)
@ -28,6 +28,7 @@ def run_txt2img_pipeline(
upscale: UpscaleParams, upscale: UpscaleParams,
) -> None: ) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
loras = get_loras_from_prompt(params.prompt)
pipe = load_pipeline( pipe = load_pipeline(
server, server,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
@ -36,6 +37,7 @@ def run_txt2img_pipeline(
job.get_device(), job.get_device(),
params.lpw, params.lpw,
params.inversion, params.inversion,
loras,
) )
progress = job.get_progress_callback() progress = job.get_progress_callback()
@ -99,6 +101,7 @@ def run_img2img_pipeline(
source: Image.Image, source: Image.Image,
strength: float, strength: float,
) -> None: ) -> None:
loras = get_loras_from_prompt(params.prompt)
pipe = load_pipeline( pipe = load_pipeline(
server, server,
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,
@ -107,6 +110,7 @@ def run_img2img_pipeline(
job.get_device(), job.get_device(),
params.lpw, params.lpw,
params.inversion, params.inversion,
loras,
) )
progress = job.get_progress_callback() progress = job.get_progress_callback()
if params.lpw: if params.lpw: