From 143904fc519c4f3aeeb1336fce1b2fca3d4d42e8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 14 Mar 2023 22:28:18 -0500 Subject: [PATCH] feat(api): parse LoRA names from prompt --- api/onnx_web/diffusers/load.py | 21 ++++++++++++++++----- api/onnx_web/diffusers/run.py | 9 +++++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 06136a1c..ee1719c5 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -1,5 +1,6 @@ from logging import getLogger from os import path +from re import compile from typing import Any, List, Optional, Tuple import numpy as np @@ -106,11 +107,21 @@ def get_tile_latents( return full_latents[:, :, y:yt, x:xt] -def get_loras_from_prompt(prompt: str) -> List[str]: - return [ - "arch", - "glass", - ] +def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]: + remaining_prompt = prompt + lora_expr = compile(r"\") + + loras = [] + next_match = lora_expr.search(remaining_prompt) + while next_match is not None: + logger.debug("found LoRA token in prompt: %s", next_match) + name, weight = next_match.groups() + loras.append(name) + # remove this match and look for another + remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] + next_match = lora_expr.search(remaining_prompt) + + return (remaining_prompt, loras) def optimize_pipeline( diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 95c3380c..3d5cb935 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -28,7 +28,10 @@ def run_txt2img_pipeline( upscale: UpscaleParams, ) -> None: latents = get_latents_from_seed(params.seed, size, batch=params.batch) - loras = get_loras_from_prompt(params.prompt) + + (prompt, loras) = get_loras_from_prompt(params.prompt) + params.prompt = prompt + pipe = load_pipeline( server, OnnxStableDiffusionPipeline, @@ -101,7 +104,9 @@ def run_img2img_pipeline( source: Image.Image, strength: float, ) -> None: - loras = get_loras_from_prompt(params.prompt) + (prompt, loras) = get_loras_from_prompt(params.prompt) + params.prompt = prompt + pipe = load_pipeline( server, OnnxStableDiffusionImg2ImgPipeline,