use preferred Torch device for prompt filters
This commit is contained in:
parent
3fea317a43
commit
2d18ac1377
|
@ -28,8 +28,14 @@ class TextPromptStage(BaseStage):
|
|||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2")
|
||||
gpt2_pipe = gpt2_pipe.to("cuda")
|
||||
device = worker.device.torch_str()
|
||||
gpt2_pipe = pipeline(
|
||||
"text-generation",
|
||||
model=prompt_model,
|
||||
tokenizer="gpt2",
|
||||
device=device,
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
input = params.prompt
|
||||
max_length = len(input) + randint(60, 90)
|
||||
|
|
Loading…
Reference in New Issue