1
0
Fork 0

feat(api): parse LoRA names from prompt

This commit is contained in:
Sean Sube 2023-03-14 22:28:18 -05:00
parent 03f4e1b922
commit 143904fc51
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 23 additions and 7 deletions

View File

@ -1,5 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from re import compile
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
@ -106,11 +107,21 @@ def get_tile_latents(
return full_latents[:, :, y:yt, x:xt] return full_latents[:, :, y:yt, x:xt]
def get_loras_from_prompt(prompt: str) -> List[str]: def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
return [ remaining_prompt = prompt
"arch", lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
"glass",
] loras = []
next_match = lora_expr.search(remaining_prompt)
while next_match is not None:
logger.debug("found LoRA token in prompt: %s", next_match)
name, weight = next_match.groups()
loras.append(name)
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
next_match = lora_expr.search(remaining_prompt)
return (remaining_prompt, loras)
def optimize_pipeline( def optimize_pipeline(

View File

@ -28,7 +28,10 @@ def run_txt2img_pipeline(
upscale: UpscaleParams, upscale: UpscaleParams,
) -> None: ) -> None:
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
loras = get_loras_from_prompt(params.prompt)
(prompt, loras) = get_loras_from_prompt(params.prompt)
params.prompt = prompt
pipe = load_pipeline( pipe = load_pipeline(
server, server,
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
@ -101,7 +104,9 @@ def run_img2img_pipeline(
source: Image.Image, source: Image.Image,
strength: float, strength: float,
) -> None: ) -> None:
loras = get_loras_from_prompt(params.prompt) (prompt, loras) = get_loras_from_prompt(params.prompt)
params.prompt = prompt
pipe = load_pipeline( pipe = load_pipeline(
server, server,
OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionImg2ImgPipeline,