feat(api): load Textual Inversions from prompt
This commit is contained in:
parent
829cedc934
commit
a2e21f427f
|
@ -156,7 +156,7 @@ def load_pipeline(
|
|||
scheduler_name: str,
|
||||
device: DeviceParams,
|
||||
lpw: bool,
|
||||
inversion: Optional[str],
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
):
|
||||
loras = loras or []
|
||||
|
@ -166,7 +166,7 @@ def load_pipeline(
|
|||
device.device,
|
||||
device.provider,
|
||||
lpw,
|
||||
inversion,
|
||||
inversions,
|
||||
loras,
|
||||
)
|
||||
scheduler_key = (scheduler_name, model)
|
||||
|
@ -215,19 +215,21 @@ def load_pipeline(
|
|||
)
|
||||
}
|
||||
|
||||
if inversion is not None:
|
||||
logger.debug("loading text encoder from %s", inversion)
|
||||
if inversions is not None and len(inversions) > 0:
|
||||
inversion = inversions[0]
|
||||
logger.debug("loading Textual Inversion from %s", inversion)
|
||||
# TODO: blend the inversion models
|
||||
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
|
||||
path.join(inversion, "text_encoder"),
|
||||
path.join(server.model_path, inversion, "text_encoder"),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
components["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
path.join(inversion, "tokenizer"),
|
||||
path.join(server.model_path, inversions, "tokenizer"),
|
||||
)
|
||||
|
||||
# test LoRA blending
|
||||
if len(loras) > 0:
|
||||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..upscale import run_upscale_correction
|
|||
from ..utils import run_gc
|
||||
from ..worker import WorkerContext
|
||||
from .load import get_latents_from_seed, load_pipeline
|
||||
from .utils import get_loras_from_prompt
|
||||
from .utils import get_inversions_from_prompt, get_loras_from_prompt
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -31,6 +31,7 @@ def run_txt2img_pipeline(
|
|||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
|
||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||
(prompt, inversions) = get_inversions_from_prompt(prompt)
|
||||
params.prompt = prompt
|
||||
|
||||
pipe = load_pipeline(
|
||||
|
@ -40,7 +41,7 @@ def run_txt2img_pipeline(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
inversions,
|
||||
loras,
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
|
@ -106,6 +107,7 @@ def run_img2img_pipeline(
|
|||
strength: float,
|
||||
) -> None:
|
||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||
(prompt, inversions) = get_inversions_from_prompt(prompt)
|
||||
params.prompt = prompt
|
||||
|
||||
pipe = load_pipeline(
|
||||
|
@ -115,7 +117,7 @@ def run_img2img_pipeline(
|
|||
params.scheduler,
|
||||
job.get_device(),
|
||||
params.lpw,
|
||||
params.inversion,
|
||||
inversions,
|
||||
loras,
|
||||
)
|
||||
progress = job.get_progress_callback()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from math import ceil
|
||||
from re import compile
|
||||
from re import compile, Pattern
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
@ -9,8 +9,10 @@ from diffusers import OnnxStableDiffusionPipeline
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
INVERSION_TOKEN = compile(r"\<inversion:(\w+):([\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||
MAX_TOKENS_PER_GROUP = 77
|
||||
PATTERN_RANGE = compile("(\\w+)-{(\\d+),(\\d+)(?:,(\\d+))?}")
|
||||
PATTERN_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
|
||||
|
||||
def expand_prompt_ranges(prompt: str) -> str:
|
||||
|
@ -130,18 +132,28 @@ def expand_prompt(
|
|||
return prompt_embeds
|
||||
|
||||
|
||||
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
TODO: replace with Arpeggio
|
||||
"""
|
||||
remaining_prompt = prompt
|
||||
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||
|
||||
loras = []
|
||||
next_match = lora_expr.search(remaining_prompt)
|
||||
tokens = []
|
||||
next_match = pattern.search(remaining_prompt)
|
||||
while next_match is not None:
|
||||
logger.debug("found LoRA token in prompt: %s", next_match)
|
||||
logger.debug("found token in prompt: %s", next_match)
|
||||
name, weight = next_match.groups()
|
||||
loras.append((name, float(weight)))
|
||||
tokens.append((name, float(weight)))
|
||||
# remove this match and look for another
|
||||
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
||||
next_match = lora_expr.search(remaining_prompt)
|
||||
next_match = pattern.search(remaining_prompt)
|
||||
|
||||
return (remaining_prompt, loras)
|
||||
return (remaining_prompt, tokens)
|
||||
|
||||
|
||||
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
return get_tokens_from_prompt(prompt, LORA_TOKEN)
|
||||
|
||||
|
||||
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
|
||||
|
|
Loading…
Reference in New Issue