1
0
Fork 0

support multi-stage prompts in prompt filter

This commit is contained in:
Sean Sube 2024-02-11 19:04:06 -06:00
parent bfdaae952f
commit 61272b9620
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 54 additions and 14 deletions

View File

@ -1,9 +1,11 @@
from logging import getLogger
from random import randint
from re import sub
from typing import Optional
from transformers import pipeline
from ..diffusers.utils import split_prompt
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
@ -13,6 +15,10 @@ from .result import StageResult
logger = getLogger(__name__)
LENGTH_MARGIN = 15
RETRY_LIMIT = 5
class TextPromptStage(BaseStage):
max_tile = SizeChart.max
@ -26,28 +32,55 @@ class TextPromptStage(BaseStage):
*,
callback: Optional[ProgressCallback] = None,
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
exclude_tokens: Optional[str] = None,
min_length: int = 75,
**kwargs,
) -> StageResult:
device = worker.device.torch_str()
gpt2_pipe = pipeline(
text_pipe = pipeline(
"text-generation",
model=prompt_model,
tokenizer="gpt2",
device=device,
framework="pt",
)
input = params.prompt
max_length = len(input) + randint(60, 90)
logger.debug(
"generating new prompt with max length of %d from input prompt: %s",
max_length,
input,
)
prompt_parts = split_prompt(params.prompt)
prompt_results = []
for prompt in prompt_parts:
retries = 0
while len(prompt) < min_length and retries < RETRY_LIMIT:
max_length = len(prompt) + randint(
min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN
)
logger.debug(
"extending input prompt to max length of %d from %s: %s",
max_length,
len(prompt),
prompt,
)
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1)
prompt = result[0]["generated_text"].strip()
logger.debug("replacing prompt with: %s", prompt)
result = text_pipe(
prompt, max_length=max_length, num_return_sequences=1
)
prompt = result[0]["generated_text"].strip()
params.prompt = prompt
if exclude_tokens:
logger.debug(
"removing excluded tokens from prompt: %s", exclude_tokens
)
prompt = sub(exclude_tokens, "", prompt)
if retries >= RETRY_LIMIT:
logger.warning(
"failed to extend input prompt to min length of %d, ended up with %d: %s",
min_length,
len(prompt),
prompt,
)
prompt_results.append(prompt)
complete_prompt = " || ".join(prompt_results)
logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt)
params.prompt = complete_prompt
return sources

View File

@ -475,9 +475,16 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
return tile
def split_prompt(prompt: str) -> List[str]:
if "||" in prompt:
return prompt.split("||")
return [prompt]
def slice_prompt(prompt: str, slice: int) -> str:
if "||" in prompt:
parts = prompt.split("||")
parts = split_prompt(prompt)
return parts[min(slice, len(parts) - 1)]
else:
return prompt