1
0
Fork 0

feat(api): parse LoRA names from prompt

This commit is contained in:
Sean Sube 2023-03-14 22:28:18 -05:00
parent 03f4e1b922
commit 143904fc51
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 23 additions and 7 deletions

View File

@ -1,5 +1,6 @@
from logging import getLogger
from os import path
from re import compile
from typing import Any, List, Optional, Tuple
import numpy as np
@ -106,11 +107,21 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt]
def get_loras_from_prompt(prompt: str) -> List[str]:
return [
"arch",
"glass",
]
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
remaining_prompt = prompt
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
loras = []
next_match = lora_expr.search(remaining_prompt)
while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match)
name, weight = next_match.groups()
loras.append(name)
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
next_match = lora_expr.search(remaining_prompt)
return (remaining_prompt, loras)
def optimize_pipeline(

View File

@ -28,7 +28,10 @@ def run_txt2img_pipeline(
upscale: UpscaleParams,
) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
loras = get_loras_from_prompt(params.prompt)
(prompt, loras) = get_loras_from_prompt(params.prompt)
params.prompt = prompt
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
@ -101,7 +104,9 @@ def run_img2img_pipeline(
source: Image.Image,
strength: float,
) -> None:
loras = get_loras_from_prompt(params.prompt)
(prompt, loras) = get_loras_from_prompt(params.prompt)
params.prompt = prompt
pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,