1
0
Fork 0

feat(api): add prompt filter stage using GPT2 models

This commit is contained in:
Sean Sube 2024-02-11 16:23:10 -06:00
parent acd9168b32
commit bc2eeb8503
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 49 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from .source_noise import SourceNoiseStage
from .source_s3 import SourceS3Stage
from .source_txt2img import SourceTxt2ImgStage
from .source_url import SourceURLStage
from .text_prompt import TextPromptStage
from .upscale_bsrgan import UpscaleBSRGANStage
from .upscale_highres import UpscaleHighresStage
from .upscale_outpaint import UpscaleOutpaintStage
@ -52,6 +53,7 @@ CHAIN_STAGES = {
"source-s3": SourceS3Stage,
"source-txt2img": SourceTxt2ImgStage,
"source-url": SourceURLStage,
"text-prompt": TextPromptStage,
"upscale-bsrgan": UpscaleBSRGANStage,
"upscale-highres": UpscaleHighresStage,
"upscale-outpaint": UpscaleOutpaintStage,

View File

@ -0,0 +1,47 @@
from logging import getLogger
from random import randint
from typing import Optional
from transformers import pipeline
from ..params import ImageParams, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class TextPromptStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
sources: StageResult,
*,
callback: Optional[ProgressCallback] = None,
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
**kwargs,
) -> StageResult:
gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2")
gpt2_pipe = gpt2_pipe.to("cuda")
input = params.prompt
max_length = len(input) + randint(60, 90)
logger.debug(
"generating new prompt with max length of %d from input prompt: %s",
max_length,
input,
)
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1)
prompt = result[0]["generated_text"].strip()
logger.debug("replacing prompt with: %s", prompt)
params.prompt = prompt
return sources