feat(api): add prompt filter stage using GPT2 models
This commit is contained in:
parent
acd9168b32
commit
bc2eeb8503
|
@ -20,6 +20,7 @@ from .source_noise import SourceNoiseStage
|
|||
from .source_s3 import SourceS3Stage
|
||||
from .source_txt2img import SourceTxt2ImgStage
|
||||
from .source_url import SourceURLStage
|
||||
from .text_prompt import TextPromptStage
|
||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||
from .upscale_highres import UpscaleHighresStage
|
||||
from .upscale_outpaint import UpscaleOutpaintStage
|
||||
|
@ -52,6 +53,7 @@ CHAIN_STAGES = {
|
|||
"source-s3": SourceS3Stage,
|
||||
"source-txt2img": SourceTxt2ImgStage,
|
||||
"source-url": SourceURLStage,
|
||||
"text-prompt": TextPromptStage,
|
||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||
"upscale-highres": UpscaleHighresStage,
|
||||
"upscale-outpaint": UpscaleOutpaintStage,
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from logging import getLogger
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .result import StageResult
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class TextPromptStage(BaseStage):
|
||||
max_tile = SizeChart.max
|
||||
|
||||
def run(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
stage: StageParams,
|
||||
params: ImageParams,
|
||||
sources: StageResult,
|
||||
*,
|
||||
callback: Optional[ProgressCallback] = None,
|
||||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2")
|
||||
gpt2_pipe = gpt2_pipe.to("cuda")
|
||||
|
||||
input = params.prompt
|
||||
max_length = len(input) + randint(60, 90)
|
||||
logger.debug(
|
||||
"generating new prompt with max length of %d from input prompt: %s",
|
||||
max_length,
|
||||
input,
|
||||
)
|
||||
|
||||
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1)
|
||||
prompt = result[0]["generated_text"].strip()
|
||||
logger.debug("replacing prompt with: %s", prompt)
|
||||
|
||||
params.prompt = prompt
|
||||
return sources
|
Loading…
Reference in New Issue