feat(api): parse LoRA weights from prompt
This commit is contained in:
parent
45166f281e
commit
a7f77a033d
|
@ -61,7 +61,7 @@ def fix_node_name(key: str):
|
||||||
|
|
||||||
def merge_lora(
|
def merge_lora(
|
||||||
base_name: str,
|
base_name: str,
|
||||||
lora_names: str,
|
lora_names: List[str],
|
||||||
dest_type: Literal["text_encoder", "unet"],
|
dest_type: Literal["text_encoder", "unet"],
|
||||||
lora_weights: "np.NDArray[np.float64]" = None,
|
lora_weights: "np.NDArray[np.float64]" = None,
|
||||||
):
|
):
|
||||||
|
|
|
@ -37,7 +37,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||||
|
|
||||||
from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora
|
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -107,26 +107,6 @@ def get_tile_latents(
|
||||||
return full_latents[:, :, y:yt, x:xt]
|
return full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
|
|
||||||
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
|
|
||||||
remaining_prompt = prompt
|
|
||||||
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
|
||||||
|
|
||||||
loras = []
|
|
||||||
next_match = lora_expr.search(remaining_prompt)
|
|
||||||
while next_match is not None:
|
|
||||||
logger.debug("found LoRA token in prompt: %s", next_match)
|
|
||||||
name, weight = next_match.groups()
|
|
||||||
loras.append(name)
|
|
||||||
# remove this match and look for another
|
|
||||||
remaining_prompt = (
|
|
||||||
remaining_prompt[: next_match.start()]
|
|
||||||
+ remaining_prompt[next_match.end() :]
|
|
||||||
)
|
|
||||||
next_match = lora_expr.search(remaining_prompt)
|
|
||||||
|
|
||||||
return (remaining_prompt, loras)
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_pipeline(
|
def optimize_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
@ -177,7 +157,7 @@ def load_pipeline(
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
inversion: Optional[str],
|
inversion: Optional[str],
|
||||||
loras: Optional[List[str]] = None,
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
):
|
):
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
pipe_key = (
|
pipe_key = (
|
||||||
|
@ -247,46 +227,38 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# test LoRA blending
|
# test LoRA blending
|
||||||
lora_models = [
|
lora_names, lora_weights = zip(*loras)
|
||||||
path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras
|
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
||||||
]
|
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
# blend and load text encoder
|
if len(lora_models) > 0:
|
||||||
blended_text_encoder = merge_lora(
|
# blend and load text encoder
|
||||||
path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder"
|
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights)
|
||||||
)
|
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
blended_text_encoder
|
text_encoder_opts = SessionOptions()
|
||||||
)
|
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
text_encoder_opts = SessionOptions()
|
OnnxRuntimeModel.load_model(
|
||||||
text_encoder_opts.add_external_initializers(
|
text_encoder_model.SerializeToString(),
|
||||||
list(text_encoder_names), list(text_encoder_values)
|
provider=device.ort_provider(),
|
||||||
)
|
sess_options=text_encoder_opts,
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
)
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
text_encoder_model.SerializeToString(),
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
blended_unet = merge_lora(
|
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
|
||||||
path.join(model, "unet", "model.onnx"), lora_models, "unet"
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
unet_opts = SessionOptions()
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
unet_opts = SessionOptions()
|
components["unet"] = OnnxRuntimeModel(
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
OnnxRuntimeModel.load_model(
|
||||||
components["unet"] = OnnxRuntimeModel(
|
unet_model.SerializeToString(),
|
||||||
OnnxRuntimeModel.load_model(
|
provider=device.ort_provider(),
|
||||||
unet_model.SerializeToString(),
|
sess_options=unet_opts,
|
||||||
provider=device.ort_provider(),
|
)
|
||||||
sess_options=unet_opts,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
|
|
@ -14,7 +14,8 @@ from ..server import ServerContext
|
||||||
from ..upscale import run_upscale_correction
|
from ..upscale import run_upscale_correction
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
|
from .utils import get_loras_from_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from re import compile
|
from re import compile
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
|
@ -128,3 +128,20 @@ def expand_prompt(
|
||||||
|
|
||||||
logger.debug("expanded prompt shape: %s", prompt_embeds.shape)
|
logger.debug("expanded prompt shape: %s", prompt_embeds.shape)
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
|
remaining_prompt = prompt
|
||||||
|
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||||
|
|
||||||
|
loras = []
|
||||||
|
next_match = lora_expr.search(remaining_prompt)
|
||||||
|
while next_match is not None:
|
||||||
|
logger.debug("found LoRA token in prompt: %s", next_match)
|
||||||
|
name, weight = next_match.groups()
|
||||||
|
loras.append((name, float(weight)))
|
||||||
|
# remove this match and look for another
|
||||||
|
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
||||||
|
next_match = lora_expr.search(remaining_prompt)
|
||||||
|
|
||||||
|
return (remaining_prompt, loras)
|
||||||
|
|
Loading…
Reference in New Issue