Skip to content

Commit 692912e

Browse files
committed
fix acc drop bug
1 parent 2537811 commit 692912e

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

examples/graphbolt/rgcn/hetero_rgcn.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def extract_embed(node_embed, input_nodes):
163163

164164
def extract_node_features(name, block, data, node_embed, device):
165165
"""Extract the node features from embedding layer or raw features."""
166-
if name == "ogbn-mag" or "igb-het" in name:
166+
if name == "ogbn-mag":
167167
input_nodes = {
168168
k: v.to(device) for k, v in block.srcdata[dgl.NID].items()
169169
}
@@ -288,13 +288,6 @@ def __init__(
288288
}
289289
)
290290

291-
self.loop_weights = nn.ModuleDict(
292-
{
293-
ntype: nn.Linear(in_size, out_size, bias=True)
294-
for ntype in self.ntypes
295-
}
296-
)
297-
298291
self.dropout = nn.Dropout(dropout)
299292
# Initialize parameters of the model.
300293
self.reset_parameters()
@@ -677,6 +670,8 @@ def main(args):
677670
"igb-het-small",
678671
"igb-het-medium",
679672
"igb-het-large",
673+
"igb-het",
674+
"igb-het-MLPerf"
680675
],
681676
help="Dataset name. Possible values: ogbn-mag, ogb-lsc-mag240m, "
682677
" igb-het-[tiny|small|medium|large].",

0 commit comments

Comments
 (0)