1
0
Fork 0

use preferred Torch device for prompt filters

This commit is contained in:
Sean Sube 2024-02-11 16:56:35 -06:00
parent 3fea317a43
commit 2d18ac1377
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 8 additions and 2 deletions

View File

@ -28,8 +28,14 @@ class TextPromptStage(BaseStage):
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2") device = worker.device.torch_str()
gpt2_pipe = gpt2_pipe.to("cuda") gpt2_pipe = pipeline(
"text-generation",
model=prompt_model,
tokenizer="gpt2",
device=device,
framework="pt",
)
input = params.prompt input = params.prompt
max_length = len(input) + randint(60, 90) max_length = len(input) + randint(60, 90)