1
0
Fork 0

feat(api): parse LoRA weights from prompt

This commit is contained in:
Sean Sube 2023-03-15 08:30:31 -05:00
parent 45166f281e
commit a7f77a033d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 50 additions and 60 deletions

View File

@ -61,7 +61,7 @@ def fix_node_name(key: str):
def merge_lora( def merge_lora(
base_name: str, base_name: str,
lora_names: str, lora_names: List[str],
dest_type: Literal["text_encoder", "unet"], dest_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None, lora_weights: "np.NDArray[np.float64]" = None,
): ):

View File

@ -37,7 +37,7 @@ try:
except ImportError: except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
from ..params import DeviceParams, Size from ..params import DeviceParams, Size
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
@ -107,26 +107,6 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt] return full_latents[:, :, y:yt, x:xt]
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
remaining_prompt = prompt
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
loras = []
next_match = lora_expr.search(remaining_prompt)
while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match)
name, weight = next_match.groups()
loras.append(name)
# remove this match and look for another
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = lora_expr.search(remaining_prompt)
return (remaining_prompt, loras)
def optimize_pipeline( def optimize_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
@ -177,7 +157,7 @@ def load_pipeline(
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
inversion: Optional[str], inversion: Optional[str],
loras: Optional[List[str]] = None, loras: Optional[List[Tuple[str, float]]] = None,
): ):
loras = loras or [] loras = loras or []
pipe_key = ( pipe_key = (
@ -247,23 +227,17 @@ def load_pipeline(
) )
# test LoRA blending # test LoRA blending
lora_models = [ lora_names, lora_weights = zip(*loras)
path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models) logger.info("blending base model %s with LoRA models: %s", model, lora_models)
if len(lora_models) > 0:
# blend and load text encoder # blend and load text encoder
blended_text_encoder = merge_lora( blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights)
path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder" (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
)
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(
blended_text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers( text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
list(text_encoder_names), list(text_encoder_values)
)
components["text_encoder"] = OnnxRuntimeModel( components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
text_encoder_model.SerializeToString(), text_encoder_model.SerializeToString(),
@ -273,9 +247,7 @@ def load_pipeline(
) )
# blend and load unet # blend and load unet
blended_unet = merge_lora( blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
path.join(model, "unet", "model.onnx"), lora_models, "unet"
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = SessionOptions()

View File

@ -14,7 +14,8 @@ from ..server import ServerContext
from ..upscale import run_upscale_correction from ..upscale import run_upscale_correction
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline from .load import get_latents_from_seed, load_pipeline
from .utils import get_loras_from_prompt
logger = getLogger(__name__) logger = getLogger(__name__)

View File

@ -1,7 +1,7 @@
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from re import compile from re import compile
from typing import List, Optional from typing import List, Optional, Tuple
import numpy as np import numpy as np
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
@ -128,3 +128,20 @@ def expand_prompt(
logger.debug("expanded prompt shape: %s", prompt_embeds.shape) logger.debug("expanded prompt shape: %s", prompt_embeds.shape)
return prompt_embeds return prompt_embeds
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
remaining_prompt = prompt
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
loras = []
next_match = lora_expr.search(remaining_prompt)
while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match)
name, weight = next_match.groups()
loras.append((name, float(weight)))
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
next_match = lora_expr.search(remaining_prompt)
return (remaining_prompt, loras)