From 44393e37706494dbc82dd54a16a9575986efa705 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 14 Feb 2023 07:40:06 -0600 Subject: [PATCH] feat(api): add txt2txt endpoint --- api/onnx_web/serve.py | 21 +++++++++++++++++++ api/onnx_web/transformers.py | 40 ++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 api/onnx_web/transformers.py diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 3f2dc96b..ee60ef84 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -64,6 +64,7 @@ from .params import ( UpscaleParams, ) from .server import DevicePoolExecutor, ServerContext, apply_patches +from .transformers import run_txt2txt_pipeline from .utils import ( base_join, get_and_clamp_float, @@ -809,6 +810,26 @@ def blend(): return jsonify(json_params(output, params, size, upscale=upscale)) +@app.route("/api/txt2txt", methods=["POST"]) +def txt2txt(): + device, params, size = pipeline_from_request() + + output = make_output_name(context, "upscale", params, size) + logger.info("upscale job queued for: %s", output) + + executor.submit( + output, + run_txt2txt_pipeline, + context, + params, + size, + output, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + @app.route("/api/cancel", methods=["PUT"]) def cancel(): output_file = request.args.get("output", None) diff --git a/api/onnx_web/transformers.py b/api/onnx_web/transformers.py new file mode 100644 index 00000000..31c066e4 --- /dev/null +++ b/api/onnx_web/transformers.py @@ -0,0 +1,40 @@ +from logging import getLogger + +from .params import ImageParams, Size +from .server import JobContext, ServerContext + +logger = getLogger(__name__) + + +def run_txt2txt_pipeline( + job: JobContext, + _server: ServerContext, + params: ImageParams, + _size: Size, + output: str, +) -> None: + from transformers import AutoTokenizer, GPTJForCausalLM + + # tested with "EleutherAI/gpt-j-6B" + model = "EleutherAI/gpt-j-6B" + tokens = 1024 + + device = job.get_device() + + model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device()) + tokenizer = AutoTokenizer.from_pretrained(model) + + input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to( + device.torch_device() + ) + output = model.generate( + input_ids, + do_sample=True, + max_length=tokens, + temperature=0.8, + ) + result = tokenizer.decode(output[0], skip_special_tokens=True) + + print("Server says: %s" % result) + + logger.info("finished txt2txt job: %s", output)