feat(api): add txt2txt endpoint
This commit is contained in:
parent
fd013c88ef
commit
44393e3770
|
@ -64,6 +64,7 @@ from .params import (
|
|||
UpscaleParams,
|
||||
)
|
||||
from .server import DevicePoolExecutor, ServerContext, apply_patches
|
||||
from .transformers import run_txt2txt_pipeline
|
||||
from .utils import (
|
||||
base_join,
|
||||
get_and_clamp_float,
|
||||
|
@ -809,6 +810,26 @@ def blend():
|
|||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
||||
@app.route("/api/txt2txt", methods=["POST"])
|
||||
def txt2txt():
|
||||
device, params, size = pipeline_from_request()
|
||||
|
||||
output = make_output_name(context, "upscale", params, size)
|
||||
logger.info("upscale job queued for: %s", output)
|
||||
|
||||
executor.submit(
|
||||
output,
|
||||
run_txt2txt_pipeline,
|
||||
context,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
@app.route("/api/cancel", methods=["PUT"])
|
||||
def cancel():
|
||||
output_file = request.args.get("output", None)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
from logging import getLogger
|
||||
|
||||
from .params import ImageParams, Size
|
||||
from .server import JobContext, ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def run_txt2txt_pipeline(
|
||||
job: JobContext,
|
||||
_server: ServerContext,
|
||||
params: ImageParams,
|
||||
_size: Size,
|
||||
output: str,
|
||||
) -> None:
|
||||
from transformers import AutoTokenizer, GPTJForCausalLM
|
||||
|
||||
# tested with "EleutherAI/gpt-j-6B"
|
||||
model = "EleutherAI/gpt-j-6B"
|
||||
tokens = 1024
|
||||
|
||||
device = job.get_device()
|
||||
|
||||
model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device())
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
|
||||
device.torch_device()
|
||||
)
|
||||
output = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
max_length=tokens,
|
||||
temperature=0.8,
|
||||
)
|
||||
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
|
||||
print("Server says: %s" % result)
|
||||
|
||||
logger.info("finished txt2txt job: %s", output)
|
Loading…
Reference in New Issue