1
0
Fork 0
onnx-web/api/onnx_web/transformers/run.py

42 lines
1.0 KiB
Python
Raw Normal View History

2023-02-14 13:40:06 +00:00
from logging import getLogger
2023-03-05 04:12:35 +00:00
from ..params import ImageParams, Size
from ..server import ServerContext
from ..worker import WorkerContext
2023-02-14 13:40:06 +00:00
logger = getLogger(__name__)
def run_txt2txt_pipeline(
2023-02-26 05:49:39 +00:00
job: WorkerContext,
2023-02-14 13:40:06 +00:00
_server: ServerContext,
params: ImageParams,
_size: Size,
output: str,
) -> None:
from transformers import AutoTokenizer, GPTJForCausalLM
# tested with "EleutherAI/gpt-j-6B"
model = "EleutherAI/gpt-j-6B"
tokens = 1024
device = job.get_device()
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
2023-02-14 13:40:06 +00:00
tokenizer = AutoTokenizer.from_pretrained(model)
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
device.torch_str()
2023-02-14 13:40:06 +00:00
)
results = pipe.generate(
2023-02-14 13:40:06 +00:00
input_ids,
do_sample=True,
max_length=tokens,
temperature=0.8,
)
2023-02-25 17:24:28 +00:00
result_text = tokenizer.decode(results[0], skip_special_tokens=True)
2023-02-14 13:40:06 +00:00
2023-02-25 17:24:28 +00:00
print("Server says: %s" % result_text)
2023-02-14 13:40:06 +00:00
logger.info("finished txt2txt job: %s", output)