feat(api): parse LoRA names from prompt
This commit is contained in:
parent
03f4e1b922
commit
143904fc51
|
@ -1,5 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
from re import compile
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -106,11 +107,21 @@ def get_tile_latents(
|
||||||
return full_latents[:, :, y:yt, x:xt]
|
return full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
|
|
||||||
def get_loras_from_prompt(prompt: str) -> List[str]:
|
def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
|
||||||
return [
|
remaining_prompt = prompt
|
||||||
"arch",
|
lora_expr = compile(r"\<lora:(\w+):([\.|\d]+)\>")
|
||||||
"glass",
|
|
||||||
]
|
loras = []
|
||||||
|
next_match = lora_expr.search(remaining_prompt)
|
||||||
|
while next_match is not None:
|
||||||
|
logger.debug("found LoRA token in prompt: %s", next_match)
|
||||||
|
name, weight = next_match.groups()
|
||||||
|
loras.append(name)
|
||||||
|
# remove this match and look for another
|
||||||
|
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
||||||
|
next_match = lora_expr.search(remaining_prompt)
|
||||||
|
|
||||||
|
return (remaining_prompt, loras)
|
||||||
|
|
||||||
|
|
||||||
def optimize_pipeline(
|
def optimize_pipeline(
|
||||||
|
|
|
@ -28,7 +28,10 @@ def run_txt2img_pipeline(
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
loras = get_loras_from_prompt(params.prompt)
|
|
||||||
|
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||||
|
params.prompt = prompt
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
|
@ -101,7 +104,9 @@ def run_img2img_pipeline(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
strength: float,
|
strength: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
loras = get_loras_from_prompt(params.prompt)
|
(prompt, loras) = get_loras_from_prompt(params.prompt)
|
||||||
|
params.prompt = prompt
|
||||||
|
|
||||||
pipe = load_pipeline(
|
pipe = load_pipeline(
|
||||||
server,
|
server,
|
||||||
OnnxStableDiffusionImg2ImgPipeline,
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
|
Loading…
Reference in New Issue