1
0
Fork 0
onnx-web/api/onnx_web/chain/text_prompt.py

96 lines
2.9 KiB
Python
Raw Permalink Normal View History

from logging import getLogger
from random import randint
2024-02-24 17:06:32 +00:00
from re import match, sub
from typing import Optional
from transformers import pipeline
from ..diffusers.utils import split_prompt
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
LENGTH_MARGIN = 15
RETRY_LIMIT = 5
class TextPromptStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
sources: StageResult,
*,
callback: Optional[ProgressCallback] = None,
2024-02-17 21:54:11 +00:00
prompt_filter: str,
2024-02-17 21:50:04 +00:00
remove_tokens: Optional[str] = None,
add_suffix: Optional[str] = None,
2024-02-24 17:06:32 +00:00
min_length: int = 80,
**kwargs,
) -> StageResult:
device = worker.device.torch_str()
text_pipe = pipeline(
"text-generation",
2024-02-17 21:50:04 +00:00
model=prompt_filter,
device=device,
framework="pt",
)
prompt_parts = split_prompt(params.prompt)
prompt_results = []
for prompt in prompt_parts:
retries = 0
while len(prompt) < min_length and retries < RETRY_LIMIT:
max_length = len(prompt) + randint(
min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN
)
logger.debug(
"extending input prompt to max length of %d from %s: %s",
max_length,
len(prompt),
prompt,
)
result = text_pipe(
prompt, max_length=max_length, num_return_sequences=1
)
prompt = result[0]["generated_text"].strip()
2024-02-17 21:50:04 +00:00
if remove_tokens:
logger.debug(
2024-02-17 21:50:04 +00:00
"removing excluded tokens from prompt: %s", remove_tokens
)
2024-02-24 17:06:32 +00:00
remove_limit = 3
while remove_limit > 0 and match(remove_tokens, prompt):
prompt = sub(remove_tokens, "", prompt)
remove_limit -= 1
if retries >= RETRY_LIMIT:
logger.warning(
"failed to extend input prompt to min length of %d, ended up with %d: %s",
min_length,
len(prompt),
prompt,
)
2024-02-17 21:50:04 +00:00
if add_suffix:
prompt = f"{prompt}, {add_suffix}"
logger.trace("adding suffix to prompt: %s", prompt)
prompt_results.append(prompt)
complete_prompt = " || ".join(prompt_results)
logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt)
params.prompt = complete_prompt
return sources