14
14
15
15
namespace graphbolt {
16
16
namespace sampling {
17
- std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact (
17
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
18
+ UniqueAndCompact (
18
19
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
19
20
const torch::Tensor unique_dst_ids, const int64_t rank,
20
21
const int64_t world_size) {
@@ -31,16 +32,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
31
32
" Cooperative Minibatching (arXiv:2310.12403) is supported only on GPUs." );
32
33
auto num_dst = unique_dst_ids.size (0 );
33
34
torch::Tensor ids = torch::cat ({unique_dst_ids, src_ids});
34
- return AT_DISPATCH_INDEX_TYPES (
35
+ auto [unique_ids, compacted_src, compacted_dst] = AT_DISPATCH_INDEX_TYPES (
35
36
ids.scalar_type (), " unique_and_compact" , ([&] {
36
37
ConcurrentIdHashMap<index_t > id_map (ids, num_dst);
37
38
return std::make_tuple (
38
39
id_map.GetUniqueIds (), id_map.MapIds (src_ids),
39
40
id_map.MapIds (dst_ids));
40
41
}));
42
+ auto offsets = torch::zeros (2 , c10::TensorOptions ().dtype (torch::kInt64 ));
43
+ offsets.data_ptr <int64_t >()[1 ] = unique_ids.size (0 );
44
+ return {unique_ids, compacted_src, compacted_dst, offsets};
41
45
}
42
46
43
- std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
47
+ std::vector<
48
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
44
49
UniqueAndCompactBatched (
45
50
const std::vector<torch::Tensor>& src_ids,
46
51
const std::vector<torch::Tensor>& dst_ids,
@@ -64,7 +69,9 @@ UniqueAndCompactBatched(
64
69
src_ids, dst_ids, unique_dst_ids, rank, world_size);
65
70
});
66
71
}
67
- std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> results;
72
+ std::vector<
73
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
74
+ results;
68
75
results.reserve (src_ids.size ());
69
76
for (std::size_t i = 0 ; i < src_ids.size (); i++) {
70
77
results.emplace_back (UniqueAndCompact (
@@ -73,8 +80,8 @@ UniqueAndCompactBatched(
73
80
return results;
74
81
}
75
82
76
- c10::intrusive_ptr<Future<
77
- std::vector<std::tuple< torch::Tensor, torch::Tensor, torch::Tensor>>>>
83
+ c10::intrusive_ptr<Future<std::vector<
84
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>
78
85
UniqueAndCompactBatchedAsync (
79
86
const std::vector<torch::Tensor>& src_ids,
80
87
const std::vector<torch::Tensor>& dst_ids,
0 commit comments