use preferred Torch device for prompt filters
This commit is contained in:
parent
3fea317a43
commit
2d18ac1377
|
@ -28,8 +28,14 @@ class TextPromptStage(BaseStage):
|
||||||
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
prompt_model: str = "Gustavosta/MagicPrompt-Stable-Diffusion",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
gpt2_pipe = pipeline("text-generation", model=prompt_model, tokenizer="gpt2")
|
device = worker.device.torch_str()
|
||||||
gpt2_pipe = gpt2_pipe.to("cuda")
|
gpt2_pipe = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=prompt_model,
|
||||||
|
tokenizer="gpt2",
|
||||||
|
device=device,
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
|
||||||
input = params.prompt
|
input = params.prompt
|
||||||
max_length = len(input) + randint(60, 90)
|
max_length = len(input) + randint(60, 90)
|
||||||
|
|
Loading…
Reference in New Issue