1
0
Fork 0

feat(api): load Textual Inversions from prompt

This commit is contained in:
Sean Sube 2023-03-15 08:51:12 -05:00
parent 829cedc934
commit a2e21f427f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 36 additions and 20 deletions

View File

@ -156,7 +156,7 @@ def load_pipeline(
scheduler_name: str, scheduler_name: str,
device: DeviceParams, device: DeviceParams,
lpw: bool, lpw: bool,
inversion: Optional[str], inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None,
): ):
loras = loras or [] loras = loras or []
@ -166,7 +166,7 @@ def load_pipeline(
device.device, device.device,
device.provider, device.provider,
lpw, lpw,
inversion, inversions,
loras, loras,
) )
scheduler_key = (scheduler_name, model) scheduler_key = (scheduler_name, model)
@ -215,19 +215,21 @@ def load_pipeline(
) )
} }
if inversion is not None: if inversions is not None and len(inversions) > 0:
logger.debug("loading text encoder from %s", inversion) inversion = inversions[0]
logger.debug("loading Textual Inversion from %s", inversion)
# TODO: blend the inversion models
components["text_encoder"] = OnnxRuntimeModel.from_pretrained( components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
path.join(inversion, "text_encoder"), path.join(server.model_path, inversion, "text_encoder"),
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components["tokenizer"] = CLIPTokenizer.from_pretrained( components["tokenizer"] = CLIPTokenizer.from_pretrained(
path.join(inversion, "tokenizer"), path.join(server.model_path, inversions, "tokenizer"),
) )
# test LoRA blending # test LoRA blending
if len(loras) > 0: if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras) lora_names, lora_weights = zip(*loras)
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names] lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
logger.info("blending base model %s with LoRA models: %s", model, lora_models) logger.info("blending base model %s with LoRA models: %s", model, lora_models)

View File

@ -15,7 +15,7 @@ from ..upscale import run_upscale_correction
from ..utils import run_gc from ..utils import run_gc
from ..worker import WorkerContext from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline from .load import get_latents_from_seed, load_pipeline
from .utils import get_loras_from_prompt from .utils import get_inversions_from_prompt, get_loras_from_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
@ -31,6 +31,7 @@ def run_txt2img_pipeline(
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
(prompt, loras) = get_loras_from_prompt(params.prompt) (prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt params.prompt = prompt
pipe = load_pipeline( pipe = load_pipeline(
@ -40,7 +41,7 @@ def run_txt2img_pipeline(
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
params.lpw, params.lpw,
params.inversion, inversions,
loras, loras,
) )
progress = job.get_progress_callback() progress = job.get_progress_callback()
@ -106,6 +107,7 @@ def run_img2img_pipeline(
strength: float, strength: float,
) -> None: ) -> None:
(prompt, loras) = get_loras_from_prompt(params.prompt) (prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt params.prompt = prompt
pipe = load_pipeline( pipe = load_pipeline(
@ -115,7 +117,7 @@ def run_img2img_pipeline(
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
params.lpw, params.lpw,
params.inversion, inversions,
loras, loras,
) )
progress = job.get_progress_callback() progress = job.get_progress_callback()

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from re import compile from re import compile, Pattern
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
@ -9,8 +9,10 @@ from diffusers import OnnxStableDiffusionPipeline
logger = getLogger(__name__) logger = getLogger(__name__)
INVERSION_TOKEN = compile(r"\<inversion:(\w+):([\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:(\w+):([\.|\d]+)\>")
MAX_TOKENS_PER_GROUP = 77 MAX_TOKENS_PER_GROUP = 77
PATTERN_RANGE = compile("(\\w+)-{(\\d+),(\\d+)(?:,(\\d+))?}") PATTERN_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
def expand_prompt_ranges(prompt: str) -> str: def expand_prompt_ranges(prompt: str) -> str:
@ -130,18 +132,28 @@ def expand_prompt(
return prompt_embeds return prompt_embeds
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]: def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
remaining_prompt = prompt remaining_prompt = prompt
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
loras = [] tokens = []
next_match = lora_expr.search(remaining_prompt) next_match = pattern.search(remaining_prompt)
while next_match is not None: while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match) logger.debug("found token in prompt: %s", next_match)
name, weight = next_match.groups() name, weight = next_match.groups()
loras.append((name, float(weight))) tokens.append((name, float(weight)))
# remove this match and look for another # remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
next_match = lora_expr.search(remaining_prompt) next_match = pattern.search(remaining_prompt)
return (remaining_prompt, loras) return (remaining_prompt, tokens)
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, LORA_TOKEN)
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)