From 61272b96202817e8974143b94119fdf3a5d1cb65 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 11 Feb 2024 19:04:06 -0600 Subject: [PATCH] support multi-stage prompts in prompt filter --- api/onnx_web/chain/text_prompt.py | 59 ++++++++++++++++++++++++------- api/onnx_web/diffusers/utils.py | 9 ++++- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py index 0def7c83..3c3f8733 100644 --- a/api/onnx_web/chain/text_prompt.py +++ b/api/onnx_web/chain/text_prompt.py @@ -1,9 +1,11 @@ from logging import getLogger from random import randint +from re import sub from typing import Optional from transformers import pipeline +from ..diffusers.utils import split_prompt from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -13,6 +15,10 @@ from .result import StageResult logger = getLogger(__name__) +LENGTH_MARGIN = 15 +RETRY_LIMIT = 5 + + class TextPromptStage(BaseStage): max_tile = SizeChart.max @@ -26,28 +32,55 @@ class TextPromptStage(BaseStage): *, callback: Optional[ProgressCallback] = None, prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", + exclude_tokens: Optional[str] = None, + min_length: int = 75, **kwargs, ) -> StageResult: device = worker.device.torch_str() - gpt2_pipe = pipeline( + text_pipe = pipeline( "text-generation", model=prompt_model, - tokenizer="gpt2", device=device, framework="pt", ) - input = params.prompt - max_length = len(input) + randint(60, 90) - logger.debug( - "generating new prompt with max length of %d from input prompt: %s", - max_length, - input, - ) + prompt_parts = split_prompt(params.prompt) + prompt_results = [] + for prompt in prompt_parts: + retries = 0 + while len(prompt) < min_length and retries < RETRY_LIMIT: + max_length = len(prompt) + randint( + min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN + ) + logger.debug( + "extending input prompt to max length of %d from %s: %s", + max_length, + len(prompt), + prompt, + ) - result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1) - prompt = result[0]["generated_text"].strip() - logger.debug("replacing prompt with: %s", prompt) + result = text_pipe( + prompt, max_length=max_length, num_return_sequences=1 + ) + prompt = result[0]["generated_text"].strip() - params.prompt = prompt + if exclude_tokens: + logger.debug( + "removing excluded tokens from prompt: %s", exclude_tokens + ) + prompt = sub(exclude_tokens, "", prompt) + + if retries >= RETRY_LIMIT: + logger.warning( + "failed to extend input prompt to min length of %d, ended up with %d: %s", + min_length, + len(prompt), + prompt, + ) + + prompt_results.append(prompt) + + complete_prompt = " || ".join(prompt_results) + logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt) + params.prompt = complete_prompt return sources diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index ecf01373..c4fe500d 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -475,9 +475,16 @@ def repair_nan(tile: np.ndarray) -> np.ndarray: return tile +def split_prompt(prompt: str) -> List[str]: + if "||" in prompt: + return prompt.split("||") + + return [prompt] + + def slice_prompt(prompt: str, slice: int) -> str: if "||" in prompt: - parts = prompt.split("||") + parts = split_prompt(prompt) return parts[min(slice, len(parts) - 1)] else: return prompt