1
0
Fork 0

support multi-stage prompts in prompt filter

This commit is contained in:
Sean Sube 2024-02-11 19:04:06 -06:00
parent bfdaae952f
commit 61272b9620
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 54 additions and 14 deletions

View File

@ -1,9 +1,11 @@
from logging import getLogger from logging import getLogger
from random import randint from random import randint
from re import sub
from typing import Optional from typing import Optional
from transformers import pipeline from transformers import pipeline
from ..diffusers.utils import split_prompt
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
@ -13,6 +15,10 @@ from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
LENGTH_MARGIN = 15
RETRY_LIMIT = 5
class TextPromptStage(BaseStage): class TextPromptStage(BaseStage):
max_tile = SizeChart.max max_tile = SizeChart.max
@ -26,28 +32,55 @@ class TextPromptStage(BaseStage):
*, *,
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
exclude_tokens: Optional[str] = None,
min_length: int = 75,
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
device = worker.device.torch_str() device = worker.device.torch_str()
gpt2_pipe = pipeline( text_pipe = pipeline(
"text-generation", "text-generation",
model=prompt_model, model=prompt_model,
tokenizer="gpt2",
device=device, device=device,
framework="pt", framework="pt",
) )
input = params.prompt prompt_parts = split_prompt(params.prompt)
max_length = len(input) + randint(60, 90) prompt_results = []
logger.debug( for prompt in prompt_parts:
"generating new prompt with max length of %d from input prompt: %s", retries = 0
max_length, while len(prompt) < min_length and retries < RETRY_LIMIT:
input, max_length = len(prompt) + randint(
) min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN
)
logger.debug(
"extending input prompt to max length of %d from %s: %s",
max_length,
len(prompt),
prompt,
)
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1) result = text_pipe(
prompt = result[0]["generated_text"].strip() prompt, max_length=max_length, num_return_sequences=1
logger.debug("replacing prompt with: %s", prompt) )
prompt = result[0]["generated_text"].strip()
params.prompt = prompt if exclude_tokens:
logger.debug(
"removing excluded tokens from prompt: %s", exclude_tokens
)
prompt = sub(exclude_tokens, "", prompt)
if retries >= RETRY_LIMIT:
logger.warning(
"failed to extend input prompt to min length of %d, ended up with %d: %s",
min_length,
len(prompt),
prompt,
)
prompt_results.append(prompt)
complete_prompt = " || ".join(prompt_results)
logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt)
params.prompt = complete_prompt
return sources return sources

View File

@ -475,9 +475,16 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
return tile return tile
def split_prompt(prompt: str) -> List[str]:
if "||" in prompt:
return prompt.split("||")
return [prompt]
def slice_prompt(prompt: str, slice: int) -> str: def slice_prompt(prompt: str, slice: int) -> str:
if "||" in prompt: if "||" in prompt:
parts = prompt.split("||") parts = split_prompt(prompt)
return parts[min(slice, len(parts) - 1)] return parts[min(slice, len(parts) - 1)]
else: else:
return prompt return prompt