1
0
Fork 0
onnx-web/api/scripts/parse-prompts.py

67 lines
1.6 KiB
Python

from argparse import ArgumentParser
from collections import Counter
from json import dumps, loads
from sys import argv
from typing import List
def parse_args(args: List[str]):
parser = ArgumentParser(
prog="onnx-web prompt parser",
description="count phrase frequency in prompt books",
)
parser.add_argument("file", nargs="+", help="prompt files to parse")
return parser.parse_args(args)
def load_duck(file: str):
import duckdb
cursor = duckdb.connect()
return [p[0] for p in cursor.sql(f"SELECT * FROM '{file}'").fetchall()]
def load_json(file: str):
with open(file, "r") as f:
data = loads(f.read())
params = data.get("params", None)
if params:
prompt = params.get("input_prompt", None)
if prompt:
return prompt
prompt = params.get("prompt", None)
if prompt:
return prompt
return ""
def load_text(file: str):
with open(file, "r") as f:
return f.readlines()
def main():
args = parse_args(argv[1:])
lines: List[str] = []
for file in args.file:
if file.endswith(".parquet") or file.endswith(".duckdb"):
lines.extend(load_duck(file))
elif file.endswith(".json"):
# json only contains a single prompt
lines.append(load_json(file))
else:
lines.extend(load_text(file))
phrases = []
for line in lines:
phrases.extend([p.lower().strip() for p in line.split(",")])
count = Counter(phrases)
print(dumps(dict(count.most_common())))
if __name__ == "__main__":
main()