3
3
"""
4
4
5
5
import argparse
6
+ import csv
6
7
import json
7
8
import os
8
9
from datetime import datetime
17
18
FINE_TUNING_CUTOFF = datetime (2024 , 8 , 2 , 0 , 0 )
18
19
19
20
20
- def process_goldens (skip_fine_tuned_goldens : bool = False ):
21
+ def process_goldens (
22
+ skip_fine_tuned_goldens : bool = False , gemini_format : bool = False
23
+ ):
21
24
dataset : list [dict [str , Any ]] = []
22
25
outputs_dir = "ft/goldens"
23
26
@@ -31,7 +34,7 @@ def process_goldens(skip_fine_tuned_goldens: bool = False):
31
34
32
35
_ , timestamp = os .path .basename (dir_path ).rsplit ("_" , 1 )
33
36
creation_date = datetime .strptime (timestamp , "%Y%m%d%H%M" )
34
- print ( creation_date )
37
+
35
38
if skip_fine_tuned_goldens and creation_date < FINE_TUNING_CUTOFF :
36
39
continue
37
40
@@ -54,7 +57,7 @@ def process_goldens(skip_fine_tuned_goldens: bool = False):
54
57
else :
55
58
code = ""
56
59
57
- if skip_fine_tuned_goldens :
60
+ if skip_fine_tuned_goldens or gemini_format :
58
61
formatter = MakeMessageFormatterShorterUserMsg ()
59
62
else :
60
63
formatter = MakeDefaultMessageFormatter ()
@@ -81,22 +84,46 @@ def process_goldens(skip_fine_tuned_goldens: bool = False):
81
84
action = "store_true" ,
82
85
help = "Generates a formatted dataset with goldens that have not been fine tuned." ,
83
86
)
87
+ parser .add_argument (
88
+ "--gemini_format" ,
89
+ action = "store_true" ,
90
+ help = "Generates a Gemini formatted dataset." ,
91
+ )
84
92
args = parser .parse_args ()
85
93
86
- formatted_dataset = process_goldens (args .skip_fine_tuned_goldens )
94
+ formatted_dataset = process_goldens (
95
+ args .skip_fine_tuned_goldens , args .gemini_format
96
+ )
87
97
print (f"Processed { len (formatted_dataset )} samples." )
88
98
# create gen dir if it doesn't exist
89
99
os .makedirs ("ft/gen" , exist_ok = True )
90
100
91
101
if args .skip_fine_tuned_goldens :
92
- full_path = os .path .join ("ft/gen/formatted_dataset_for_prompting.jsonl" )
102
+ if args .gemini_format :
103
+ full_path = os .path .join (
104
+ "ft/gen/gemini_formatted_dataset_for_prompting.csv"
105
+ )
106
+ else :
107
+ full_path = os .path .join ("ft/gen/formatted_dataset_for_prompting.jsonl" )
93
108
else :
94
- full_path = os .path .join ("ft/gen/formatted_dataset.jsonl" )
95
-
96
- # Append each sample as a JSON object on a separate line to a file
97
- with open (full_path , "w" ) as f :
98
- for sample in formatted_dataset :
99
- f .write (json .dumps (sample ) + "\n " )
109
+ if args .gemini_format :
110
+ full_path = os .path .join ("ft/gen/gemini_formatted_dataset.csv" )
111
+ else :
112
+ full_path = os .path .join ("ft/gen/formatted_dataset.jsonl" )
113
+
114
+ if args .gemini_format :
115
+ with open (full_path , "w" ) as f :
116
+ writer = csv .writer (f )
117
+ writer .writerow (["input:" , "output:" ])
118
+ for sample in formatted_dataset :
119
+ writer .writerow (
120
+ [sample ["messages" ][1 ]["content" ], sample ["messages" ][2 ]["content" ]]
121
+ )
122
+ else :
123
+ # Append each sample as a JSON object on a separate line to a file
124
+ with open (full_path , "w" ) as f :
125
+ for sample in formatted_dataset :
126
+ f .write (json .dumps (sample ) + "\n " )
100
127
101
128
# Print absolute path of file
102
129
print (f"File created at: { full_path } " )
0 commit comments