feat(api): add txt2txt endpoint
This commit is contained in:
parent
fd013c88ef
commit
44393e3770
|
@ -64,6 +64,7 @@ from .params import (
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
from .server import DevicePoolExecutor, ServerContext, apply_patches
|
from .server import DevicePoolExecutor, ServerContext, apply_patches
|
||||||
|
from .transformers import run_txt2txt_pipeline
|
||||||
from .utils import (
|
from .utils import (
|
||||||
base_join,
|
base_join,
|
||||||
get_and_clamp_float,
|
get_and_clamp_float,
|
||||||
|
@ -809,6 +810,26 @@ def blend():
|
||||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/txt2txt", methods=["POST"])
|
||||||
|
def txt2txt():
|
||||||
|
device, params, size = pipeline_from_request()
|
||||||
|
|
||||||
|
output = make_output_name(context, "upscale", params, size)
|
||||||
|
logger.info("upscale job queued for: %s", output)
|
||||||
|
|
||||||
|
executor.submit(
|
||||||
|
output,
|
||||||
|
run_txt2txt_pipeline,
|
||||||
|
context,
|
||||||
|
params,
|
||||||
|
size,
|
||||||
|
output,
|
||||||
|
needs_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify(json_params(output, params, size))
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/cancel", methods=["PUT"])
|
@app.route("/api/cancel", methods=["PUT"])
|
||||||
def cancel():
|
def cancel():
|
||||||
output_file = request.args.get("output", None)
|
output_file = request.args.get("output", None)
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from .params import ImageParams, Size
|
||||||
|
from .server import JobContext, ServerContext
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_txt2txt_pipeline(
|
||||||
|
job: JobContext,
|
||||||
|
_server: ServerContext,
|
||||||
|
params: ImageParams,
|
||||||
|
_size: Size,
|
||||||
|
output: str,
|
||||||
|
) -> None:
|
||||||
|
from transformers import AutoTokenizer, GPTJForCausalLM
|
||||||
|
|
||||||
|
# tested with "EleutherAI/gpt-j-6B"
|
||||||
|
model = "EleutherAI/gpt-j-6B"
|
||||||
|
tokens = 1024
|
||||||
|
|
||||||
|
device = job.get_device()
|
||||||
|
|
||||||
|
model = GPTJForCausalLM.from_pretrained(model).to(device.torch_device())
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
|
||||||
|
input_ids = tokenizer.encode(params.prompt, return_tensors="pt").to(
|
||||||
|
device.torch_device()
|
||||||
|
)
|
||||||
|
output = model.generate(
|
||||||
|
input_ids,
|
||||||
|
do_sample=True,
|
||||||
|
max_length=tokens,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
print("Server says: %s" % result)
|
||||||
|
|
||||||
|
logger.info("finished txt2txt job: %s", output)
|
Loading…
Reference in New Issue