feat(api): parse LoRA weights from prompt
This commit is contained in:
parent
45166f281e
commit
a7f77a033d
|
@ -61,7 +61,7 @@ def fix_node_name(key: str):
|
|||
|
||||
def merge_lora(
|
||||
base_name: str,
|
||||
lora_names: str,
|
||||
lora_names: List[str],
|
||||
dest_type: Literal["text_encoder", "unet"],
|
||||
lora_weights: "np.NDArray[np.float64]" = None,
|
||||
):
|
||||
|
|
|
@ -37,7 +37,7 @@ try:
|
|||
except ImportError:
|
||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||
|
||||
from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora
|
||||
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
|
||||
from ..params import DeviceParams, Size
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
|
@ -107,26 +107,6 @@ def get_tile_latents(
|
|||
return full_latents[:, :, y:yt, x:xt]
|
||||
|
||||
|
||||
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
|
||||
remaining_prompt = prompt
|
||||
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||
|
||||
loras = []
|
||||
next_match = lora_expr.search(remaining_prompt)
|
||||
while next_match is not None:
|
||||
logger.debug("found LoRA token in prompt: %s", next_match)
|
||||
name, weight = next_match.groups()
|
||||
loras.append(name)
|
||||
# remove this match and look for another
|
||||
remaining_prompt = (
|
||||
remaining_prompt[: next_match.start()]
|
||||
+ remaining_prompt[next_match.end() :]
|
||||
)
|
||||
next_match = lora_expr.search(remaining_prompt)
|
||||
|
||||
return (remaining_prompt, loras)
|
||||
|
||||
|
||||
def optimize_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
|
@ -177,7 +157,7 @@ def load_pipeline(
|
|||
device: DeviceParams,
|
||||
lpw: bool,
|
||||
inversion: Optional[str],
|
||||
loras: Optional[List[str]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
):
|
||||
loras = loras or []
|
||||
pipe_key = (
|
||||
|
@ -247,46 +227,38 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# test LoRA blending
|
||||
lora_models = [
|
||||
path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras
|
||||
]
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||
|
||||
# blend and load text encoder
|
||||
blended_text_encoder = merge_lora(
|
||||
path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder"
|
||||
)
|
||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(
|
||||
blended_text_encoder
|
||||
)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = SessionOptions()
|
||||
text_encoder_opts.add_external_initializers(
|
||||
list(text_encoder_names), list(text_encoder_values)
|
||||
)
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder_model.SerializeToString(),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=text_encoder_opts,
|
||||
if len(lora_models) > 0:
|
||||
# blend and load text encoder
|
||||
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights)
|
||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = SessionOptions()
|
||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder_model.SerializeToString(),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# blend and load unet
|
||||
blended_unet = merge_lora(
|
||||
path.join(model, "unet", "model.onnx"), lora_models, "unet"
|
||||
)
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = SessionOptions()
|
||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
unet_model.SerializeToString(),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=unet_opts,
|
||||
# blend and load unet
|
||||
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = SessionOptions()
|
||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
unet_model.SerializeToString(),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=unet_opts,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
|
|
|
@ -14,7 +14,8 @@ from ..server import ServerContext
|
|||
from ..upscale import run_upscale_correction
|
||||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
from .utils import get_loras_from_prompt
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from logging import getLogger
|
||||
from math import ceil
|
||||
from re import compile
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
@ -128,3 +128,20 @@ def expand_prompt(
|
|||
|
||||
logger.debug("expanded prompt shape: %s", prompt_embeds.shape)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
remaining_prompt = prompt
|
||||
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||
|
||||
loras = []
|
||||
next_match = lora_expr.search(remaining_prompt)
|
||||
while next_match is not None:
|
||||
logger.debug("found LoRA token in prompt: %s", next_match)
|
||||
name, weight = next_match.groups()
|
||||
loras.append((name, float(weight)))
|
||||
# remove this match and look for another
|
||||
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
||||
next_match = lora_expr.search(remaining_prompt)
|
||||
|
||||
return (remaining_prompt, loras)
|
||||
|
|
Loading…
Reference in New Issue