From a2e21f427f70f010a7fff40b7db4b724eda66df8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 15 Mar 2023 08:51:12 -0500 Subject: [PATCH] feat(api): load Textual Inversions from prompt --- api/onnx_web/diffusers/load.py | 16 +++++++++------- api/onnx_web/diffusers/run.py | 8 +++++--- api/onnx_web/diffusers/utils.py | 32 ++++++++++++++++++++++---------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index f9cd218b..cf0462a9 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -156,7 +156,7 @@ def load_pipeline( scheduler_name: str, device: DeviceParams, lpw: bool, - inversion: Optional[str], + inversions: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, ): loras = loras or [] @@ -166,7 +166,7 @@ def load_pipeline( device.device, device.provider, lpw, - inversion, + inversions, loras, ) scheduler_key = (scheduler_name, model) @@ -215,19 +215,21 @@ def load_pipeline( ) } - if inversion is not None: - logger.debug("loading text encoder from %s", inversion) + if inversions is not None and len(inversions) > 0: + inversion = inversions[0] + logger.debug("loading Textual Inversion from %s", inversion) + # TODO: blend the inversion models components["text_encoder"] = OnnxRuntimeModel.from_pretrained( - path.join(inversion, "text_encoder"), + path.join(server.model_path, inversion, "text_encoder"), provider=device.ort_provider(), sess_options=device.sess_options(), ) components["tokenizer"] = CLIPTokenizer.from_pretrained( - path.join(inversion, "tokenizer"), + path.join(server.model_path, inversions, "tokenizer"), ) # test LoRA blending - if len(loras) > 0: + if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names] logger.info("blending base model %s with LoRA models: %s", model, lora_models) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index dd7320ee..5a14173b 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -15,7 +15,7 @@ from ..upscale import run_upscale_correction from ..utils import run_gc from ..worker import WorkerContext from .load import get_latents_from_seed, load_pipeline -from .utils import get_loras_from_prompt +from .utils import get_inversions_from_prompt, get_loras_from_prompt logger = getLogger(__name__) @@ -31,6 +31,7 @@ def run_txt2img_pipeline( latents = get_latents_from_seed(params.seed, size, batch=params.batch) (prompt, loras) = get_loras_from_prompt(params.prompt) + (prompt, inversions) = get_inversions_from_prompt(prompt) params.prompt = prompt pipe = load_pipeline( @@ -40,7 +41,7 @@ def run_txt2img_pipeline( params.scheduler, job.get_device(), params.lpw, - params.inversion, + inversions, loras, ) progress = job.get_progress_callback() @@ -106,6 +107,7 @@ def run_img2img_pipeline( strength: float, ) -> None: (prompt, loras) = get_loras_from_prompt(params.prompt) + (prompt, inversions) = get_inversions_from_prompt(prompt) params.prompt = prompt pipe = load_pipeline( @@ -115,7 +117,7 @@ def run_img2img_pipeline( params.scheduler, job.get_device(), params.lpw, - params.inversion, + inversions, loras, ) progress = job.get_progress_callback() diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index c122225a..e1d34303 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -1,6 +1,6 @@ from logging import getLogger from math import ceil -from re import compile +from re import compile, Pattern from typing import List, Optional, Tuple import numpy as np @@ -9,8 +9,10 @@ from diffusers import OnnxStableDiffusionPipeline logger = getLogger(__name__) +INVERSION_TOKEN = compile(r"\") +LORA_TOKEN = compile(r"\") MAX_TOKENS_PER_GROUP = 77 -PATTERN_RANGE = compile("(\\w+)-{(\\d+),(\\d+)(?:,(\\d+))?}") +PATTERN_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") def expand_prompt_ranges(prompt: str) -> str: @@ -130,18 +132,28 @@ def expand_prompt( return prompt_embeds -def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]: +def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]: + """ + TODO: replace with Arpeggio + """ remaining_prompt = prompt - lora_expr = compile(r"\") - loras = [] - next_match = lora_expr.search(remaining_prompt) + tokens = [] + next_match = pattern.search(remaining_prompt) while next_match is not None: - logger.debug("found LoRA token in prompt: %s", next_match) + logger.debug("found token in prompt: %s", next_match) name, weight = next_match.groups() - loras.append((name, float(weight))) + tokens.append((name, float(weight))) # remove this match and look for another remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] - next_match = lora_expr.search(remaining_prompt) + next_match = pattern.search(remaining_prompt) - return (remaining_prompt, loras) + return (remaining_prompt, tokens) + + +def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]: + return get_tokens_from_prompt(prompt, LORA_TOKEN) + + +def get_inversions_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]: + return get_tokens_from_prompt(prompt, INVERSION_TOKEN)