diff --git a/api/scripts/parse-prompts.py b/api/scripts/parse-prompts.py index 2c977015..954b4e01 100644 --- a/api/scripts/parse-prompts.py +++ b/api/scripts/parse-prompts.py @@ -1,8 +1,8 @@ -from typing import List from argparse import ArgumentParser -from sys import argv from collections import Counter -from json import dumps +from json import dumps, loads +from sys import argv +from typing import List def parse_args(args: List[str]): @@ -14,13 +14,46 @@ def parse_args(args: List[str]): return parser.parse_args(args) +def load_duck(file: str): + import duckdb + + cursor = duckdb.connect() + return [p[0] for p in cursor.sql(f"SELECT * FROM '{file}'").fetchall()] + + +def load_json(file: str): + with open(file, "r") as f: + data = loads(f.read()) + params = data.get("params", None) + if params: + prompt = params.get("input_prompt", None) + if prompt: + return prompt + + prompt = params.get("prompt", None) + if prompt: + return prompt + + return "" + + +def load_text(file: str): + with open(file, "r") as f: + return f.readlines() + + def main(): args = parse_args(argv[1:]) lines: List[str] = [] for file in args.file: - with open(file, "r") as f: - lines.extend(f.readlines()) + if file.endswith(".parquet") or file.endswith(".duckdb"): + lines.extend(load_duck(file)) + elif file.endswith(".json"): + # json only contains a single prompt + lines.append(load_json(file)) + else: + lines.extend(load_text(file)) phrases = [] for line in lines: