Skip to content

Commit e975694

Browse files
mfbalinlijialin03
authored andcommitted
[GraphBolt] CachePolicy Writer lock for read_async correctness. (dmlc#7581)
1 parent 1f7c8a6 commit e975694

9 files changed

+165
-40
lines changed

graphbolt/src/cache_policy.cc

+25-9
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
4747
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
4848
for (int64_t i = 0; i < keys.size(0); i++) {
4949
const auto key = keys_ptr[i];
50-
auto pos = policy.Read(key);
50+
auto pos = policy.template Read<false>(key);
5151
if (pos.has_value()) {
5252
positions_ptr[found_cnt] = *pos;
5353
filtered_keys_ptr[found_cnt] = key;
@@ -78,7 +78,7 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
7878
position_set.reserve(keys.size(0));
7979
for (int64_t i = 0; i < keys.size(0); i++) {
8080
const auto key = keys_ptr[i];
81-
const auto pos_optional = policy.Read(key);
81+
const auto pos_optional = policy.template Read<true>(key);
8282
const auto pos = pos_optional ? *pos_optional : policy.Insert(key);
8383
positions_ptr[i] = pos;
8484
TORCH_CHECK(
@@ -91,14 +91,14 @@ torch::Tensor BaseCachePolicy::ReplaceImpl(
9191
return positions;
9292
}
9393

94-
template <typename CachePolicy>
95-
void BaseCachePolicy::ReadingCompletedImpl(
94+
template <bool write, typename CachePolicy>
95+
void BaseCachePolicy::ReadingWritingCompletedImpl(
9696
CachePolicy& policy, torch::Tensor keys) {
9797
AT_DISPATCH_INDEX_TYPES(
9898
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
9999
auto keys_ptr = keys.data_ptr<index_t>();
100100
for (int64_t i = 0; i < keys.size(0); i++) {
101-
policy.Unmark(keys_ptr[i]);
101+
policy.template Unmark<write>(keys_ptr[i]);
102102
}
103103
}));
104104
}
@@ -125,7 +125,11 @@ torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
125125
}
126126

127127
void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
128-
ReadingCompletedImpl(*this, keys);
128+
ReadingWritingCompletedImpl<false>(*this, keys);
129+
}
130+
131+
void S3FifoCachePolicy::WritingCompleted(torch::Tensor keys) {
132+
ReadingWritingCompletedImpl<true>(*this, keys);
129133
}
130134

131135
SieveCachePolicy::SieveCachePolicy(int64_t capacity)
@@ -145,7 +149,11 @@ torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
145149
}
146150

147151
void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
148-
ReadingCompletedImpl(*this, keys);
152+
ReadingWritingCompletedImpl<false>(*this, keys);
153+
}
154+
155+
void SieveCachePolicy::WritingCompleted(torch::Tensor keys) {
156+
ReadingWritingCompletedImpl<true>(*this, keys);
149157
}
150158

151159
LruCachePolicy::LruCachePolicy(int64_t capacity)
@@ -164,7 +172,11 @@ torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
164172
}
165173

166174
void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
167-
ReadingCompletedImpl(*this, keys);
175+
ReadingWritingCompletedImpl<false>(*this, keys);
176+
}
177+
178+
void LruCachePolicy::WritingCompleted(torch::Tensor keys) {
179+
ReadingWritingCompletedImpl<true>(*this, keys);
168180
}
169181

170182
ClockCachePolicy::ClockCachePolicy(int64_t capacity)
@@ -183,7 +195,11 @@ torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
183195
}
184196

185197
void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
186-
ReadingCompletedImpl(*this, keys);
198+
ReadingWritingCompletedImpl<false>(*this, keys);
199+
}
200+
201+
void ClockCachePolicy::WritingCompleted(torch::Tensor keys) {
202+
ReadingWritingCompletedImpl<true>(*this, keys);
187203
}
188204

189205
} // namespace storage

graphbolt/src/cache_policy.h

+89-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
2+
* Copyright (c) 2024, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
33
* All rights reserved.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,14 +24,20 @@
2424
#include <torch/custom_class.h>
2525
#include <torch/torch.h>
2626

27+
#include <limits>
28+
2729
#include "./circular_queue.h"
2830

2931
namespace graphbolt {
3032
namespace storage {
3133

3234
struct CacheKey {
3335
CacheKey(int64_t key, int64_t position)
34-
: freq_(0), key_(key), position_in_cache_(position), reference_count_(1) {
36+
: freq_(0),
37+
key_(key),
38+
position_in_cache_(position),
39+
read_reference_count_(0),
40+
write_reference_count_(1) {
3541
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
3642
}
3743

@@ -63,17 +69,30 @@ struct CacheKey {
6369
return *this;
6470
}
6571

72+
template <bool write>
6673
CacheKey& StartUse() {
67-
++reference_count_;
74+
if constexpr (write) {
75+
TORCH_CHECK(
76+
write_reference_count_++ < std::numeric_limits<int16_t>::max());
77+
} else {
78+
TORCH_CHECK(read_reference_count_++ < std::numeric_limits<int8_t>::max());
79+
}
6880
return *this;
6981
}
7082

83+
template <bool write>
7184
CacheKey& EndUse() {
72-
--reference_count_;
85+
if constexpr (write) {
86+
--write_reference_count_;
87+
} else {
88+
--read_reference_count_;
89+
}
7390
return *this;
7491
}
7592

76-
bool InUse() { return reference_count_ > 0; }
93+
bool InUse() const { return read_reference_count_ || write_reference_count_; }
94+
95+
bool BeingWritten() const { return write_reference_count_; }
7796

7897
friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
7998
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
@@ -83,8 +102,10 @@ struct CacheKey {
83102
private:
84103
int64_t freq_ : 3;
85104
int64_t key_ : 61;
86-
int64_t position_in_cache_ : 48;
87-
int64_t reference_count_ : 16;
105+
int64_t position_in_cache_ : 40;
106+
int64_t read_reference_count_ : 8;
107+
// There could be a chain of writes so it is better to have larger bit count.
108+
int64_t write_reference_count_ : 16;
88109
};
89110

90111
class BaseCachePolicy {
@@ -123,6 +144,12 @@ class BaseCachePolicy {
123144
*/
124145
virtual void ReadingCompleted(torch::Tensor keys) = 0;
125146

147+
/**
148+
* @brief A writer has finished writing these keys, so they can be evicted.
149+
* @param keys The keys to unmark.
150+
*/
151+
virtual void WritingCompleted(torch::Tensor keys) = 0;
152+
126153
protected:
127154
template <typename CachePolicy>
128155
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
@@ -131,8 +158,9 @@ class BaseCachePolicy {
131158
template <typename CachePolicy>
132159
static torch::Tensor ReplaceImpl(CachePolicy& policy, torch::Tensor keys);
133160

134-
template <typename CachePolicy>
135-
static void ReadingCompletedImpl(CachePolicy& policy, torch::Tensor keys);
161+
template <bool write, typename CachePolicy>
162+
static void ReadingWritingCompletedImpl(
163+
CachePolicy& policy, torch::Tensor keys);
136164
};
137165

138166
/**
@@ -170,6 +198,11 @@ class S3FifoCachePolicy : public BaseCachePolicy {
170198
*/
171199
void ReadingCompleted(torch::Tensor keys);
172200

201+
/**
202+
* @brief See BaseCachePolicy::WritingCompleted.
203+
*/
204+
void WritingCompleted(torch::Tensor keys);
205+
173206
friend std::ostream& operator<<(
174207
std::ostream& os, const S3FifoCachePolicy& policy) {
175208
return os << "small_queue_: " << policy.small_queue_ << "\n"
@@ -178,11 +211,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {
178211
<< "capacity_: " << policy.capacity_ << "\n";
179212
}
180213

214+
template <bool write>
181215
std::optional<int64_t> Read(int64_t key) {
182216
auto it = key_to_cache_key_.find(key);
183217
if (it != key_to_cache_key_.end()) {
184218
auto& cache_key = *it->second;
185-
return cache_key.Increment().StartUse().getPos();
219+
if (write || !cache_key.BeingWritten())
220+
return cache_key.Increment().StartUse<write>().getPos();
186221
}
187222
return std::nullopt;
188223
}
@@ -195,7 +230,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
195230
return pos;
196231
}
197232

198-
void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
233+
template <bool write>
234+
void Unmark(int64_t key) {
235+
key_to_cache_key_[key]->EndUse<write>();
236+
}
199237

200238
private:
201239
int64_t EvictMainQueue() {
@@ -282,11 +320,18 @@ class SieveCachePolicy : public BaseCachePolicy {
282320
*/
283321
void ReadingCompleted(torch::Tensor keys);
284322

323+
/**
324+
* @brief See BaseCachePolicy::WritingCompleted.
325+
*/
326+
void WritingCompleted(torch::Tensor keys);
327+
328+
template <bool write>
285329
std::optional<int64_t> Read(int64_t key) {
286330
auto it = key_to_cache_key_.find(key);
287331
if (it != key_to_cache_key_.end()) {
288332
auto& cache_key = *it->second;
289-
return cache_key.SetFreq().StartUse().getPos();
333+
if (write || !cache_key.BeingWritten())
334+
return cache_key.SetFreq().StartUse<write>().getPos();
290335
}
291336
return std::nullopt;
292337
}
@@ -298,7 +343,10 @@ class SieveCachePolicy : public BaseCachePolicy {
298343
return pos;
299344
}
300345

301-
void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
346+
template <bool write>
347+
void Unmark(int64_t key) {
348+
key_to_cache_key_[key]->EndUse<write>();
349+
}
302350

303351
private:
304352
int64_t Evict() {
@@ -362,14 +410,22 @@ class LruCachePolicy : public BaseCachePolicy {
362410
*/
363411
void ReadingCompleted(torch::Tensor keys);
364412

413+
/**
414+
* @brief See BaseCachePolicy::WritingCompleted.
415+
*/
416+
void WritingCompleted(torch::Tensor keys);
417+
418+
template <bool write>
365419
std::optional<int64_t> Read(int64_t key) {
366420
auto it = key_to_cache_key_.find(key);
367421
if (it != key_to_cache_key_.end()) {
368422
auto cache_key = *it->second;
369-
queue_.erase(it->second);
370-
queue_.push_front(cache_key.StartUse());
371-
it->second = queue_.begin();
372-
return cache_key.getPos();
423+
if (write || !cache_key.BeingWritten()) {
424+
queue_.erase(it->second);
425+
queue_.push_front(cache_key.StartUse<write>());
426+
it->second = queue_.begin();
427+
return cache_key.getPos();
428+
}
373429
}
374430
return std::nullopt;
375431
}
@@ -381,7 +437,10 @@ class LruCachePolicy : public BaseCachePolicy {
381437
return pos;
382438
}
383439

384-
void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
440+
template <bool write>
441+
void Unmark(int64_t key) {
442+
key_to_cache_key_[key]->EndUse<write>();
443+
}
385444

386445
private:
387446
int64_t Evict() {
@@ -443,11 +502,18 @@ class ClockCachePolicy : public BaseCachePolicy {
443502
*/
444503
void ReadingCompleted(torch::Tensor keys);
445504

505+
/**
506+
* @brief See BaseCachePolicy::WritingCompleted.
507+
*/
508+
void WritingCompleted(torch::Tensor keys);
509+
510+
template <bool write>
446511
std::optional<int64_t> Read(int64_t key) {
447512
auto it = key_to_cache_key_.find(key);
448513
if (it != key_to_cache_key_.end()) {
449514
auto& cache_key = *it->second;
450-
return cache_key.SetFreq().StartUse().getPos();
515+
if (write || !cache_key.BeingWritten())
516+
return cache_key.SetFreq().StartUse<write>().getPos();
451517
}
452518
return std::nullopt;
453519
}
@@ -458,7 +524,10 @@ class ClockCachePolicy : public BaseCachePolicy {
458524
return pos;
459525
}
460526

461-
void Unmark(int64_t key) { key_to_cache_key_[key]->EndUse(); }
527+
template <bool write>
528+
void Unmark(int64_t key) {
529+
key_to_cache_key_[key]->EndUse<write>();
530+
}
462531

463532
private:
464533
int64_t Evict() {

graphbolt/src/feature_cache.cc

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
7272
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
7373
const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());
7474
const auto positions_ptr = positions.data_ptr<int64_t>();
75+
std::lock_guard lock(mtx_);
7576
torch::parallel_for(
7677
0, positions.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {
7778
for (int64_t i = begin; i < end; i++) {

graphbolt/src/feature_cache.h

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ struct FeatureCache : public torch::CustomClassHolder {
8989

9090
private:
9191
torch::Tensor tensor_;
92+
// Protects writes only as reads are guaranteed to be safe.
93+
std::mutex mtx_;
9294
};
9395

9496
} // namespace storage

graphbolt/src/partitioned_cache_policy.cc

+23-3
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,14 @@ c10::intrusive_ptr<Future<torch::Tensor>> PartitionedCachePolicy::ReplaceAsync(
242242
return async([=] { return Replace(keys); });
243243
}
244244

245-
void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
245+
template <bool write>
246+
void PartitionedCachePolicy::ReadingWritingCompletedImpl(torch::Tensor keys) {
246247
if (policies_.size() == 1) {
247248
std::lock_guard lock(mtx_);
248-
policies_[0]->ReadingCompleted(keys);
249+
if constexpr (write)
250+
policies_[0]->WritingCompleted(keys);
251+
else
252+
policies_[0]->ReadingCompleted(keys);
249253
return;
250254
}
251255
torch::Tensor offsets, indices, permuted_keys;
@@ -257,15 +261,31 @@ void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
257261
const auto tid = begin;
258262
begin = offsets_ptr[tid];
259263
end = offsets_ptr[tid + 1];
260-
policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end));
264+
if constexpr (write)
265+
policies_.at(tid)->WritingCompleted(permuted_keys.slice(0, begin, end));
266+
else
267+
policies_.at(tid)->ReadingCompleted(permuted_keys.slice(0, begin, end));
261268
});
262269
}
263270

271+
void PartitionedCachePolicy::ReadingCompleted(torch::Tensor keys) {
272+
ReadingWritingCompletedImpl<false>(keys);
273+
}
274+
275+
void PartitionedCachePolicy::WritingCompleted(torch::Tensor keys) {
276+
ReadingWritingCompletedImpl<true>(keys);
277+
}
278+
264279
c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::ReadingCompletedAsync(
265280
torch::Tensor keys) {
266281
return async([=] { return ReadingCompleted(keys); });
267282
}
268283

284+
c10::intrusive_ptr<Future<void>> PartitionedCachePolicy::WritingCompletedAsync(
285+
torch::Tensor keys) {
286+
return async([=] { return WritingCompleted(keys); });
287+
}
288+
269289
template <typename CachePolicy>
270290
c10::intrusive_ptr<PartitionedCachePolicy> PartitionedCachePolicy::Create(
271291
int64_t capacity, int64_t num_partitions) {

0 commit comments

Comments
 (0)