@@ -14,13 +14,16 @@ def __init__(self, cache_shape, dtype):
14
14
self .total_miss = 0
15
15
self .total_queries = 0
16
16
17
- def query (self , keys ):
17
+ def query (self , keys , async_op = False ):
18
18
"""Queries the GPU cache.
19
19
20
20
Parameters
21
21
----------
22
22
keys : Tensor
23
23
The keys to query the GPU cache with.
24
+ async_op: bool
25
+ Boolean indicating whether the call is asynchronous. If so, the
26
+ result can be obtained by calling wait on the returned future.
24
27
25
28
Returns
26
29
-------
@@ -29,10 +32,29 @@ def query(self, keys):
29
32
values[missing_indices] corresponds to cache misses that should be
30
33
filled by quering another source with missing_keys.
31
34
"""
32
- self .total_queries += keys .shape [0 ]
33
- values , missing_index , missing_keys = self ._cache .query (keys )
34
- self .total_miss += missing_keys .shape [0 ]
35
- return values , missing_index , missing_keys
35
+
36
+ class _Waiter :
37
+ def __init__ (self , gpu_cache , future ):
38
+ self .gpu_cache = gpu_cache
39
+ self .future = future
40
+
41
+ def wait (self ):
42
+ """Returns the stored value when invoked."""
43
+ gpu_cache = self .gpu_cache
44
+ values , missing_index , missing_keys = (
45
+ self .future .wait () if async_op else self .future
46
+ )
47
+ # Ensure there is no leak.
48
+ self .gpu_cache = self .future = None
49
+
50
+ gpu_cache .total_queries += values .shape [0 ]
51
+ gpu_cache .total_miss += missing_keys .shape [0 ]
52
+ return values , missing_index , missing_keys
53
+
54
+ if async_op :
55
+ return _Waiter (self , self ._cache .query_async (keys ))
56
+ else :
57
+ return _Waiter (self , self ._cache .query (keys )).wait ()
36
58
37
59
def replace (self , keys , values ):
38
60
"""Inserts key-value pairs into the GPU cache using the Least-Recently
0 commit comments