Skip to content

Commit 5fc930b

Browse files
authored
Merge branch 'dmlc:master' into add-igbh-to-rgcn
2 parents 6bf10cb + db574f5 commit 5fc930b

File tree

6 files changed

+50
-8
lines changed

6 files changed

+50
-8
lines changed

examples/multigpu/graphbolt/node_classification.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def create_dataloader(
135135
if args.storage_device != "cpu":
136136
datapipe = datapipe.copy_to(device)
137137
datapipe = datapipe.sample_neighbor(
138-
graph, args.fanout, overlap_fetch=args.storage_device == "pinned"
138+
graph,
139+
args.fanout,
140+
overlap_fetch=args.storage_device == "pinned",
141+
asynchronous=args.storage_device != "cpu",
139142
)
140143
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
141144
if args.storage_device == "cpu":

graphbolt/src/cuda/extension/gpu_cache.cu

+8
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
7676
return std::make_tuple(values, missing_index, missing_keys);
7777
}
7878

79+
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> GpuCache::QueryAsync(
80+
torch::Tensor keys) {
81+
return async([=] {
82+
auto [values, missing_index, missing_keys] = Query(keys);
83+
return std::vector{values, missing_index, missing_keys};
84+
});
85+
}
86+
7987
void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
8088
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
8189
TORCH_CHECK(

graphbolt/src/cuda/extension/gpu_cache.h

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#ifndef GRAPHBOLT_GPU_CACHE_H_
2222
#define GRAPHBOLT_GPU_CACHE_H_
2323

24+
#include <graphbolt/async.h>
2425
#include <torch/custom_class.h>
2526
#include <torch/torch.h>
2627

@@ -53,6 +54,9 @@ class GpuCache : public torch::CustomClassHolder {
5354
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Query(
5455
torch::Tensor keys);
5556

57+
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(
58+
torch::Tensor keys);
59+
5660
void Replace(torch::Tensor keys, torch::Tensor values);
5761

5862
static c10::intrusive_ptr<GpuCache> Create(

graphbolt/src/python_binding.cc

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ TORCH_LIBRARY(graphbolt, m) {
109109
#ifdef GRAPHBOLT_USE_CUDA
110110
m.class_<cuda::GpuCache>("GpuCache")
111111
.def("query", &cuda::GpuCache::Query)
112+
.def("query_async", &cuda::GpuCache::QueryAsync)
112113
.def("replace", &cuda::GpuCache::Replace);
113114
m.def("gpu_cache", &cuda::GpuCache::Create);
114115
m.class_<cuda::GpuGraphCache>("GpuGraphCache")

python/dgl/graphbolt/impl/gpu_cache.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ def __init__(self, cache_shape, dtype):
1414
self.total_miss = 0
1515
self.total_queries = 0
1616

17-
def query(self, keys):
17+
def query(self, keys, async_op=False):
1818
"""Queries the GPU cache.
1919
2020
Parameters
2121
----------
2222
keys : Tensor
2323
The keys to query the GPU cache with.
24+
async_op: bool
25+
Boolean indicating whether the call is asynchronous. If so, the
26+
result can be obtained by calling wait on the returned future.
2427
2528
Returns
2629
-------
@@ -29,10 +32,29 @@ def query(self, keys):
2932
values[missing_indices] corresponds to cache misses that should be
3033
filled by quering another source with missing_keys.
3134
"""
32-
self.total_queries += keys.shape[0]
33-
values, missing_index, missing_keys = self._cache.query(keys)
34-
self.total_miss += missing_keys.shape[0]
35-
return values, missing_index, missing_keys
35+
36+
class _Waiter:
37+
def __init__(self, gpu_cache, future):
38+
self.gpu_cache = gpu_cache
39+
self.future = future
40+
41+
def wait(self):
42+
"""Returns the stored value when invoked."""
43+
gpu_cache = self.gpu_cache
44+
values, missing_index, missing_keys = (
45+
self.future.wait() if async_op else self.future
46+
)
47+
# Ensure there is no leak.
48+
self.gpu_cache = self.future = None
49+
50+
gpu_cache.total_queries += values.shape[0]
51+
gpu_cache.total_miss += missing_keys.shape[0]
52+
return values, missing_index, missing_keys
53+
54+
if async_op:
55+
return _Waiter(self, self._cache.query_async(keys))
56+
else:
57+
return _Waiter(self, self._cache.query(keys)).wait()
3658

3759
def replace(self, keys, values):
3860
"""Inserts key-value pairs into the GPU cache using the Least-Recently

python/dgl/graphbolt/impl/gpu_cached_feature.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def read_async(self, ids: torch.Tensor):
114114
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
115115
>>> result = future.wait() # result contains the read values.
116116
"""
117-
values, missing_index, missing_keys = self._feature.query(ids)
117+
future = self._feature.query(ids, async_op=True)
118+
119+
yield
120+
121+
values, missing_index, missing_keys = future.wait()
118122

119123
fallback_reader = self._fallback_feature.read_async(missing_keys)
120124
fallback_num_stages = self._fallback_feature.read_async_num_stages(
@@ -175,7 +179,7 @@ def read_async_num_stages(self, ids_device: torch.device):
175179
The number of stages of the read_async operation.
176180
"""
177181
assert ids_device.type == "cuda"
178-
return self._fallback_feature.read_async_num_stages(ids_device)
182+
return 1 + self._fallback_feature.read_async_num_stages(ids_device)
179183

180184
def size(self):
181185
"""Get the size of the feature.

0 commit comments

Comments
 (0)