From 2d18ac1377e8c06b2417116219b6376a8459250c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 11 Feb 2024 16:56:35 -0600 Subject: [PATCH] use preferred Torch device for prompt filters --- api/onnx_web/chain/text_prompt.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py index 6ede0eba..0def7c83 100644 --- a/api/onnx_web/chain/text_prompt.py +++ b/api/onnx_web/chain/text_prompt.py @@ -28,8 +28,14 @@ class TextPromptStage(BaseStage): prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion", **kwargs, ) -> StageResult: - gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2") - gpt2_pipe = gpt2_pipe.to("cuda") + device = worker.device.torch_str() + gpt2_pipe = pipeline( + "text-generation", + model=prompt_model, + tokenizer="gpt2", + device=device, + framework="pt", + ) input = params.prompt max_length = len(input) + randint(60, 90)