@@ -115,14 +115,16 @@ constexpr int kIntBlockSize = 512;
115
115
116
116
c10::intrusive_ptr<GpuGraphCache> GpuGraphCache::Create (
117
117
const int64_t num_edges, const int64_t threshold,
118
- torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes) {
118
+ torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,
119
+ bool has_original_edge_ids) {
119
120
return c10::make_intrusive<GpuGraphCache>(
120
- num_edges, threshold, indptr_dtype, dtypes);
121
+ num_edges, threshold, indptr_dtype, dtypes, has_original_edge_ids );
121
122
}
122
123
123
124
GpuGraphCache::GpuGraphCache (
124
125
const int64_t num_edges, const int64_t threshold,
125
- torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes) {
126
+ torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes,
127
+ bool has_original_edge_ids) {
126
128
const int64_t initial_node_capacity = 1024 ;
127
129
AT_DISPATCH_INDEX_TYPES (
128
130
dtypes.at (0 ), " GpuGraphCache::GpuGraphCache" , ([&] {
@@ -149,7 +151,9 @@ GpuGraphCache::GpuGraphCache(
149
151
num_edges_ = 0 ;
150
152
indptr_ =
151
153
torch::zeros (initial_node_capacity + 1 , options.dtype (indptr_dtype));
152
- offset_ = torch::empty (indptr_.size (0 ) - 1 , indptr_.options ());
154
+ if (!has_original_edge_ids) {
155
+ offset_ = torch::empty (indptr_.size (0 ) - 1 , indptr_.options ());
156
+ }
153
157
for (auto dtype : dtypes) {
154
158
cached_edge_tensors_.push_back (
155
159
torch::empty (num_edges, options.dtype (dtype)));
@@ -249,8 +253,9 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
249
253
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
250
254
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
251
255
std::vector<torch::Tensor> edge_tensors) {
256
+ const auto with_edge_ids = offset_.has_value ();
252
257
// The last element of edge_tensors has the edge ids.
253
- const auto num_tensors = edge_tensors.size () - 1 ;
258
+ const auto num_tensors = edge_tensors.size () - with_edge_ids ;
254
259
TORCH_CHECK (
255
260
num_tensors == cached_edge_tensors_.size (),
256
261
" Same number of tensors need to be passed!" );
@@ -312,21 +317,28 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
312
317
auto input = allocator.AllocateStorage <std::byte*>(num_buffers);
313
318
auto input_size =
314
319
allocator.AllocateStorage <size_t >(num_buffers + 1 );
315
- auto edge_id_offsets = torch::empty (
316
- num_nodes, seeds.options ().dtype (offset_.scalar_type ()));
320
+ torch::optional<torch::Tensor> edge_id_offsets;
321
+ if (with_edge_ids) {
322
+ edge_id_offsets = torch::empty (
323
+ num_nodes,
324
+ seeds.options ().dtype (offset_.value ().scalar_type ()));
325
+ }
317
326
const auto cache_missing_dtype_dev_ptr =
318
327
cache_missing_dtype_dev.get ();
319
328
const auto indices_ptr = indices.data_ptr <indices_t >();
320
329
const auto positions_ptr = positions.data_ptr <indices_t >();
321
330
const auto input_ptr = input.get ();
322
331
const auto input_size_ptr = input_size.get ();
323
332
const auto edge_id_offsets_ptr =
324
- edge_id_offsets.data_ptr <indptr_t >();
333
+ edge_id_offsets ? edge_id_offsets->data_ptr <indptr_t >()
334
+ : nullptr ;
325
335
const auto cache_indptr = indptr_.data_ptr <indptr_t >();
326
336
const auto missing_indptr = indptr.data_ptr <indptr_t >();
327
- const auto cache_offset = offset_.data_ptr <indptr_t >();
337
+ const auto cache_offset =
338
+ offset_ ? offset_->data_ptr <indptr_t >() : nullptr ;
328
339
const auto missing_edge_ids =
329
- edge_tensors.back ().data_ptr <indptr_t >();
340
+ edge_id_offsets ? edge_tensors.back ().data_ptr <indptr_t >()
341
+ : nullptr ;
330
342
CUB_CALL (DeviceFor::Bulk, num_buffers, [=] __device__ (int64_t i) {
331
343
const auto tensor_idx = i / num_nodes;
332
344
const auto idx = i % num_nodes;
@@ -340,14 +352,14 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
340
352
const auto offset_end = is_cached
341
353
? cache_indptr[pos + 1 ]
342
354
: missing_indptr[idx - num_hit + 1 ];
343
- const auto edge_id =
344
- is_cached ? cache_offset[pos] : missing_edge_ids[offset];
345
355
const auto out_idx = tensor_idx * num_nodes + original_idx;
346
356
347
357
input_ptr[out_idx] =
348
358
(is_cached ? cache_ptr : missing_ptr) + offset * size;
349
359
input_size_ptr[out_idx] = size * (offset_end - offset);
350
- if (i < num_nodes) {
360
+ if (edge_id_offsets_ptr && i < num_nodes) {
361
+ const auto edge_id =
362
+ is_cached ? cache_offset[pos] : missing_edge_ids[offset];
351
363
edge_id_offsets_ptr[out_idx] = edge_id;
352
364
}
353
365
});
@@ -390,10 +402,12 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
390
402
indptr_.size (0 ) * kIntGrowthFactor , indptr_.options ());
391
403
new_indptr.slice (0 , 0 , indptr_.size (0 )) = indptr_;
392
404
indptr_ = new_indptr;
393
- auto new_offset =
394
- torch::empty (indptr_.size (0 ) - 1 , offset_.options ());
395
- new_offset.slice (0 , 0 , offset_.size (0 )) = offset_;
396
- offset_ = new_offset;
405
+ if (offset_) {
406
+ auto new_offset =
407
+ torch::empty (indptr_.size (0 ) - 1 , offset_->options ());
408
+ new_offset.slice (0 , 0 , offset_->size (0 )) = *offset_;
409
+ offset_ = new_offset;
410
+ }
397
411
}
398
412
torch::Tensor sindptr;
399
413
bool enough_space;
@@ -415,22 +429,32 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
415
429
}
416
430
if (enough_space) {
417
431
auto num_edges = num_edges_;
418
- auto transform_input_it = thrust::make_zip_iterator (
419
- sindptr.data_ptr <indptr_t >() + 1 ,
420
- sliced_indptr.data_ptr <indptr_t >());
421
- auto transform_output_it = thrust::make_zip_iterator (
422
- indptr_.data_ptr <indptr_t >() + num_nodes_ + 1 ,
423
- offset_.data_ptr <indptr_t >() + num_nodes_);
424
- THRUST_CALL (
425
- transform, transform_input_it,
426
- transform_input_it + sindptr.size (0 ) - 1 ,
427
- transform_output_it,
428
- [=] __host__ __device__ (
429
- const thrust::tuple<indptr_t , indptr_t >& x) {
430
- return thrust::make_tuple (
431
- thrust::get<0 >(x) + num_edges,
432
- missing_edge_ids[thrust::get<1 >(x)]);
433
- });
432
+ if (offset_) {
433
+ auto transform_input_it = thrust::make_zip_iterator (
434
+ sindptr.data_ptr <indptr_t >() + 1 ,
435
+ sliced_indptr.data_ptr <indptr_t >());
436
+ auto transform_output_it = thrust::make_zip_iterator (
437
+ indptr_.data_ptr <indptr_t >() + num_nodes_ + 1 ,
438
+ offset_->data_ptr <indptr_t >() + num_nodes_);
439
+ THRUST_CALL (
440
+ transform, transform_input_it,
441
+ transform_input_it + sindptr.size (0 ) - 1 ,
442
+ transform_output_it,
443
+ [=] __host__ __device__ (
444
+ const thrust::tuple<indptr_t , indptr_t >& x) {
445
+ return thrust::make_tuple (
446
+ thrust::get<0 >(x) + num_edges,
447
+ missing_edge_ids[thrust::get<1 >(x)]);
448
+ });
449
+ } else {
450
+ THRUST_CALL (
451
+ transform, sindptr.data_ptr <indptr_t >() + 1 ,
452
+ sindptr.data_ptr <indptr_t >() + sindptr.size (0 ),
453
+ indptr_.data_ptr <indptr_t >() + num_nodes_ + 1 ,
454
+ [=] __host__ __device__ (const indptr_t & x) {
455
+ return x + num_edges;
456
+ });
457
+ }
434
458
auto map = reinterpret_cast <map_t <indices_t >*>(map_);
435
459
const dim3 block (kIntBlockSize );
436
460
const dim3 grid (
@@ -467,10 +491,13 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
467
491
.view (edge_tensors[i].scalar_type ())
468
492
.slice (0 , 0 , static_cast <indptr_t >(output_size)));
469
493
}
470
- // Append the edge ids as the last element of the output.
471
- output_edge_tensors.push_back (ops::IndptrEdgeIdsImpl (
472
- output_indptr, output_indptr.scalar_type (), edge_id_offsets,
473
- static_cast <int64_t >(static_cast <indptr_t >(output_size))));
494
+ if (edge_id_offsets) {
495
+ // Append the edge ids as the last element of the output.
496
+ output_edge_tensors.push_back (ops::IndptrEdgeIdsImpl (
497
+ output_indptr, output_indptr.scalar_type (),
498
+ *edge_id_offsets,
499
+ static_cast <int64_t >(static_cast <indptr_t >(output_size))));
500
+ }
474
501
475
502
{
476
503
thrust::counting_iterator<int64_t > iota{0 };
0 commit comments