support multi-stage prompts in prompt filter
This commit is contained in:
parent
bfdaae952f
commit
61272b9620
|
@ -1,9 +1,11 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from random import randint
|
from random import randint
|
||||||
|
from re import sub
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from ..diffusers.utils import split_prompt
|
||||||
from ..params import ImageParams, SizeChart, StageParams
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
@ -13,6 +15,10 @@ from .result import StageResult
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
LENGTH_MARGIN = 15
|
||||||
|
RETRY_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
class TextPromptStage(BaseStage):
|
class TextPromptStage(BaseStage):
|
||||||
max_tile = SizeChart.max
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
@ -26,28 +32,55 @@ class TextPromptStage(BaseStage):
|
||||||
*,
|
*,
|
||||||
callback: Optional[ProgressCallback] = None,
|
callback: Optional[ProgressCallback] = None,
|
||||||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||||
|
exclude_tokens: Optional[str] = None,
|
||||||
|
min_length: int = 75,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
device = worker.device.torch_str()
|
device = worker.device.torch_str()
|
||||||
gpt2_pipe = pipeline(
|
text_pipe = pipeline(
|
||||||
"text-generation",
|
"text-generation",
|
||||||
model=prompt_model,
|
model=prompt_model,
|
||||||
tokenizer="gpt2",
|
|
||||||
device=device,
|
device=device,
|
||||||
framework="pt",
|
framework="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
input = params.prompt
|
prompt_parts = split_prompt(params.prompt)
|
||||||
max_length = len(input) + randint(60, 90)
|
prompt_results = []
|
||||||
|
for prompt in prompt_parts:
|
||||||
|
retries = 0
|
||||||
|
while len(prompt) < min_length and retries < RETRY_LIMIT:
|
||||||
|
max_length = len(prompt) + randint(
|
||||||
|
min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN
|
||||||
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"generating new prompt with max length of %d from input prompt: %s",
|
"extending input prompt to max length of %d from %s: %s",
|
||||||
max_length,
|
max_length,
|
||||||
input,
|
len(prompt),
|
||||||
|
prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1)
|
result = text_pipe(
|
||||||
|
prompt, max_length=max_length, num_return_sequences=1
|
||||||
|
)
|
||||||
prompt = result[0]["generated_text"].strip()
|
prompt = result[0]["generated_text"].strip()
|
||||||
logger.debug("replacing prompt with: %s", prompt)
|
|
||||||
|
|
||||||
params.prompt = prompt
|
if exclude_tokens:
|
||||||
|
logger.debug(
|
||||||
|
"removing excluded tokens from prompt: %s", exclude_tokens
|
||||||
|
)
|
||||||
|
prompt = sub(exclude_tokens, "", prompt)
|
||||||
|
|
||||||
|
if retries >= RETRY_LIMIT:
|
||||||
|
logger.warning(
|
||||||
|
"failed to extend input prompt to min length of %d, ended up with %d: %s",
|
||||||
|
min_length,
|
||||||
|
len(prompt),
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_results.append(prompt)
|
||||||
|
|
||||||
|
complete_prompt = " || ".join(prompt_results)
|
||||||
|
logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt)
|
||||||
|
params.prompt = complete_prompt
|
||||||
return sources
|
return sources
|
||||||
|
|
|
@ -475,9 +475,16 @@ def repair_nan(tile: np.ndarray) -> np.ndarray:
|
||||||
return tile
|
return tile
|
||||||
|
|
||||||
|
|
||||||
|
def split_prompt(prompt: str) -> List[str]:
|
||||||
|
if "||" in prompt:
|
||||||
|
return prompt.split("||")
|
||||||
|
|
||||||
|
return [prompt]
|
||||||
|
|
||||||
|
|
||||||
def slice_prompt(prompt: str, slice: int) -> str:
|
def slice_prompt(prompt: str, slice: int) -> str:
|
||||||
if "||" in prompt:
|
if "||" in prompt:
|
||||||
parts = prompt.split("||")
|
parts = split_prompt(prompt)
|
||||||
return parts[min(slice, len(parts) - 1)]
|
return parts[min(slice, len(parts) - 1)]
|
||||||
else:
|
else:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
Loading…
Reference in New Issue