Skip to content

Commit 1724a29

Browse files
committed
Update format goldens to create gemini CSV dataset
1 parent 5f775d2 commit 1724a29

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

ai/src/format_goldens.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import argparse
6+
import csv
67
import json
78
import os
89
from datetime import datetime
@@ -17,7 +18,9 @@
1718
FINE_TUNING_CUTOFF = datetime(2024, 8, 2, 0, 0)
1819

1920

20-
def process_goldens(skip_fine_tuned_goldens: bool = False):
21+
def process_goldens(
22+
skip_fine_tuned_goldens: bool = False, gemini_format: bool = False
23+
):
2124
dataset: list[dict[str, Any]] = []
2225
outputs_dir = "ft/goldens"
2326

@@ -31,7 +34,7 @@ def process_goldens(skip_fine_tuned_goldens: bool = False):
3134

3235
_, timestamp = os.path.basename(dir_path).rsplit("_", 1)
3336
creation_date = datetime.strptime(timestamp, "%Y%m%d%H%M")
34-
print(creation_date)
37+
3538
if skip_fine_tuned_goldens and creation_date < FINE_TUNING_CUTOFF:
3639
continue
3740

@@ -54,7 +57,7 @@ def process_goldens(skip_fine_tuned_goldens: bool = False):
5457
else:
5558
code = ""
5659

57-
if skip_fine_tuned_goldens:
60+
if skip_fine_tuned_goldens or gemini_format:
5861
formatter = MakeMessageFormatterShorterUserMsg()
5962
else:
6063
formatter = MakeDefaultMessageFormatter()
@@ -81,22 +84,46 @@ def process_goldens(skip_fine_tuned_goldens: bool = False):
8184
action="store_true",
8285
help="Generates a formatted dataset with goldens that have not been fine tuned.",
8386
)
87+
parser.add_argument(
88+
"--gemini_format",
89+
action="store_true",
90+
help="Generates a Gemini formatted dataset.",
91+
)
8492
args = parser.parse_args()
8593

86-
formatted_dataset = process_goldens(args.skip_fine_tuned_goldens)
94+
formatted_dataset = process_goldens(
95+
args.skip_fine_tuned_goldens, args.gemini_format
96+
)
8797
print(f"Processed {len(formatted_dataset)} samples.")
8898
# create gen dir if it doesn't exist
8999
os.makedirs("ft/gen", exist_ok=True)
90100

91101
if args.skip_fine_tuned_goldens:
92-
full_path = os.path.join("ft/gen/formatted_dataset_for_prompting.jsonl")
102+
if args.gemini_format:
103+
full_path = os.path.join(
104+
"ft/gen/gemini_formatted_dataset_for_prompting.csv"
105+
)
106+
else:
107+
full_path = os.path.join("ft/gen/formatted_dataset_for_prompting.jsonl")
93108
else:
94-
full_path = os.path.join("ft/gen/formatted_dataset.jsonl")
95-
96-
# Append each sample as a JSON object on a separate line to a file
97-
with open(full_path, "w") as f:
98-
for sample in formatted_dataset:
99-
f.write(json.dumps(sample) + "\n")
109+
if args.gemini_format:
110+
full_path = os.path.join("ft/gen/gemini_formatted_dataset.csv")
111+
else:
112+
full_path = os.path.join("ft/gen/formatted_dataset.jsonl")
113+
114+
if args.gemini_format:
115+
with open(full_path, "w") as f:
116+
writer = csv.writer(f)
117+
writer.writerow(["input:", "output:"])
118+
for sample in formatted_dataset:
119+
writer.writerow(
120+
[sample["messages"][1]["content"], sample["messages"][2]["content"]]
121+
)
122+
else:
123+
# Append each sample as a JSON object on a separate line to a file
124+
with open(full_path, "w") as f:
125+
for sample in formatted_dataset:
126+
f.write(json.dumps(sample) + "\n")
100127

101128
# Print absolute path of file
102129
print(f"File created at: {full_path}")

0 commit comments

Comments
 (0)