96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
from logging import getLogger
|
|
from random import randint
|
|
from re import match, sub
|
|
from typing import Optional
|
|
|
|
from transformers import pipeline
|
|
|
|
from ..diffusers.utils import split_prompt
|
|
from ..params import ImageParams, SizeChart, StageParams
|
|
from ..server import ServerContext
|
|
from ..worker import ProgressCallback, WorkerContext
|
|
from .base import BaseStage
|
|
from .result import StageResult
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
LENGTH_MARGIN = 15
|
|
RETRY_LIMIT = 5
|
|
|
|
|
|
class TextPromptStage(BaseStage):
|
|
max_tile = SizeChart.max
|
|
|
|
def run(
|
|
self,
|
|
worker: WorkerContext,
|
|
server: ServerContext,
|
|
stage: StageParams,
|
|
params: ImageParams,
|
|
sources: StageResult,
|
|
*,
|
|
callback: Optional[ProgressCallback] = None,
|
|
prompt_filter: str,
|
|
remove_tokens: Optional[str] = None,
|
|
add_suffix: Optional[str] = None,
|
|
min_length: int = 80,
|
|
**kwargs,
|
|
) -> StageResult:
|
|
device = worker.device.torch_str()
|
|
text_pipe = pipeline(
|
|
"text-generation",
|
|
model=prompt_filter,
|
|
device=device,
|
|
framework="pt",
|
|
)
|
|
|
|
prompt_parts = split_prompt(params.prompt)
|
|
prompt_results = []
|
|
for prompt in prompt_parts:
|
|
retries = 0
|
|
while len(prompt) < min_length and retries < RETRY_LIMIT:
|
|
max_length = len(prompt) + randint(
|
|
min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN
|
|
)
|
|
logger.debug(
|
|
"extending input prompt to max length of %d from %s: %s",
|
|
max_length,
|
|
len(prompt),
|
|
prompt,
|
|
)
|
|
|
|
result = text_pipe(
|
|
prompt, max_length=max_length, num_return_sequences=1
|
|
)
|
|
prompt = result[0]["generated_text"].strip()
|
|
|
|
if remove_tokens:
|
|
logger.debug(
|
|
"removing excluded tokens from prompt: %s", remove_tokens
|
|
)
|
|
|
|
remove_limit = 3
|
|
while remove_limit > 0 and match(remove_tokens, prompt):
|
|
prompt = sub(remove_tokens, "", prompt)
|
|
remove_limit -= 1
|
|
|
|
if retries >= RETRY_LIMIT:
|
|
logger.warning(
|
|
"failed to extend input prompt to min length of %d, ended up with %d: %s",
|
|
min_length,
|
|
len(prompt),
|
|
prompt,
|
|
)
|
|
|
|
if add_suffix:
|
|
prompt = f"{prompt}, {add_suffix}"
|
|
logger.trace("adding suffix to prompt: %s", prompt)
|
|
|
|
prompt_results.append(prompt)
|
|
|
|
complete_prompt = " || ".join(prompt_results)
|
|
logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt)
|
|
params.prompt = complete_prompt
|
|
return sources
|