support multi-stage prompts in prompt filter
This commit is contained in:
parent
bfdaae952f
commit
61272b9620
|
@ -1,9 +1,11 @@
|
|||
from logging import getLogger
|
||||
from random import randint
|
||||
from re import sub
|
||||
from typing import Optional
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from ..diffusers.utils import split_prompt
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
|
@ -13,6 +15,10 @@ from .result import StageResult
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
LENGTH_MARGIN = 15
|
||||
RETRY_LIMIT = 5
|
||||
|
||||
|
||||
class TextPromptStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
|
@ -26,28 +32,55 @@ class TextPromptStage(BaseStage):
|
|||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||
exclude_tokens: Optional[str] = None,
|
||||
min_length: int = 75,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
device = worker.device.torch_str()
|
||||
gpt2_pipe = pipeline(
|
||||
text_pipe = pipeline(
|
||||
"text-generation",
|
||||
model=prompt_model,
|
||||
tokenizer="gpt2",
|
||||
device=device,
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
input = params.prompt
|
||||
max_length = len(input) + randint(60, 90)
|
||||
prompt_parts = split_prompt(params.prompt)
|
||||
prompt_results = []
|
||||
for prompt in prompt_parts:
|
||||
retries = 0
|
||||
while len(prompt) < min_length and retries < RETRY_LIMIT:
|
||||
max_length = len(prompt) + randint(
|
||||
min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN
|
||||
)
|
||||
logger.debug(
|
||||
"generating new prompt with max length of %d from input prompt: %s",
|
||||
"extending input prompt to max length of %d from %s: %s",
|
||||
max_length,
|
||||
input,
|
||||
len(prompt),
|
||||
prompt,
|
||||
)
|
||||
|
||||
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1)
|
||||
result = text_pipe(
|
||||
prompt, max_length=max_length, num_return_sequences=1
|
||||
)
|
||||
prompt = result[0]["generated_text"].strip()
|
||||
logger.debug("replacing prompt with: %s", prompt)
|
||||
|
||||
params.prompt = prompt
|
||||
if exclude_tokens:
|
||||
logger.debug(
|
||||
"removing excluded tokens from prompt: %s", exclude_tokens
|
||||
)
|
||||
prompt = sub(exclude_tokens, "", prompt)
|
||||
|
||||
if retries >= RETRY_LIMIT:
|
||||
logger.warning(
|
||||
"failed to extend input prompt to min length of %d, ended up with %d: %s",
|
||||
min_length,
|
||||
len(prompt),
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_results.append(prompt)
|
||||
|
||||
complete_prompt = " || ".join(prompt_results)
|
||||
logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt)
|
||||
params.prompt = complete_prompt
|
||||
return sources
|
||||
|
|
|
@ -475,9 +475,16 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
|
|||
return tile
|
||||
|
||||
|
||||
def split_prompt(prompt: str) -> List[str]:
|
||||
if "||" in prompt:
|
||||
return prompt.split("||")
|
||||
|
||||
return [prompt]
|
||||
|
||||
|
||||
def slice_prompt(prompt: str, slice: int) -> str:
|
||||
if "||" in prompt:
|
||||
parts = prompt.split("||")
|
||||
parts = split_prompt(prompt)
|
||||
return parts[min(slice, len(parts) - 1)]
|
||||
else:
|
||||
return prompt
|
||||
|
|
Loading…
Reference in New Issue