Skip to content

Commit c3e8e89

Browse files
committed
extend test coverage.
1 parent 20a35a7 commit c3e8e89

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ def test_rank_sort_and_unique_and_compact(dtype, rank):
3737
assert_equal(offsets1, offsets2)
3838
assert offsets1.is_pinned() and offsets2.is_pinned()
3939

40-
res3 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE)
40+
# Test with the reverse order of ntypes. See if results are equivalent.
41+
res3 = torch.ops.graphbolt.rank_sort(nodes_list1[::-1], rank, WORLD_SIZE)
4142

4243
# This function is deterministic. Call with identical arguments and check.
43-
for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, res3):
44+
for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, reversed(res3)):
4445
assert_equal(nodes1, nodes3)
4546
assert_equal(idx1, idx3)
46-
assert_equal(offsets1, offsets3)
47+
assert_equal(offsets1.diff(), offsets3.diff())
4748

4849
# The dependency on the rank argument is simply a permutation.
4950
res4 = torch.ops.graphbolt.rank_sort(nodes_list1, 0, WORLD_SIZE)

0 commit comments

Comments
 (0)