Skip to content

Commit 2fc7731

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Feature score eviction backend support (#4681)
Summary: X-link: facebookresearch/FBGEMM#1707 ## Context We need a new eviction policy for large embedding which has high id growth rate. The feature score eviction is based on engagement rate of id instead of only time or counter. This will help model to keep all relatively important ids during eviction. ## Detail * New Eviction Strategy: BY_FEATURE_SCORE Added a new eviction trigger strategy BY_FEATURE_SCORE in the eviction config and logic. This strategy uses feature scores derived from engagement rates to decide which IDs to evict. * FeatureScoreBasedEvict Class Implements the feature score based eviction logic. Maintains buckets of feature scores per shard and table to compute eviction thresholds. * Supports a dry-run mode to calculate thresholds before actual eviction. Eviction decisions are based on thresholds computed from feature score distributions. Supports decay of feature score statistics over time. * Async Metadata Update API Added set_kv_zch_eviction_metadata_async method to update feature score metadata asynchronously in the KV store. This method shards the input indices and engagement rates and updates the feature score statistics in parallel. * Dry Run Eviction Mode Introduced a dry run mode to simulate eviction rounds to compute thresholds without actually evicting. Dry run results are used to finalize thresholds for real eviction rounds. Differential Revision: D78138679
1 parent 3574c54 commit 2fc7731

13 files changed

+1087
-107
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,111 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
633633
});
634634
}
635635

636+
/// Update feature scores metadata into kvstore.
637+
folly::SemiFuture<std::vector<folly::Unit>>
638+
set_kv_zch_eviction_metadata_async(
639+
at::Tensor indices,
640+
at::Tensor count,
641+
at::Tensor engege_rates) override {
642+
if (!feature_evict_ || !feature_evict_config_.has_value() ||
643+
feature_evict_config_.value()->trigger_mode_ ==
644+
EvictTriggerMode::DISABLED) {
645+
// featre eviction is disabled
646+
return folly::makeSemiFuture(std::vector<folly::Unit>());
647+
}
648+
649+
CHECK_EQ(engege_rates.scalar_type(), at::ScalarType::Float);
650+
auto* feature_score_evict =
651+
dynamic_cast<FeatureScoreBasedEvict<weight_type>*>(
652+
feature_evict_.get());
653+
654+
if (feature_score_evict == nullptr) {
655+
// Not a feature score based eviction
656+
return folly::makeSemiFuture(std::vector<folly::Unit>());
657+
}
658+
pause_ongoing_eviction();
659+
std::vector<folly::Future<int64_t>> futures;
660+
auto shardid_to_indexes = shard_input(indices, count);
661+
for (auto iter = shardid_to_indexes.begin();
662+
iter != shardid_to_indexes.end();
663+
iter++) {
664+
const auto shard_id = iter->first;
665+
const auto indexes = iter->second;
666+
auto f =
667+
folly::via(executor_.get())
668+
.thenValue([this,
669+
shard_id,
670+
indexes,
671+
indices,
672+
engege_rates,
673+
feature_score_evict](folly::Unit) {
674+
int64_t updated_id_count = 0;
675+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
676+
indices.scalar_type(),
677+
"dram_set_kv_feature_score_metadata",
678+
[this,
679+
shard_id,
680+
indexes,
681+
indices,
682+
engege_rates,
683+
feature_score_evict,
684+
&updated_id_count] {
685+
using index_t = scalar_t;
686+
CHECK(indices.is_contiguous());
687+
CHECK(engege_rates.is_contiguous());
688+
CHECK_EQ(indices.size(0), engege_rates.size(0));
689+
auto indices_data_ptr = indices.data_ptr<index_t>();
690+
auto engage_rate_ptr = engege_rates.data_ptr<float>();
691+
int64_t stride = 2;
692+
{
693+
auto wlmap = kv_store_.by(shard_id).wlock();
694+
auto* pool = kv_store_.pool_by(shard_id);
695+
696+
for (auto index_iter = indexes.begin();
697+
index_iter != indexes.end();
698+
index_iter++) {
699+
const auto& id_index = *index_iter;
700+
auto id = int64_t(indices_data_ptr[id_index]);
701+
float engege_rate =
702+
float(engage_rate_ptr[id_index * stride + 0]);
703+
// use mempool
704+
weight_type* block = nullptr;
705+
auto it = wlmap->find(id);
706+
if (it != wlmap->end()) {
707+
block = it->second;
708+
} else {
709+
// Key doesn't exist, allocate new block and
710+
// insert.
711+
block = pool->template allocate_t<weight_type>();
712+
FixedBlockPool::set_key(block, id);
713+
wlmap->insert({id, block});
714+
}
715+
716+
feature_score_evict->update_feature_score_statistics(
717+
block, engege_rate);
718+
updated_id_count++;
719+
}
720+
}
721+
});
722+
return updated_id_count;
723+
});
724+
futures.push_back(std::move(f));
725+
}
726+
return folly::collect(std::move(futures))
727+
.via(executor_.get())
728+
.thenValue([this](const std::vector<int64_t>& results) {
729+
resume_ongoing_eviction();
730+
int total_updated_ids = 0;
731+
for (const auto& result : results) {
732+
total_updated_ids += result;
733+
}
734+
LOG(INFO)
735+
<< "[DRAM KV][Feature Score Eviction]Total updated IDs across all shards: "
736+
<< total_updated_ids;
737+
return std::vector<folly::Unit>(results.size());
738+
});
739+
}
740+
636741
/// Get embeddings from kvstore.
637742
///
638743
/// @param indices The 1D embedding index tensor, should skip on negative

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
7676
at::Tensor count,
7777
int64_t timestep,
7878
bool is_bwd) {
79-
return impl_->set_cuda(indices, weights, count, timestep);
79+
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
8080
}
8181

8282
void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
@@ -147,7 +147,8 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
147147
at::Tensor evicted_counts,
148148
at::Tensor processed_counts,
149149
at::Tensor full_duration_ms,
150-
at::Tensor exec_duration_ms) {
150+
at::Tensor exec_duration_ms,
151+
at::Tensor dry_run_exec_duration_ms) {
151152
auto metrics = impl_->get_feature_evict_metric();
152153
if (metrics.has_value()) {
153154
evicted_counts.copy_(
@@ -158,6 +159,8 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
158159
metrics.value().full_duration_ms); // full duration (Long)
159160
exec_duration_ms.copy_(
160161
metrics.value().exec_duration_ms); // exec duration (Long)
162+
dry_run_exec_duration_ms.copy_(
163+
metrics.value().dry_run_exec_duration_ms); // dry run exec duration
161164
}
162165
}
163166

@@ -169,6 +172,13 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
169172
impl_->set_backend_return_whole_row(backend_return_whole_row);
170173
}
171174

175+
void set_feature_score_metadata_cuda(
176+
at::Tensor indices,
177+
at::Tensor count,
178+
at::Tensor engage_show_count) {
179+
impl_->set_feature_score_metadata_cuda(indices, count, engage_show_count);
180+
}
181+
172182
private:
173183
// friend class EmbeddingRocksDBWrapper;
174184
friend class ssd::KVTensorWrapper;

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,16 @@ void DramKVEmbeddingInferenceWrapper::init(
5050
std::nullopt /* ttls_in_mins */,
5151
std::nullopt /* counter_thresholds */,
5252
std::nullopt /* counter_decay_rates */,
53+
std::nullopt /* feature_score_counter_decay_rates */,
54+
std::nullopt /* max_training_id_num_per_table */,
55+
std::nullopt /* target_eviction_percent_per_table */,
5356
std::nullopt /* l2_weight_thresholds */,
5457
std::nullopt /* embedding_dims */,
58+
std::nullopt /* threshold_calculation_bucket_stride */,
59+
std::nullopt /* threshold_calculation_bucket_num */,
5560
0 /* interval for insufficient eviction s*/,
56-
0 /* interval for sufficient eviction s*/),
61+
0 /* interval for sufficient eviction s*/,
62+
0 /* interval_for_feature_statistics_decay_s_*/),
5763
num_shards_ /* num_shards */,
5864
num_shards_ /* num_threads */,
5965
8 /* row_storage_bitwidth */,

0 commit comments

Comments
 (0)