feat(api): add prompt filter stage using GPT2 models
This commit is contained in:
parent
acd9168b32
commit
bc2eeb8503
|
@ -20,6 +20,7 @@ from .source_noise import SourceNoiseStage
|
||||||
from .source_s3 import SourceS3Stage
|
from .source_s3 import SourceS3Stage
|
||||||
from .source_txt2img import SourceTxt2ImgStage
|
from .source_txt2img import SourceTxt2ImgStage
|
||||||
from .source_url import SourceURLStage
|
from .source_url import SourceURLStage
|
||||||
|
from .text_prompt import TextPromptStage
|
||||||
from .upscale_bsrgan import UpscaleBSRGANStage
|
from .upscale_bsrgan import UpscaleBSRGANStage
|
||||||
from .upscale_highres import UpscaleHighresStage
|
from .upscale_highres import UpscaleHighresStage
|
||||||
from .upscale_outpaint import UpscaleOutpaintStage
|
from .upscale_outpaint import UpscaleOutpaintStage
|
||||||
|
@ -52,6 +53,7 @@ CHAIN_STAGES = {
|
||||||
"source-s3": SourceS3Stage,
|
"source-s3": SourceS3Stage,
|
||||||
"source-txt2img": SourceTxt2ImgStage,
|
"source-txt2img": SourceTxt2ImgStage,
|
||||||
"source-url": SourceURLStage,
|
"source-url": SourceURLStage,
|
||||||
|
"text-prompt": TextPromptStage,
|
||||||
"upscale-bsrgan": UpscaleBSRGANStage,
|
"upscale-bsrgan": UpscaleBSRGANStage,
|
||||||
"upscale-highres": UpscaleHighresStage,
|
"upscale-highres": UpscaleHighresStage,
|
||||||
"upscale-outpaint": UpscaleOutpaintStage,
|
"upscale-outpaint": UpscaleOutpaintStage,
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from random import randint
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
from ..params import ImageParams, SizeChart, StageParams
|
||||||
|
from ..server import ServerContext
|
||||||
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
|
from .base import BaseStage
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TextPromptStage(BaseStage):
|
||||||
|
max_tile = SizeChart.max
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
worker: WorkerContext,
|
||||||
|
server: ServerContext,
|
||||||
|
stage: StageParams,
|
||||||
|
params: ImageParams,
|
||||||
|
sources: StageResult,
|
||||||
|
*,
|
||||||
|
callback: Optional[ProgressCallback] = None,
|
||||||
|
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||||
|
**kwargs,
|
||||||
|
) -> StageResult:
|
||||||
|
gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2")
|
||||||
|
gpt2_pipe = gpt2_pipe.to("cuda")
|
||||||
|
|
||||||
|
input = params.prompt
|
||||||
|
max_length = len(input) + randint(60, 90)
|
||||||
|
logger.debug(
|
||||||
|
"generating new prompt with max length of %d from input prompt: %s",
|
||||||
|
max_length,
|
||||||
|
input,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = gpt2_pipe(input, max_length=max_length, num_return_sequences=1)
|
||||||
|
prompt = result[0]["generated_text"].strip()
|
||||||
|
logger.debug("replacing prompt with: %s", prompt)
|
||||||
|
|
||||||
|
params.prompt = prompt
|
||||||
|
return sources
|
Loading…
Reference in New Issue