Skip to content

Commit 93cb70f

Browse files
author
Ubuntu
committed
added 2983 class task
1 parent 97c1735 commit 93cb70f

File tree

1 file changed

+96
-17
lines changed

1 file changed

+96
-17
lines changed

examples/graphbolt/rgcn/download.py

+96-17
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,33 @@ def build_yaml_helper(path, dataset_size, in_memory=True):
5656
"data": [
5757
{
5858
"in_memory": in_memory,
59-
"path": "set/validation_indices.npy",
59+
"path": "set/validation_indices_19.npy",
6060
"name": "seeds",
6161
"format": "numpy",
6262
},
6363
{
6464
"in_memory": in_memory,
65-
"path": "set/validation_labels.npy",
65+
"path": "set/validation_labels_19.npy",
6666
"name": "labels",
6767
"format": "numpy",
6868
},
6969
],
7070
"type": "paper",
7171
}
7272
],
73-
"name": "node_classification",
73+
"name": "node_classification_19",
7474
"train_set": [
7575
{
7676
"data": [
7777
{
7878
"in_memory": in_memory,
79-
"path": "set/train_indices.npy",
79+
"path": "set/train_indices_19.npy",
8080
"name": "seeds",
8181
"format": "numpy",
8282
},
8383
{
8484
"in_memory": in_memory,
85-
"path": "set/train_labels.npy",
85+
"path": "set/train_labels_19.npy",
8686
"name": "labels",
8787
"format": "numpy",
8888
},
@@ -95,21 +95,82 @@ def build_yaml_helper(path, dataset_size, in_memory=True):
9595
"data": [
9696
{
9797
"in_memory": in_memory,
98-
"path": "set/test_indices.npy",
98+
"path": "set/test_indices_19.npy",
9999
"name": "seeds",
100100
"format": "numpy",
101101
},
102102
{
103103
"in_memory": in_memory,
104-
"path": "set/test_labels.npy",
104+
"path": "set/test_labels_19.npy",
105105
"name": "labels",
106106
"format": "numpy",
107107
},
108108
],
109109
"type": "paper",
110110
}
111111
],
112-
}
112+
},
113+
{
114+
"num_classes": 2983,
115+
"validation_set": [
116+
{
117+
"data": [
118+
{
119+
"in_memory": in_memory,
120+
"path": "set/validation_indices_2983.npy",
121+
"name": "seeds",
122+
"format": "numpy",
123+
},
124+
{
125+
"in_memory": in_memory,
126+
"path": "set/validation_labels_2983.npy",
127+
"name": "labels",
128+
"format": "numpy",
129+
},
130+
],
131+
"type": "paper",
132+
}
133+
],
134+
"name": "node_classification_2K",
135+
"train_set": [
136+
{
137+
"data": [
138+
{
139+
"in_memory": in_memory,
140+
"path": "set/train_indices_2983.npy",
141+
"name": "seeds",
142+
"format": "numpy",
143+
},
144+
{
145+
"in_memory": in_memory,
146+
"path": "set/train_labels_2983.npy",
147+
"name": "labels",
148+
"format": "numpy",
149+
},
150+
],
151+
"type": "paper",
152+
}
153+
],
154+
"test_set": [
155+
{
156+
"data": [
157+
{
158+
"in_memory": in_memory,
159+
"path": "set/test_indices_2983.npy",
160+
"name": "seeds",
161+
"format": "numpy",
162+
},
163+
{
164+
"in_memory": in_memory,
165+
"path": "set/test_labels_2983.npy",
166+
"name": "labels",
167+
"format": "numpy",
168+
},
169+
],
170+
"type": "paper",
171+
}
172+
],
173+
},
113174
],
114175
"feature_data": [
115176
{
@@ -390,7 +451,7 @@ def download_dataset(path, dataset_type, dataset_size):
390451
}
391452

392453

393-
def split_data(label_path, set_dir, dataset_size):
454+
def split_data(label_path, set_dir, dataset_size, class_num):
394455
"""This is for splitting the labels into three sets: train, validation, and test sets."""
395456
# labels = np.memmap(label_path, dtype='int32', mode='r', shape=(num_nodes[dataset_size]["paper"], 1))
396457
labels = np.load(label_path)
@@ -415,14 +476,24 @@ def split_data(label_path, set_dir, dataset_size):
415476
print(validation_labels, len(validation_labels))
416477
print(test_labels, len(test_labels))
417478

418-
gb.numpy_save_aligned(f"{set_dir}/train_indices.npy", train_indices)
419479
gb.numpy_save_aligned(
420-
f"{set_dir}/validation_indices.npy", validation_indices
480+
f"{set_dir}/train_indices_{class_num}.npy", train_indices
481+
)
482+
gb.numpy_save_aligned(
483+
f"{set_dir}/validation_indices_{class_num}.npy", validation_indices
484+
)
485+
gb.numpy_save_aligned(
486+
f"{set_dir}/test_indices_{class_num}.npy", test_indices
487+
)
488+
gb.numpy_save_aligned(
489+
f"{set_dir}/train_labels_{class_num}.npy", train_labels
490+
)
491+
gb.numpy_save_aligned(
492+
f"{set_dir}/validation_labels_{class_num}.npy", validation_labels
493+
)
494+
gb.numpy_save_aligned(
495+
f"{set_dir}/test_labels_{class_num}.npy", test_labels
421496
)
422-
gb.numpy_save_aligned(f"{set_dir}/test_indices.npy", test_indices)
423-
gb.numpy_save_aligned(f"{set_dir}/train_labels.npy", train_labels)
424-
gb.numpy_save_aligned(f"{set_dir}/validation_labels.npy", validation_labels)
425-
gb.numpy_save_aligned(f"{set_dir}/test_labels.npy", test_labels)
426497

427498

428499
def add_edges(edges, source, dest, dataset_size):
@@ -480,7 +551,6 @@ def process_label(file_path, num_class, dataset_size):
480551
assert new_array.shape[0] == 227130858
481552
assert np.array_equal(array, new_array)
482553
else:
483-
assert num_class == 19
484554
# new_array = np.memmap(file_path, dtype='int32', mode='r', shape=(num_nodes[dataset_size]["paper"], 1))
485555
new_array = np.load(file_path)
486556
assert new_array.shape[0] == num_nodes[dataset_size]["paper"]
@@ -547,7 +617,16 @@ def process_dataset(path, dataset_size):
547617
set_dir = processed_dir + "/" + "set"
548618
os.makedirs(name=set_dir, exist_ok=True)
549619
split_data(
550-
label_path=label_file_19, set_dir=set_dir, dataset_size=dataset_size
620+
label_path=label_file_19,
621+
set_dir=set_dir,
622+
dataset_size=dataset_size,
623+
class_num=19,
624+
)
625+
split_data(
626+
label_path=label_file_2K,
627+
set_dir=set_dir,
628+
dataset_size=dataset_size,
629+
class_num=2983,
551630
)
552631

553632
# Step 3: Move edge files

0 commit comments

Comments
 (0)