-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
271 lines (222 loc) · 9.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# Kun Zhou Implemented
from data.simulation.simulation import SimulationManager
from preprocess.dataset import knowledge_info
from preprocess.stat_info_functions import stat_info_collection, convert_stat_info_to_text
from algorithm.filter import Filter
from algorithm.program import Programming
from algorithm.rerank import Reranker
from postprocess.judge import Judge
from postprocess.visualization import Visualization
from preprocess.eda_generation import EDA
from postprocess.report_generation import Report_generation
from global_setting.Initialize_state import global_state_initialization, load_data
import json
import argparse
import pandas as pd
import os
def parse_args():
parser = argparse.ArgumentParser(description='Causal Learning Tool for Data Analysis')
# Input data file
parser.add_argument(
'--data-file',
type=str,
default="dataset/Abalone/Abalone.csv",
help='Path to the input dataset file (e.g., CSV format or directory location)'
)
# Output file for results
parser.add_argument(
'--output-report-dir',
type=str,
default='dataset/Abalone/output_report',
help='Directory to save the output report'
)
# Output directory for graphs
parser.add_argument(
'--output-graph-dir',
type=str,
default='dataset/Abalone/output_graph',
help='Directory to save the output graph'
)
# OpenAI Settings
parser.add_argument(
'--organization',
type=str,
default="org-5NION61XDUXh0ib0JZpcppqS",
help='Organization ID'
)
parser.add_argument(
'--project',
type=str,
default="proj_Ry1rvoznXAMj8R2bujIIkhQN",
help='Project ID'
)
parser.add_argument(
'--apikey',
type=str,
default="",
help='API Key'
)
parser.add_argument(
'--simulation_mode',
type=str,
default="offline",
help='Simulation mode: online or offline'
)
parser.add_argument(
'--data_mode',
type=str,
default="real",
help='Data mode: real or simulated'
)
parser.add_argument(
'--debug',
action='store_true',
default=False,
help='Enable debugging mode'
)
parser.add_argument(
'--initial_query',
type=str,
default="selected algorithm: PC",
help='Initial query for the algorithm'
)
parser.add_argument(
'--parallel',
type=bool,
default=False,
help='Parallel computing for bootstrapping.'
)
parser.add_argument(
'--demo_mode',
type=bool,
default=False,
help='Demo mode'
)
args = parser.parse_args()
return args
def load_real_world_data(file_path):
#Baseline code
# Checking file format and loading accordingly, right now it's for CSV only
if file_path.endswith('.csv'):
data = pd.read_csv(file_path)
elif file_path.endswith('.json'):
with open(file_path, 'r') as f:
data = pd.DataFrame(json.load(f))
else:
raise ValueError(f"Unsupported file format for {file_path}")
print("Real-world data loaded successfully.")
return data
def process_user_query(query, data):
#Baseline code
query_dict = {}
for part in query.split(';'):
key, value = part.strip().split(':')
query_dict[key.strip()] = value.strip()
if 'filter' in query_dict and query_dict['filter'] == 'continuous':
# Filtering continuous columns, just for target practice right now
data = data.select_dtypes(include=['float64', 'int64'])
if 'selected_algorithm' in query_dict:
selected_algorithm = query_dict['selected_algorithm']
print(f"Algorithm selected: {selected_algorithm}")
print("User query processed.")
return data
def main(args):
global_state = global_state_initialization(args)
global_state = load_data(global_state, args)
if args.data_mode == 'real':
global_state.user_data.raw_data = load_real_world_data(args.data_file)
global_state.user_data.processed_data = process_user_query(args.initial_query, global_state.user_data.raw_data)
# Show the exacted global state
print(global_state)
# background info collection
#print("Original Data: ", global_state.user_data.raw_data)
if args.debug:
# Fake statistics for debugging
global_state.statistics.missingness = False
global_state.statistics.data_type = "Continuous"
global_state.statistics.linearity = True
global_state.statistics.gaussian_error = True
global_state.statistics.stationary = "non time-series"
global_state.user_data.processed_data = global_state.user_data.raw_data
global_state.user_data.knowledge_docs = "This is fake domain knowledge for debugging purposes."
else:
global_state = stat_info_collection(global_state)
global_state = knowledge_info(args, global_state)
# Convert statistics to text
global_state.statistics.description = convert_stat_info_to_text(global_state.statistics)
print("Preprocessed Data: ", global_state.user_data.processed_data)
print("Statistics Info: ", global_state.statistics.description)
print("Knowledge Info: ", global_state.user_data.knowledge_docs)
#############EDA###################
my_eda = EDA(global_state)
my_eda.generate_eda()
# Algorithm selection and deliberation
filter = Filter(args)
global_state = filter.forward(global_state)
reranker = Reranker(args)
global_state = reranker.forward(global_state)
programmer = Programming(args)
global_state = programmer.forward(global_state)
#############Visualization for Initial Graph###################
my_visual_initial = Visualization(global_state)
# Get the position of the nodes
pos_est = my_visual_initial.get_pos(global_state.results.raw_result)
# Plot True Graph
if global_state.user_data.ground_truth is not None:
_ = my_visual_initial.plot_pdag(global_state.user_data.ground_truth, 'true_graph.pdf', pos=pos_est)
# Plot Initial Graph
_ = my_visual_initial.plot_pdag(global_state.results.raw_result, 'initial_graph.pdf', pos=pos_est)
my_report = Report_generation(global_state, args)
global_state.results.raw_edges = my_visual_initial.convert_to_edges(global_state.results.raw_result)
global_state.logging.graph_conversion['initial_graph_analysis'] = my_report.graph_effect_prompts()
judge = Judge(global_state, args)
if global_state.user_data.ground_truth is not None:
print("Original Graph: ", global_state.results.converted_graph)
print("Mat Ground Truth: ", global_state.user_data.ground_truth)
global_state.results.metrics = judge.evaluation(global_state)
print(global_state.results.metrics)
global_state = judge.forward(global_state)
# ##############################
# from postprocess.judge_functions import llm_direction_evaluation
# llm_direction_evaluation(global_state)
# if global_state.user_data.ground_truth is not None:
# print("Revised Graph: ", global_state.results.revised_graph)
# print("Mat Ground Truth: ", global_state.user_data.ground_truth)
# global_state.results.revised_metrics = judge.evaluation(global_state)
# print(global_state.results.revised_metrics)
# ################################
#############Visualization for Revised Graph###################
# Plot Revised Graph
my_visual_revise = Visualization(global_state)
pos_new = my_visual_revise.plot_pdag(global_state.results.revised_graph, 'revised_graph.pdf', pos=pos_est)
global_state.results.revised_edges = my_visual_revise.convert_to_edges(global_state.results.revised_graph)
# Plot Bootstrap Heatmap
boot_heatmap_path = my_visual_revise.boot_heatmap_plot()
# algorithm selection process
'''
round = 0
flag = False
while round < args.max_iterations and flag == False:
code, results = programmer.forward(preprocessed_data, algorithm, hyper_suggest)
flag, algorithm_setup = judge(preprocessed_data, code, results, statistics_dict, algorithm_setup, knowledge_docs)
'''
#############Report Generation###################
import os
try_num = 1
my_report = Report_generation(global_state, args)
report = my_report.generation()
my_report.save_report(report)
report_path = os.path.join(global_state.user_data.output_report_dir, 'report.pdf')
while not os.path.isfile(report_path) and try_num<=3:
try_num = +1
print('Error occur during the Report Generation, try again')
report_gen = Report_generation(global_state, args)
report = report_gen.generation(debug=False)
report_gen.save_report(report)
if not os.path.isfile(report_path) and try_num==3:
print('Error occur during the Report Generation three times, we stop.')
################################
return report, global_state
if __name__ == '__main__':
args = parse_args()
main(args)