1
0
Fork 0

feat(api): add txt2txt endpoint

This commit is contained in:
Sean Sube 2023-02-14 07:40:06 -06:00
parent fd013c88ef
commit 44393e3770
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 61 additions and 0 deletions

View File

@ -64,6 +64,7 @@ from .params import (
UpscaleParams, UpscaleParams,
) )
from .server import DevicePoolExecutor, ServerContext, apply_patches from .server import DevicePoolExecutor, ServerContext, apply_patches
from .transformers import run_txt2txt_pipeline
from .utils import ( from .utils import (
base_join, base_join,
get_and_clamp_float, get_and_clamp_float,
@ -809,6 +810,26 @@ def blend():
return jsonify(json_params(output, params, size, upscale=upscale)) return jsonify(json_params(output, params, size, upscale=upscale))
@app.route("/api/txt2txt", methods=["POST"])
def txt2txt():
device, params, size = pipeline_from_request()
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
executor.submit(
output,
run_txt2txt_pipeline,
context,
params,
size,
output,
needs_device=device,
)
return jsonify(json_params(output, params, size))
@app.route("/api/cancel", methods=["PUT"]) @app.route("/api/cancel", methods=["PUT"])
def cancel(): def cancel():
output_file = request.args.get("output", None) output_file = request.args.get("output", None)

View File

@ -0,0 +1,40 @@
from logging import getLogger
from .params import ImageParams, Size
from .server import JobContext, ServerContext
logger = getLogger(__name__)
def run_txt2txt_pipeline(
job: JobContext,
_server: ServerContext,
params: ImageParams,
_size: Size,
output: str,
) -> None:
from transformers import AutoTokenizer, GPTJForCausalLM
# tested with "EleutherAI/gpt-j-6B"
model = "EleutherAI/gpt-j-6B"
tokens = 1024
device = job.get_device()
model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device())
tokenizer = AutoTokenizer.from_pretrained(model)
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
device.torch_device()
)
output = model.generate(
input_ids,
do_sample=True,
max_length=tokens,
temperature=0.8,
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
print("Server says: %s" % result)
logger.info("finished txt2txt job: %s", output)