From 2121c7aa5d5ffecf240145822ce188e82f8159a1 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 6 Jul 2023 08:11:43 -0500 Subject: [PATCH] feat(scripts): add json and parquet support to prompt book parser --- api/scripts/parse-prompts.py | 43 +++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/api/scripts/parse-prompts.py b/api/scripts/parse-prompts.py index 2c977015..954b4e01 100644 --- a/api/scripts/parse-prompts.py +++ b/api/scripts/parse-prompts.py @@ -1,8 +1,8 @@ -from typing import List from argparse import ArgumentParser -from sys import argv from collections import Counter -from json import dumps +from json import dumps, loads +from sys import argv +from typing import List def parse_args(args: List[str]): @@ -14,13 +14,46 @@ def parse_args(args: List[str]): return parser.parse_args(args) +def load_duck(file: str): + import duckdb + + cursor = duckdb.connect() + return [p[0] for p in cursor.sql(f"SELECT * FROM '{file}'").fetchall()] + + +def load_json(file: str): + with open(file, "r") as f: + data = loads(f.read()) + params = data.get("params", None) + if params: + prompt = params.get("input_prompt", None) + if prompt: + return prompt + + prompt = params.get("prompt", None) + if prompt: + return prompt + + return "" + + +def load_text(file: str): + with open(file, "r") as f: + return f.readlines() + + def main(): args = parse_args(argv[1:]) lines: List[str] = [] for file in args.file: - with open(file, "r") as f: - lines.extend(f.readlines()) + if file.endswith(".parquet") or file.endswith(".duckdb"): + lines.extend(load_duck(file)) + elif file.endswith(".json"): + # json only contains a single prompt + lines.append(load_json(file)) + else: + lines.extend(load_text(file)) phrases = [] for line in lines: