1
0
Fork 0

feat(api): load Textual Inversions from prompt

This commit is contained in:
Sean Sube 2023-03-15 08:51:12 -05:00
parent 829cedc934
commit a2e21f427f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 36 additions and 20 deletions

View File

@ -156,7 +156,7 @@ def load_pipeline(
scheduler_name: str,
device: DeviceParams,
lpw: bool,
inversion: Optional[str],
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
):
loras = loras or []
@ -166,7 +166,7 @@ def load_pipeline(
device.device,
device.provider,
lpw,
inversion,
inversions,
loras,
)
scheduler_key = (scheduler_name, model)
@ -215,19 +215,21 @@ def load_pipeline(
)
}
if inversion is not None:
logger.debug("loading text encoder from %s", inversion)
if inversions is not None and len(inversions) > 0:
inversion = inversions[0]
logger.debug("loading Textual Inversion from %s", inversion)
# TODO: blend the inversion models
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
path.join(inversion, "text_encoder"),
path.join(server.model_path, inversion, "text_encoder"),
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
components["tokenizer"] = CLIPTokenizer.from_pretrained(
path.join(inversion, "tokenizer"),
path.join(server.model_path, inversions, "tokenizer"),
)
# test LoRA blending
if len(loras) > 0:
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)

View File

@ -15,7 +15,7 @@ from ..upscale import run_upscale_correction
from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline
from .utils import get_loras_from_prompt
from .utils import get_inversions_from_prompt, get_loras_from_prompt
logger = getLogger(__name__)
@ -31,6 +31,7 @@ def run_txt2img_pipeline(
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
pipe = load_pipeline(
@ -40,7 +41,7 @@ def run_txt2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
inversions,
loras,
)
progress = job.get_progress_callback()
@ -106,6 +107,7 @@ def run_img2img_pipeline(
strength: float,
) -> None:
(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt
pipe = load_pipeline(
@ -115,7 +117,7 @@ def run_img2img_pipeline(
params.scheduler,
job.get_device(),
params.lpw,
params.inversion,
inversions,
loras,
)
progress = job.get_progress_callback()

View File

@ -1,6 +1,6 @@
from logging import getLogger
from math import ceil
from re import compile
from re import compile, Pattern
from typing import List, Optional, Tuple
import numpy as np
@ -9,8 +9,10 @@ from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__)
INVERSION_TOKEN = compile(r"\<inversion:(\w+):([\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:(\w+):([\.|\d]+)\>")
MAX_TOKENS_PER_GROUP = 77
PATTERN_RANGE = compile("(\\w+)-{(\\d+),(\\d+)(?:,(\\d+))?}")
PATTERN_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
def expand_prompt_ranges(prompt: str) -> str:
@ -130,18 +132,28 @@ def expand_prompt(
return prompt_embeds
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
loras = []
next_match = lora_expr.search(remaining_prompt)
tokens = []
next_match = pattern.search(remaining_prompt)
while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match)
logger.debug("found token in prompt: %s", next_match)
name, weight = next_match.groups()
loras.append((name, float(weight)))
tokens.append((name, float(weight)))
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
next_match = lora_expr.search(remaining_prompt)
next_match = pattern.search(remaining_prompt)
return (remaining_prompt, loras)
return (remaining_prompt, tokens)
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, LORA_TOKEN)
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)