Skip to content

Commit 03e5569

Browse files
mfbalinlijialin03
authored andcommitted
[GraphBolt][CUDA] Overlap original edge ids fetch. (dmlc#7714)
1 parent b03e4ce commit 03e5569

File tree

7 files changed

+231
-91
lines changed

7 files changed

+231
-91
lines changed

graphbolt/src/cuda/extension/gpu_graph_cache.cu

+64-37
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,16 @@ constexpr int kIntBlockSize = 512;
115115

116116
c10::intrusive_ptr<GpuGraphCache> GpuGraphCache::Create(
117117
const int64_t num_edges, const int64_t threshold,
118-
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes) {
118+
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,
119+
bool has_original_edge_ids) {
119120
return c10::make_intrusive<GpuGraphCache>(
120-
num_edges, threshold, indptr_dtype, dtypes);
121+
num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids);
121122
}
122123

123124
GpuGraphCache::GpuGraphCache(
124125
const int64_t num_edges, const int64_t threshold,
125-
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes) {
126+
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,
127+
bool has_original_edge_ids) {
126128
const int64_t initial_node_capacity = 1024;
127129
AT_DISPATCH_INDEX_TYPES(
128130
dtypes.at(0), "GpuGraphCache::GpuGraphCache", ([&] {
@@ -149,7 +151,9 @@ GpuGraphCache::GpuGraphCache(
149151
num_edges_ = 0;
150152
indptr_ =
151153
torch::zeros(initial_node_capacity + 1, options.dtype(indptr_dtype));
152-
offset_ = torch::empty(indptr_.size(0) - 1, indptr_.options());
154+
if (!has_original_edge_ids) {
155+
offset_ = torch::empty(indptr_.size(0) - 1, indptr_.options());
156+
}
153157
for (auto dtype : dtypes) {
154158
cached_edge_tensors_.push_back(
155159
torch::empty(num_edges, options.dtype(dtype)));
@@ -249,8 +253,9 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
249253
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
250254
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
251255
std::vector<torch::Tensor> edge_tensors) {
256+
const auto with_edge_ids = offset_.has_value();
252257
// The last element of edge_tensors has the edge ids.
253-
const auto num_tensors = edge_tensors.size() - 1;
258+
const auto num_tensors = edge_tensors.size() - with_edge_ids;
254259
TORCH_CHECK(
255260
num_tensors == cached_edge_tensors_.size(),
256261
"Same number of tensors need to be passed!");
@@ -312,21 +317,28 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
312317
auto input = allocator.AllocateStorage<std::byte*>(num_buffers);
313318
auto input_size =
314319
allocator.AllocateStorage<size_t>(num_buffers + 1);
315-
auto edge_id_offsets = torch::empty(
316-
num_nodes, seeds.options().dtype(offset_.scalar_type()));
320+
torch::optional<torch::Tensor> edge_id_offsets;
321+
if (with_edge_ids) {
322+
edge_id_offsets = torch::empty(
323+
num_nodes,
324+
seeds.options().dtype(offset_.value().scalar_type()));
325+
}
317326
const auto cache_missing_dtype_dev_ptr =
318327
cache_missing_dtype_dev.get();
319328
const auto indices_ptr = indices.data_ptr<indices_t>();
320329
const auto positions_ptr = positions.data_ptr<indices_t>();
321330
const auto input_ptr = input.get();
322331
const auto input_size_ptr = input_size.get();
323332
const auto edge_id_offsets_ptr =
324-
edge_id_offsets.data_ptr<indptr_t>();
333+
edge_id_offsets ? edge_id_offsets->data_ptr<indptr_t>()
334+
: nullptr;
325335
const auto cache_indptr = indptr_.data_ptr<indptr_t>();
326336
const auto missing_indptr = indptr.data_ptr<indptr_t>();
327-
const auto cache_offset = offset_.data_ptr<indptr_t>();
337+
const auto cache_offset =
338+
offset_ ? offset_->data_ptr<indptr_t>() : nullptr;
328339
const auto missing_edge_ids =
329-
edge_tensors.back().data_ptr<indptr_t>();
340+
edge_id_offsets ? edge_tensors.back().data_ptr<indptr_t>()
341+
: nullptr;
330342
CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) {
331343
const auto tensor_idx = i / num_nodes;
332344
const auto idx = i % num_nodes;
@@ -340,14 +352,14 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
340352
const auto offset_end = is_cached
341353
? cache_indptr[pos + 1]
342354
: missing_indptr[idx - num_hit + 1];
343-
const auto edge_id =
344-
is_cached ? cache_offset[pos] : missing_edge_ids[offset];
345355
const auto out_idx = tensor_idx * num_nodes + original_idx;
346356

347357
input_ptr[out_idx] =
348358
(is_cached ? cache_ptr : missing_ptr) + offset * size;
349359
input_size_ptr[out_idx] = size * (offset_end - offset);
350-
if (i < num_nodes) {
360+
if (edge_id_offsets_ptr && i < num_nodes) {
361+
const auto edge_id =
362+
is_cached ? cache_offset[pos] : missing_edge_ids[offset];
351363
edge_id_offsets_ptr[out_idx] = edge_id;
352364
}
353365
});
@@ -390,10 +402,12 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
390402
indptr_.size(0) * kIntGrowthFactor, indptr_.options());
391403
new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;
392404
indptr_ = new_indptr;
393-
auto new_offset =
394-
torch::empty(indptr_.size(0) - 1, offset_.options());
395-
new_offset.slice(0, 0, offset_.size(0)) = offset_;
396-
offset_ = new_offset;
405+
if (offset_) {
406+
auto new_offset =
407+
torch::empty(indptr_.size(0) - 1, offset_->options());
408+
new_offset.slice(0, 0, offset_->size(0)) = *offset_;
409+
offset_ = new_offset;
410+
}
397411
}
398412
torch::Tensor sindptr;
399413
bool enough_space;
@@ -415,22 +429,32 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
415429
}
416430
if (enough_space) {
417431
auto num_edges = num_edges_;
418-
auto transform_input_it = thrust::make_zip_iterator(
419-
sindptr.data_ptr<indptr_t>() + 1,
420-
sliced_indptr.data_ptr<indptr_t>());
421-
auto transform_output_it = thrust::make_zip_iterator(
422-
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
423-
offset_.data_ptr<indptr_t>() + num_nodes_);
424-
THRUST_CALL(
425-
transform, transform_input_it,
426-
transform_input_it + sindptr.size(0) - 1,
427-
transform_output_it,
428-
[=] __host__ __device__(
429-
const thrust::tuple<indptr_t, indptr_t>& x) {
430-
return thrust::make_tuple(
431-
thrust::get<0>(x) + num_edges,
432-
missing_edge_ids[thrust::get<1>(x)]);
433-
});
432+
if (offset_) {
433+
auto transform_input_it = thrust::make_zip_iterator(
434+
sindptr.data_ptr<indptr_t>() + 1,
435+
sliced_indptr.data_ptr<indptr_t>());
436+
auto transform_output_it = thrust::make_zip_iterator(
437+
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
438+
offset_->data_ptr<indptr_t>() + num_nodes_);
439+
THRUST_CALL(
440+
transform, transform_input_it,
441+
transform_input_it + sindptr.size(0) - 1,
442+
transform_output_it,
443+
[=] __host__ __device__(
444+
const thrust::tuple<indptr_t, indptr_t>& x) {
445+
return thrust::make_tuple(
446+
thrust::get<0>(x) + num_edges,
447+
missing_edge_ids[thrust::get<1>(x)]);
448+
});
449+
} else {
450+
THRUST_CALL(
451+
transform, sindptr.data_ptr<indptr_t>() + 1,
452+
sindptr.data_ptr<indptr_t>() + sindptr.size(0),
453+
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
454+
[=] __host__ __device__(const indptr_t& x) {
455+
return x + num_edges;
456+
});
457+
}
434458
auto map = reinterpret_cast<map_t<indices_t>*>(map_);
435459
const dim3 block(kIntBlockSize);
436460
const dim3 grid(
@@ -467,10 +491,13 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
467491
.view(edge_tensors[i].scalar_type())
468492
.slice(0, 0, static_cast<indptr_t>(output_size)));
469493
}
470-
// Append the edge ids as the last element of the output.
471-
output_edge_tensors.push_back(ops::IndptrEdgeIdsImpl(
472-
output_indptr, output_indptr.scalar_type(), edge_id_offsets,
473-
static_cast<int64_t>(static_cast<indptr_t>(output_size))));
494+
if (edge_id_offsets) {
495+
// Append the edge ids as the last element of the output.
496+
output_edge_tensors.push_back(ops::IndptrEdgeIdsImpl(
497+
output_indptr, output_indptr.scalar_type(),
498+
*edge_id_offsets,
499+
static_cast<int64_t>(static_cast<indptr_t>(output_size))));
500+
}
474501

475502
{
476503
thrust::counting_iterator<int64_t> iota{0};

graphbolt/src/cuda/extension/gpu_graph_cache.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,13 @@ class GpuGraphCache : public torch::CustomClassHolder {
4747
* @param indptr_dtype The node id datatype.
4848
* @param dtypes The dtypes of the edge tensors to be cached. dtypes[0] is
4949
* reserved for the indices edge tensor holding node ids.
50+
* @param has_original_edge_ids Whether the graph to be cached has original
51+
* edge ids.
5052
*/
5153
GpuGraphCache(
5254
const int64_t num_edges, const int64_t threshold,
53-
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes);
55+
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,
56+
bool has_original_edge_ids);
5457

5558
GpuGraphCache() = default;
5659

@@ -109,7 +112,8 @@ class GpuGraphCache : public torch::CustomClassHolder {
109112

110113
static c10::intrusive_ptr<GpuGraphCache> Create(
111114
const int64_t num_edges, const int64_t threshold,
112-
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes);
115+
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,
116+
bool has_original_edge_ids);
113117

114118
private:
115119
void* map_; // pointer to the hash table.
@@ -119,7 +123,8 @@ class GpuGraphCache : public torch::CustomClassHolder {
119123
int64_t num_nodes_; // The number of cached nodes in the cache.
120124
int64_t num_edges_; // The number of cached edges in the cache.
121125
torch::Tensor indptr_; // The cached graph structure indptr tensor.
122-
torch::Tensor offset_; // The original graph's sliced_indptr tensor.
126+
torch::optional<torch::Tensor>
127+
offset_; // The original graph's sliced_indptr tensor.
123128
std::vector<torch::Tensor> cached_edge_tensors_; // The cached graph
124129
// structure edge tensors.
125130
std::mutex mtx_; // Protects the data structure and makes it threadsafe.

0 commit comments

Comments
 (0)