2023-02-14 13:40:06 +00:00
|
|
|
from logging import getLogger
|
|
|
|
|
2023-03-05 04:12:35 +00:00
|
|
|
from ..params import ImageParams, Size
|
|
|
|
from ..server import ServerContext
|
|
|
|
from ..worker import WorkerContext
|
2023-02-14 13:40:06 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def run_txt2txt_pipeline(
|
2023-07-15 23:54:54 +00:00
|
|
|
worker: WorkerContext,
|
2023-02-14 13:40:06 +00:00
|
|
|
_server: ServerContext,
|
|
|
|
params: ImageParams,
|
|
|
|
_size: Size,
|
|
|
|
) -> None:
|
|
|
|
from transformers import AutoTokenizer, GPTJForCausalLM
|
|
|
|
|
|
|
|
# tested with "EleutherAI/gpt-j-6B"
|
|
|
|
model = "EleutherAI/gpt-j-6B"
|
|
|
|
tokens = 1024
|
|
|
|
|
2023-07-15 23:54:54 +00:00
|
|
|
device = worker.get_device()
|
2023-02-14 13:40:06 +00:00
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
pipe = GPTJForCausalLM.from_pretrained(model).to(device.torch_str())
|
2023-02-14 13:40:06 +00:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
|
2023-03-01 03:44:52 +00:00
|
|
|
device.torch_str()
|
2023-02-14 13:40:06 +00:00
|
|
|
)
|
2023-03-01 03:44:52 +00:00
|
|
|
results = pipe.generate(
|
2023-02-14 13:40:06 +00:00
|
|
|
input_ids,
|
|
|
|
do_sample=True,
|
|
|
|
max_length=tokens,
|
|
|
|
temperature=0.8,
|
|
|
|
)
|
2023-02-25 17:24:28 +00:00
|
|
|
result_text = tokenizer.decode(results[0], skip_special_tokens=True)
|
2023-02-14 13:40:06 +00:00
|
|
|
|
2023-02-25 17:24:28 +00:00
|
|
|
print("Server says: %s" % result_text)
|
2023-02-14 13:40:06 +00:00
|
|
|
|
2024-01-04 01:09:18 +00:00
|
|
|
logger.info("finished txt2txt job: %s", worker.job)
|