@@ -675,18 +675,25 @@ def __init__(
675675 if self .kv_zch_params .eviction_policy .eviction_mem_threshold_gb
676676 else self .l2_cache_size
677677 )
678+ # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
678679 eviction_config = torch .classes .fbgemm .FeatureEvictConfig (
679680 self .kv_zch_params .eviction_policy .eviction_trigger_mode , # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
680- self .kv_zch_params .eviction_policy .eviction_strategy , # evict_trigger_strategy: 0: timestamp, 1: counter (feature score) , 2: counter (feature score) + timestamp, 3: feature l2 norm
681+ self .kv_zch_params .eviction_policy .eviction_strategy , # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
681682 self .kv_zch_params .eviction_policy .eviction_step_intervals , # trigger_step_interval if trigger mode is iteration
682683 eviction_mem_threshold_gb , # mem_util_threshold_in_GB if trigger mode is mem_util
683684 self .kv_zch_params .eviction_policy .ttls_in_mins , # ttls_in_mins for each table if eviction strategy is timestamp
684- self .kv_zch_params .eviction_policy .counter_thresholds , # counter_thresholds for each table if eviction strategy is feature score
685- self .kv_zch_params .eviction_policy .counter_decay_rates , # counter_decay_rates for each table if eviction strategy is feature score
685+ self .kv_zch_params .eviction_policy .counter_thresholds , # counter_thresholds for each table if eviction strategy is counter
686+ self .kv_zch_params .eviction_policy .counter_decay_rates , # counter_decay_rates for each table if eviction strategy is counter
687+ self .kv_zch_params .eviction_policy .feature_score_counter_decay_rates , # feature_score_counter_decay_rates for each table if eviction strategy is feature score
688+ self .kv_zch_params .eviction_policy .max_training_id_num_per_table , # max_training_id_num for each table
689+ self .kv_zch_params .eviction_policy .target_eviction_percent_per_table , # target_eviction_percent for each table
686690 self .kv_zch_params .eviction_policy .l2_weight_thresholds , # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
687691 table_dims .tolist () if table_dims is not None else None ,
692+ self .kv_zch_params .eviction_policy .threshold_calculation_bucket_stride , # threshold_calculation_bucket_stride if eviction strategy is feature score
693+ self .kv_zch_params .eviction_policy .threshold_calculation_bucket_num , # threshold_calculation_bucket_num if eviction strategy is feature score
688694 self .kv_zch_params .eviction_policy .interval_for_insufficient_eviction_s ,
689695 self .kv_zch_params .eviction_policy .interval_for_sufficient_eviction_s ,
696+ self .kv_zch_params .eviction_policy .interval_for_feature_statistics_decay_s ,
690697 )
691698 self ._ssd_db = torch .classes .fbgemm .DramKVEmbeddingCacheWrapper (
692699 self .cache_row_dim ,
@@ -1018,6 +1025,9 @@ def __init__(
10181025 self .stats_reporter .register_stats (
10191026 "eviction.feature_table.exec_duration_ms"
10201027 )
1028+ self .stats_reporter .register_stats (
1029+ "eviction.feature_table.dry_run_exec_duration_ms"
1030+ )
10211031 self .stats_reporter .register_stats (
10221032 "eviction.feature_table.exec_div_full_duration_rate"
10231033 )
@@ -1605,6 +1615,7 @@ def prefetch(
16051615 self ,
16061616 indices : Tensor ,
16071617 offsets : Tensor ,
1618+ weights : Optional [Tensor ] = None , # todo: need to update caller
16081619 forward_stream : Optional [torch .cuda .Stream ] = None ,
16091620 batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
16101621 ) -> None :
@@ -1630,6 +1641,7 @@ def prefetch(
16301641 self ._prefetch (
16311642 indices ,
16321643 offsets ,
1644+ weights ,
16331645 vbe_metadata ,
16341646 forward_stream ,
16351647 )
@@ -1638,6 +1650,7 @@ def _prefetch( # noqa C901
16381650 self ,
16391651 indices : Tensor ,
16401652 offsets : Tensor ,
1653+ weights : Optional [Tensor ] = None ,
16411654 vbe_metadata : Optional [invokers .lookup_args .VBEMetadata ] = None ,
16421655 forward_stream : Optional [torch .cuda .Stream ] = None ,
16431656 ) -> None :
@@ -1665,6 +1678,12 @@ def _prefetch( # noqa C901
16651678
16661679 self .timestep += 1
16671680 self .timesteps_prefetched .append (self .timestep )
1681+ if self .backend_type == BackendType .DRAM and weights is not None :
1682+ # DRAM backend supports feature score eviction, if there is weights available
1683+ # in the prefetch call, we will set metadata for feature score eviction asynchronously
1684+ cloned_linear_cache_indices = linear_cache_indices .clone ()
1685+ else :
1686+ cloned_linear_cache_indices = None
16681687
16691688 # Lookup and virtually insert indices into L1. After this operator,
16701689 # we know:
@@ -2022,6 +2041,16 @@ def _prefetch( # noqa C901
20222041 is_bwd = False ,
20232042 )
20242043
2044+ if self .backend_type == BackendType .DRAM and weights is not None :
2045+ # Write feature score metadata to DRAM
2046+ self .record_function_via_dummy_profile (
2047+ "## ssd_write_feature_score_metadata ##" ,
2048+ self .ssd_db .set_feature_score_metadata_cuda ,
2049+ cloned_linear_cache_indices .cpu (),
2050+ torch .tensor ([weights .shape [0 ]], device = "cpu" , dtype = torch .long ),
2051+ weights .cpu ().view (torch .float32 ).view (- 1 , 2 ),
2052+ )
2053+
20252054 # Generate row addresses (pointing to either L1 or the current
20262055 # iteration's scratch pad)
20272056 with record_function ("## ssd_generate_row_addrs ##" ):
@@ -2164,6 +2193,7 @@ def forward(
21642193 self ,
21652194 indices : Tensor ,
21662195 offsets : Tensor ,
2196+ weights : Optional [Tensor ] = None ,
21672197 per_sample_weights : Optional [Tensor ] = None ,
21682198 feature_requires_grad : Optional [Tensor ] = None ,
21692199 batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
@@ -2185,7 +2215,7 @@ def forward(
21852215 context = self .step ,
21862216 stream = self .ssd_eviction_stream ,
21872217 ):
2188- self ._prefetch (indices , offsets , vbe_metadata )
2218+ self ._prefetch (indices , offsets , weights , vbe_metadata )
21892219
21902220 assert len (self .ssd_prefetch_data ) > 0
21912221
@@ -3745,8 +3775,13 @@ def _report_eviction_stats(self) -> None:
37453775 processed_counts = torch .zeros (T , dtype = torch .int64 )
37463776 full_duration_ms = torch .tensor (0 , dtype = torch .int64 )
37473777 exec_duration_ms = torch .tensor (0 , dtype = torch .int64 )
3778+ dry_run_exec_duration_ms = torch .tensor (0 , dtype = torch .int64 )
37483779 self .ssd_db .get_feature_evict_metric (
3749- evicted_counts , processed_counts , full_duration_ms , exec_duration_ms
3780+ evicted_counts ,
3781+ processed_counts ,
3782+ full_duration_ms ,
3783+ exec_duration_ms ,
3784+ dry_run_exec_duration_ms ,
37503785 )
37513786
37523787 stats_reporter .report_data_amount (
@@ -3798,6 +3833,12 @@ def _report_eviction_stats(self) -> None:
37983833 duration_ms = exec_duration_ms .item (),
37993834 time_unit = "ms" ,
38003835 )
3836+ stats_reporter .report_duration (
3837+ iteration_step = self .step ,
3838+ event_name = "eviction.feature_table.dry_run_exec_duration_ms" ,
3839+ duration_ms = dry_run_exec_duration_ms .item (),
3840+ time_unit = "ms" ,
3841+ )
38013842 if full_duration_ms .item () != 0 :
38023843 stats_reporter .report_data_amount (
38033844 iteration_step = self .step ,
0 commit comments