1
0
Fork 0
onnx-web/api/scripts/parse-prompts.py

67 lines
1.6 KiB
Python
Raw Permalink Normal View History

from argparse import ArgumentParser
from collections import Counter
from json import dumps, loads
from sys import argv
from typing import List
def parse_args(args: List[str]):
parser = ArgumentParser(
prog="onnx-web prompt parser",
description="count phrase frequency in prompt books",
)
parser.add_argument("file", nargs="+", help="prompt files to parse")
return parser.parse_args(args)
def load_duck(file: str):
import duckdb
cursor = duckdb.connect()
return [p[0] for p in cursor.sql(f"SELECT * FROM '{file}'").fetchall()]
def load_json(file: str):
with open(file, "r") as f:
data = loads(f.read())
params = data.get("params", None)
if params:
prompt = params.get("input_prompt", None)
if prompt:
return prompt
prompt = params.get("prompt", None)
if prompt:
return prompt
return ""
def load_text(file: str):
with open(file, "r") as f:
return f.readlines()
def main():
args = parse_args(argv[1:])
lines: List[str] = []
for file in args.file:
if file.endswith(".parquet") or file.endswith(".duckdb"):
lines.extend(load_duck(file))
elif file.endswith(".json"):
# json only contains a single prompt
lines.append(load_json(file))
else:
lines.extend(load_text(file))
phrases = []
for line in lines:
phrases.extend([p.lower().strip() for p in line.split(",")])
count = Counter(phrases)
print(dumps(dict(count.most_common())))
if __name__ == "__main__":
main()