start wiring LoRAs into prompt
This commit is contained in:
parent
ce05e76947
commit
03f4e1b922
|
@ -8,7 +8,6 @@ import torch
|
||||||
from onnx import ModelProto, load, numpy_helper
|
from onnx import ModelProto, load, numpy_helper
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from onnx.external_data_helper import (
|
from onnx.external_data_helper import (
|
||||||
ExternalDataInfo,
|
|
||||||
convert_model_to_external_data,
|
convert_model_to_external_data,
|
||||||
set_external_data,
|
set_external_data,
|
||||||
write_external_data_tensors,
|
write_external_data_tensors,
|
||||||
|
@ -61,7 +60,6 @@ def fix_node_name(key: str):
|
||||||
def merge_lora(
|
def merge_lora(
|
||||||
base_name: str,
|
base_name: str,
|
||||||
lora_names: str,
|
lora_names: str,
|
||||||
dest_path: str,
|
|
||||||
dest_type: Literal["text_encoder", "unet"],
|
dest_type: Literal["text_encoder", "unet"],
|
||||||
lora_weights: "np.NDArray[np.float64]" = None,
|
lora_weights: "np.NDArray[np.float64]" = None,
|
||||||
):
|
):
|
||||||
|
@ -227,7 +225,7 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
||||||
|
|
||||||
blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
|
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
|
||||||
if args.dest is None or args.dest == "" or args.dest == "ort":
|
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||||
# convert to external data and save to memory
|
# convert to external data and save to memory
|
||||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||||
|
@ -247,3 +245,7 @@ if __name__ == "__main__":
|
||||||
model_file.write(bare_model.SerializeToString())
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
||||||
logger.info("successfully saved blended model: %s", dest_file)
|
logger.info("successfully saved blended model: %s", dest_file)
|
||||||
|
|
||||||
|
check_model(dest_file)
|
||||||
|
|
||||||
|
logger.info("checked blended model")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
@ -106,6 +106,13 @@ def get_tile_latents(
|
||||||
return full_latents[:, :, y:yt, x:xt]
|
return full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
|
|
||||||
|
def get_loras_from_prompt(prompt: str) -> List[str]:
|
||||||
|
return [
|
||||||
|
"arch",
|
||||||
|
"glass",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def optimize_pipeline(
|
def optimize_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
@ -156,7 +163,9 @@ def load_pipeline(
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
inversion: Optional[str],
|
inversion: Optional[str],
|
||||||
|
loras: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
loras = loras or []
|
||||||
pipe_key = (
|
pipe_key = (
|
||||||
pipeline.__name__,
|
pipeline.__name__,
|
||||||
model,
|
model,
|
||||||
|
@ -164,6 +173,7 @@ def load_pipeline(
|
||||||
device.provider,
|
device.provider,
|
||||||
lpw,
|
lpw,
|
||||||
inversion,
|
inversion,
|
||||||
|
loras,
|
||||||
)
|
)
|
||||||
scheduler_key = (scheduler_name, model)
|
scheduler_key = (scheduler_name, model)
|
||||||
scheduler_type = get_pipeline_schedulers()[scheduler_name]
|
scheduler_type = get_pipeline_schedulers()[scheduler_name]
|
||||||
|
@ -223,26 +233,36 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# test LoRA blending
|
# test LoRA blending
|
||||||
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [
|
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras]
|
||||||
"arch",
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
"glass",
|
|
||||||
]]
|
|
||||||
|
|
||||||
logger.info("blending text encoder with LoRA models: %s", lora_models)
|
# blend and load text encoder
|
||||||
blended_text_encoder = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), lora_models, None, "text_encoder")
|
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, None, "text_encoder")
|
||||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = SessionOptions()
|
||||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||||
components["text_encoder"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(text_encoder_model.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts))
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder_model.SerializeToString(),
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("blending unet with LoRA models: %s", lora_models)
|
# blend and load unet
|
||||||
blended_unet = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/unet/model.onnx"), lora_models, None, "unet")
|
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, None, "unet")
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = SessionOptions()
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
components["unet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(unet_model.SerializeToString(), provider=device.ort_provider(), sess_options=unet_opts))
|
components["unet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
unet_model.SerializeToString(),
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=unet_opts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from ..server import ServerContext
|
||||||
from ..upscale import run_upscale_correction
|
from ..upscale import run_upscale_correction
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .load import get_latents_from_seed, load_pipeline
|
from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ def run_txt2img_pipeline(
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
|
loras = get_loras_from_prompt(params.prompt)
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
|
@ -36,6 +37,7 @@ def run_txt2img_pipeline(
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
params.lpw,
|
params.lpw,
|
||||||
params.inversion,
|
params.inversion,
|
||||||
|
loras,
|
||||||
)
|
)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
|
|
||||||
|
@ -99,6 +101,7 @@ def run_img2img_pipeline(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
strength: float,
|
strength: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
loras = get_loras_from_prompt(params.prompt)
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
@ -107,6 +110,7 @@ def run_img2img_pipeline(
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
params.lpw,
|
params.lpw,
|
||||||
params.inversion,
|
params.inversion,
|
||||||
|
loras,
|
||||||
)
|
)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
if params.lpw:
|
if params.lpw:
|
||||||
|
|
Loading…
Reference in New Issue