Skip to content

Commit 62b4bbb

Browse files
authored
[Fix] Enable mini-batch rgcn for CPU (dmlc#2345)
1 parent 77968e3 commit 62b4bbb

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/pytorch/rgcn/entity_classify_mp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
185185

186186
@thread_wrapped_func
187187
def run(proc_id, n_gpus, args, devices, dataset, split, queue=None):
188-
dev_id = devices[proc_id]
188+
dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1
189189
g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
190190
train_idx, val_idx, test_idx, labels = dataset
191191
if split is not None:

examples/pytorch/rgcn/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self,
7878
sparse_emb=False,
7979
embed_name='embed'):
8080
super(RelGraphEmbedLayer, self).__init__()
81-
self.dev_id = dev_id
81+
self.dev_id = th.device(dev_id if dev_id >= 0 else 'cpu')
8282
self.embed_size = embed_size
8383
self.embed_name = embed_name
8484
self.num_nodes = num_nodes

0 commit comments

Comments
 (0)