feat(api): load Textual Inversions from prompt
This commit is contained in:
parent
829cedc934
commit
a2e21f427f
|
@ -156,7 +156,7 @@ def load_pipeline(
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
lpw: bool,
|
lpw: bool,
|
||||||
inversion: Optional[str],
|
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
):
|
):
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
|
@ -166,7 +166,7 @@ def load_pipeline(
|
||||||
device.device,
|
device.device,
|
||||||
device.provider,
|
device.provider,
|
||||||
lpw,
|
lpw,
|
||||||
inversion,
|
inversions,
|
||||||
loras,
|
loras,
|
||||||
)
|
)
|
||||||
scheduler_key = (scheduler_name, model)
|
scheduler_key = (scheduler_name, model)
|
||||||
|
@ -215,19 +215,21 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if inversion is not None:
|
if inversions is not None and len(inversions) > 0:
|
||||||
logger.debug("loading text encoder from %s", inversion)
|
inversion = inversions[0]
|
||||||
|
logger.debug("loading Textual Inversion from %s", inversion)
|
||||||
|
# TODO: blend the inversion models
|
||||||
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
|
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
|
||||||
path.join(inversion, "text_encoder"),
|
path.join(server.model_path, inversion, "text_encoder"),
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components["tokenizer"] = CLIPTokenizer.from_pretrained(
|
components["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||||
path.join(inversion, "tokenizer"),
|
path.join(server.model_path, inversions, "tokenizer"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# test LoRA blending
|
# test LoRA blending
|
||||||
if len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ..upscale import run_upscale_correction
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .load import get_latents_from_seed, load_pipeline
|
from .load import get_latents_from_seed, load_pipeline
|
||||||
from .utils import get_loras_from_prompt
|
from .utils import get_inversions_from_prompt, get_loras_from_prompt
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ def run_txt2img_pipeline(
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
|
|
||||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||||
|
(prompt, inversions) = get_inversions_from_prompt(prompt)
|
||||||
params.prompt = prompt
|
params.prompt = prompt
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
@ -40,7 +41,7 @@ def run_txt2img_pipeline(
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
params.lpw,
|
params.lpw,
|
||||||
params.inversion,
|
inversions,
|
||||||
loras,
|
loras,
|
||||||
)
|
)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
|
@ -106,6 +107,7 @@ def run_img2img_pipeline(
|
||||||
strength: float,
|
strength: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||||
|
(prompt, inversions) = get_inversions_from_prompt(prompt)
|
||||||
params.prompt = prompt
|
params.prompt = prompt
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
|
@ -115,7 +117,7 @@ def run_img2img_pipeline(
|
||||||
params.scheduler,
|
params.scheduler,
|
||||||
job.get_device(),
|
job.get_device(),
|
||||||
params.lpw,
|
params.lpw,
|
||||||
params.inversion,
|
inversions,
|
||||||
loras,
|
loras,
|
||||||
)
|
)
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from re import compile
|
from re import compile, Pattern
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -9,8 +9,10 @@ from diffusers import OnnxStableDiffusionPipeline
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
INVERSION_TOKEN = compile(r"\<inversion:(\w+):([\.|\d]+)\>")
|
||||||
|
LORA_TOKEN = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||||
MAX_TOKENS_PER_GROUP = 77
|
MAX_TOKENS_PER_GROUP = 77
|
||||||
PATTERN_RANGE = compile("(\\w+)-{(\\d+),(\\d+)(?:,(\\d+))?}")
|
PATTERN_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
|
|
||||||
|
|
||||||
def expand_prompt_ranges(prompt: str) -> str:
|
def expand_prompt_ranges(prompt: str) -> str:
|
||||||
|
@ -130,18 +132,28 @@ def expand_prompt(
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
|
"""
|
||||||
|
TODO: replace with Arpeggio
|
||||||
|
"""
|
||||||
remaining_prompt = prompt
|
remaining_prompt = prompt
|
||||||
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
|
||||||
|
|
||||||
loras = []
|
tokens = []
|
||||||
next_match = lora_expr.search(remaining_prompt)
|
next_match = pattern.search(remaining_prompt)
|
||||||
while next_match is not None:
|
while next_match is not None:
|
||||||
logger.debug("found LoRA token in prompt: %s", next_match)
|
logger.debug("found token in prompt: %s", next_match)
|
||||||
name, weight = next_match.groups()
|
name, weight = next_match.groups()
|
||||||
loras.append((name, float(weight)))
|
tokens.append((name, float(weight)))
|
||||||
# remove this match and look for another
|
# remove this match and look for another
|
||||||
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
||||||
next_match = lora_expr.search(remaining_prompt)
|
next_match = pattern.search(remaining_prompt)
|
||||||
|
|
||||||
return (remaining_prompt, loras)
|
return (remaining_prompt, tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
|
return get_tokens_from_prompt(prompt, LORA_TOKEN)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
|
return get_tokens_from_prompt(prompt, INVERSION_TOKEN)
|
||||||
|
|
Loading…
Reference in New Issue