1
1
/* *
2
- * Copyright (c) 2023 , GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
2
+ * Copyright (c) 2024 , GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
3
3
* All rights reserved.
4
4
*
5
5
* Licensed under the Apache License, Version 2.0 (the "License");
24
24
#include < torch/custom_class.h>
25
25
#include < torch/torch.h>
26
26
27
+ #include < limits>
28
+
27
29
#include " ./circular_queue.h"
28
30
29
31
namespace graphbolt {
30
32
namespace storage {
31
33
32
34
struct CacheKey {
33
35
CacheKey (int64_t key, int64_t position)
34
- : freq_(0 ), key_(key), position_in_cache_(position), reference_count_(1 ) {
36
+ : freq_(0 ),
37
+ key_ (key),
38
+ position_in_cache_(position),
39
+ read_reference_count_(0 ),
40
+ write_reference_count_(1 ) {
35
41
static_assert (sizeof (CacheKey) == 2 * sizeof (int64_t ));
36
42
}
37
43
@@ -63,17 +69,30 @@ struct CacheKey {
63
69
return *this ;
64
70
}
65
71
72
+ template <bool write>
66
73
CacheKey& StartUse () {
67
- ++reference_count_;
74
+ if constexpr (write ) {
75
+ TORCH_CHECK (
76
+ write_reference_count_++ < std::numeric_limits<int16_t >::max ());
77
+ } else {
78
+ TORCH_CHECK (read_reference_count_++ < std::numeric_limits<int8_t >::max ());
79
+ }
68
80
return *this ;
69
81
}
70
82
83
+ template <bool write>
71
84
CacheKey& EndUse () {
72
- --reference_count_;
85
+ if constexpr (write ) {
86
+ --write_reference_count_;
87
+ } else {
88
+ --read_reference_count_;
89
+ }
73
90
return *this ;
74
91
}
75
92
76
- bool InUse () { return reference_count_ > 0 ; }
93
+ bool InUse () const { return read_reference_count_ || write_reference_count_; }
94
+
95
+ bool BeingWritten () const { return write_reference_count_; }
77
96
78
97
friend std::ostream& operator <<(std::ostream& os, const CacheKey& key_ref) {
79
98
return os << ' (' << key_ref.key_ << " , " << key_ref.freq_ << " , "
@@ -83,8 +102,10 @@ struct CacheKey {
83
102
private:
84
103
int64_t freq_ : 3 ;
85
104
int64_t key_ : 61 ;
86
- int64_t position_in_cache_ : 48 ;
87
- int64_t reference_count_ : 16 ;
105
+ int64_t position_in_cache_ : 40 ;
106
+ int64_t read_reference_count_ : 8 ;
107
+ // There could be a chain of writes so it is better to have larger bit count.
108
+ int64_t write_reference_count_ : 16 ;
88
109
};
89
110
90
111
class BaseCachePolicy {
@@ -123,6 +144,12 @@ class BaseCachePolicy {
123
144
*/
124
145
virtual void ReadingCompleted (torch::Tensor keys) = 0;
125
146
147
+ /* *
148
+ * @brief A writer has finished writing these keys, so they can be evicted.
149
+ * @param keys The keys to unmark.
150
+ */
151
+ virtual void WritingCompleted (torch::Tensor keys) = 0;
152
+
126
153
protected:
127
154
template <typename CachePolicy>
128
155
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -131,8 +158,9 @@ class BaseCachePolicy {
131
158
template <typename CachePolicy>
132
159
static torch::Tensor ReplaceImpl (CachePolicy& policy, torch::Tensor keys);
133
160
134
- template <typename CachePolicy>
135
- static void ReadingCompletedImpl (CachePolicy& policy, torch::Tensor keys);
161
+ template <bool write, typename CachePolicy>
162
+ static void ReadingWritingCompletedImpl (
163
+ CachePolicy& policy, torch::Tensor keys);
136
164
};
137
165
138
166
/* *
@@ -170,6 +198,11 @@ class S3FifoCachePolicy : public BaseCachePolicy {
170
198
*/
171
199
void ReadingCompleted (torch::Tensor keys);
172
200
201
+ /* *
202
+ * @brief See BaseCachePolicy::WritingCompleted.
203
+ */
204
+ void WritingCompleted (torch::Tensor keys);
205
+
173
206
friend std::ostream& operator <<(
174
207
std::ostream& os, const S3FifoCachePolicy& policy) {
175
208
return os << " small_queue_: " << policy.small_queue_ << " \n "
@@ -178,11 +211,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {
178
211
<< " capacity_: " << policy.capacity_ << " \n " ;
179
212
}
180
213
214
+ template <bool write>
181
215
std::optional<int64_t > Read (int64_t key) {
182
216
auto it = key_to_cache_key_.find (key);
183
217
if (it != key_to_cache_key_.end ()) {
184
218
auto & cache_key = *it->second ;
185
- return cache_key.Increment ().StartUse ().getPos ();
219
+ if (write || !cache_key.BeingWritten ())
220
+ return cache_key.Increment ().StartUse <write >().getPos ();
186
221
}
187
222
return std::nullopt;
188
223
}
@@ -195,7 +230,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
195
230
return pos;
196
231
}
197
232
198
- void Unmark (int64_t key) { key_to_cache_key_[key]->EndUse (); }
233
+ template <bool write>
234
+ void Unmark (int64_t key) {
235
+ key_to_cache_key_[key]->EndUse <write >();
236
+ }
199
237
200
238
private:
201
239
int64_t EvictMainQueue () {
@@ -282,11 +320,18 @@ class SieveCachePolicy : public BaseCachePolicy {
282
320
*/
283
321
void ReadingCompleted (torch::Tensor keys);
284
322
323
+ /* *
324
+ * @brief See BaseCachePolicy::WritingCompleted.
325
+ */
326
+ void WritingCompleted (torch::Tensor keys);
327
+
328
+ template <bool write>
285
329
std::optional<int64_t > Read (int64_t key) {
286
330
auto it = key_to_cache_key_.find (key);
287
331
if (it != key_to_cache_key_.end ()) {
288
332
auto & cache_key = *it->second ;
289
- return cache_key.SetFreq ().StartUse ().getPos ();
333
+ if (write || !cache_key.BeingWritten ())
334
+ return cache_key.SetFreq ().StartUse <write >().getPos ();
290
335
}
291
336
return std::nullopt;
292
337
}
@@ -298,7 +343,10 @@ class SieveCachePolicy : public BaseCachePolicy {
298
343
return pos;
299
344
}
300
345
301
- void Unmark (int64_t key) { key_to_cache_key_[key]->EndUse (); }
346
+ template <bool write>
347
+ void Unmark (int64_t key) {
348
+ key_to_cache_key_[key]->EndUse <write >();
349
+ }
302
350
303
351
private:
304
352
int64_t Evict () {
@@ -362,14 +410,22 @@ class LruCachePolicy : public BaseCachePolicy {
362
410
*/
363
411
void ReadingCompleted (torch::Tensor keys);
364
412
413
+ /* *
414
+ * @brief See BaseCachePolicy::WritingCompleted.
415
+ */
416
+ void WritingCompleted (torch::Tensor keys);
417
+
418
+ template <bool write>
365
419
std::optional<int64_t > Read (int64_t key) {
366
420
auto it = key_to_cache_key_.find (key);
367
421
if (it != key_to_cache_key_.end ()) {
368
422
auto cache_key = *it->second ;
369
- queue_.erase (it->second );
370
- queue_.push_front (cache_key.StartUse ());
371
- it->second = queue_.begin ();
372
- return cache_key.getPos ();
423
+ if (write || !cache_key.BeingWritten ()) {
424
+ queue_.erase (it->second );
425
+ queue_.push_front (cache_key.StartUse <write >());
426
+ it->second = queue_.begin ();
427
+ return cache_key.getPos ();
428
+ }
373
429
}
374
430
return std::nullopt;
375
431
}
@@ -381,7 +437,10 @@ class LruCachePolicy : public BaseCachePolicy {
381
437
return pos;
382
438
}
383
439
384
- void Unmark (int64_t key) { key_to_cache_key_[key]->EndUse (); }
440
+ template <bool write>
441
+ void Unmark (int64_t key) {
442
+ key_to_cache_key_[key]->EndUse <write >();
443
+ }
385
444
386
445
private:
387
446
int64_t Evict () {
@@ -443,11 +502,18 @@ class ClockCachePolicy : public BaseCachePolicy {
443
502
*/
444
503
void ReadingCompleted (torch::Tensor keys);
445
504
505
+ /* *
506
+ * @brief See BaseCachePolicy::WritingCompleted.
507
+ */
508
+ void WritingCompleted (torch::Tensor keys);
509
+
510
+ template <bool write>
446
511
std::optional<int64_t > Read (int64_t key) {
447
512
auto it = key_to_cache_key_.find (key);
448
513
if (it != key_to_cache_key_.end ()) {
449
514
auto & cache_key = *it->second ;
450
- return cache_key.SetFreq ().StartUse ().getPos ();
515
+ if (write || !cache_key.BeingWritten ())
516
+ return cache_key.SetFreq ().StartUse <write >().getPos ();
451
517
}
452
518
return std::nullopt;
453
519
}
@@ -458,7 +524,10 @@ class ClockCachePolicy : public BaseCachePolicy {
458
524
return pos;
459
525
}
460
526
461
- void Unmark (int64_t key) { key_to_cache_key_[key]->EndUse (); }
527
+ template <bool write>
528
+ void Unmark (int64_t key) {
529
+ key_to_cache_key_[key]->EndUse <write >();
530
+ }
462
531
463
532
private:
464
533
int64_t Evict () {
0 commit comments