diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 06136a1c..ee1719c5 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -1,5 +1,6 @@ from logging import getLogger from os import path +from re import compile from typing import Any, List, Optional, Tuple import numpy as np @@ -106,11 +107,21 @@ def get_tile_latents( return full_latents[:, :, y:yt, x:xt] -def get_loras_from_prompt(prompt: str) -> List[str]: - return [ - "arch", - "glass", - ] +def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]: + remaining_prompt = prompt + lora_expr = compile(r"\") + + loras = [] + next_match = lora_expr.search(remaining_prompt) + while next_match is not None: + logger.debug("found LoRA token in prompt: %s", next_match) + name, weight = next_match.groups() + loras.append(name) + # remove this match and look for another + remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] + next_match = lora_expr.search(remaining_prompt) + + return (remaining_prompt, loras) def optimize_pipeline( diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 95c3380c..3d5cb935 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -28,7 +28,10 @@ def run_txt2img_pipeline( upscale: UpscaleParams, ) -> None: latents = get_latents_from_seed(params.seed, size, batch=params.batch) - loras = get_loras_from_prompt(params.prompt) + + (prompt, loras) = get_loras_from_prompt(params.prompt) + params.prompt = prompt + pipe = load_pipeline( server, OnnxStableDiffusionPipeline, @@ -101,7 +104,9 @@ def run_img2img_pipeline( source: Image.Image, strength: float, ) -> None: - loras = get_loras_from_prompt(params.prompt) + (prompt, loras) = get_loras_from_prompt(params.prompt) + params.prompt = prompt + pipe = load_pipeline( server, OnnxStableDiffusionImg2ImgPipeline,