diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py index 9fc4bd9a..bb644233 100644 --- a/api/onnx_web/chain/stages.py +++ b/api/onnx_web/chain/stages.py @@ -20,6 +20,7 @@ from .source_noise import SourceNoiseStage from .source_s3 import SourceS3Stage from .source_txt2img import SourceTxt2ImgStage from .source_url import SourceURLStage +from .text_prompt import TextPromptStage from .upscale_bsrgan import UpscaleBSRGANStage from .upscale_highres import UpscaleHighresStage from .upscale_outpaint import UpscaleOutpaintStage @@ -52,6 +53,7 @@ CHAIN_STAGES = { "source-s3": SourceS3Stage, "source-txt2img": SourceTxt2ImgStage, "source-url": SourceURLStage, + "text-prompt": TextPromptStage, "upscale-bsrgan": UpscaleBSRGANStage, "upscale-highres": UpscaleHighresStage, "upscale-outpaint": UpscaleOutpaintStage, diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py new file mode 100644 index 00000000..6ede0eba --- /dev/null +++ b/api/onnx_web/chain/text_prompt.py @@ -0,0 +1,47 @@ +from logging import getLogger +from random import randint +from typing import Optional + +from transformers import pipeline + +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +class TextPromptStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + worker: WorkerContext, + server: ServerContext, + stage: StageParams, + params: ImageParams, + sources: StageResult, + *, + callback: Optional[ProgressCallback] = None, + prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", + **kwargs, + ) -> StageResult: + gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2") + gpt2_pipe = gpt2_pipe.to("cuda") + + input = params.prompt + max_length = len(input) + randint(60, 90) + logger.debug( + "generating new prompt with max length of %d from input prompt: %s", + max_length, + input, + ) + + result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1) + prompt = result[0]["generated_text"].strip() + logger.debug("replacing prompt with: %s", prompt) + + params.prompt = prompt + return sources