@@ -675,18 +675,25 @@ def __init__(
675675                    if  self .kv_zch_params .eviction_policy .eviction_mem_threshold_gb 
676676                    else  self .l2_cache_size 
677677                )
678+                 # Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters. 
678679                eviction_config  =  torch .classes .fbgemm .FeatureEvictConfig (
679680                    self .kv_zch_params .eviction_policy .eviction_trigger_mode ,  # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual 
680-                     self .kv_zch_params .eviction_policy .eviction_strategy ,  # evict_trigger_strategy: 0: timestamp, 1: counter (feature score) , 2: counter (feature score)  + timestamp, 3: feature l2 norm 
681+                     self .kv_zch_params .eviction_policy .eviction_strategy ,  # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score  
681682                    self .kv_zch_params .eviction_policy .eviction_step_intervals ,  # trigger_step_interval if trigger mode is iteration 
682683                    eviction_mem_threshold_gb ,  # mem_util_threshold_in_GB if trigger mode is mem_util 
683684                    self .kv_zch_params .eviction_policy .ttls_in_mins ,  # ttls_in_mins for each table if eviction strategy is timestamp 
684-                     self .kv_zch_params .eviction_policy .counter_thresholds ,  # counter_thresholds for each table if eviction strategy is feature score 
685-                     self .kv_zch_params .eviction_policy .counter_decay_rates ,  # counter_decay_rates for each table if eviction strategy is feature score 
685+                     self .kv_zch_params .eviction_policy .counter_thresholds ,  # counter_thresholds for each table if eviction strategy is counter 
686+                     self .kv_zch_params .eviction_policy .counter_decay_rates ,  # counter_decay_rates for each table if eviction strategy is counter 
687+                     self .kv_zch_params .eviction_policy .feature_score_counter_decay_rates ,  # feature_score_counter_decay_rates for each table if eviction strategy is feature score 
688+                     self .kv_zch_params .eviction_policy .max_training_id_num_per_table ,  # max_training_id_num for each table 
689+                     self .kv_zch_params .eviction_policy .target_eviction_percent_per_table ,  # target_eviction_percent for each table 
686690                    self .kv_zch_params .eviction_policy .l2_weight_thresholds ,  # l2_weight_thresholds for each table if eviction strategy is feature l2 norm 
687691                    table_dims .tolist () if  table_dims  is  not None  else  None ,
692+                     self .kv_zch_params .eviction_policy .threshold_calculation_bucket_stride ,  # threshold_calculation_bucket_stride if eviction strategy is feature score 
693+                     self .kv_zch_params .eviction_policy .threshold_calculation_bucket_num ,  # threshold_calculation_bucket_num if eviction strategy is feature score 
688694                    self .kv_zch_params .eviction_policy .interval_for_insufficient_eviction_s ,
689695                    self .kv_zch_params .eviction_policy .interval_for_sufficient_eviction_s ,
696+                     self .kv_zch_params .eviction_policy .interval_for_feature_statistics_decay_s ,
690697                )
691698            self ._ssd_db  =  torch .classes .fbgemm .DramKVEmbeddingCacheWrapper (
692699                self .cache_row_dim ,
@@ -1018,6 +1025,9 @@ def __init__(
10181025            self .stats_reporter .register_stats (
10191026                "eviction.feature_table.exec_duration_ms" 
10201027            )
1028+             self .stats_reporter .register_stats (
1029+                 "eviction.feature_table.dry_run_exec_duration_ms" 
1030+             )
10211031            self .stats_reporter .register_stats (
10221032                "eviction.feature_table.exec_div_full_duration_rate" 
10231033            )
@@ -1605,6 +1615,7 @@ def prefetch(
16051615        self ,
16061616        indices : Tensor ,
16071617        offsets : Tensor ,
1618+         weights : Optional [Tensor ] =  None ,  # todo: need to update caller 
16081619        forward_stream : Optional [torch .cuda .Stream ] =  None ,
16091620        batch_size_per_feature_per_rank : Optional [List [List [int ]]] =  None ,
16101621    ) ->  None :
@@ -1630,6 +1641,7 @@ def prefetch(
16301641        self ._prefetch (
16311642            indices ,
16321643            offsets ,
1644+             weights ,
16331645            vbe_metadata ,
16341646            forward_stream ,
16351647        )
@@ -1638,6 +1650,7 @@ def _prefetch(  # noqa C901
16381650        self ,
16391651        indices : Tensor ,
16401652        offsets : Tensor ,
1653+         weights : Optional [Tensor ] =  None ,
16411654        vbe_metadata : Optional [invokers .lookup_args .VBEMetadata ] =  None ,
16421655        forward_stream : Optional [torch .cuda .Stream ] =  None ,
16431656    ) ->  None :
@@ -1665,6 +1678,12 @@ def _prefetch(  # noqa C901
16651678
16661679            self .timestep  +=  1 
16671680            self .timesteps_prefetched .append (self .timestep )
1681+             if  self .backend_type  ==  BackendType .DRAM  and  weights  is  not None :
1682+                 # DRAM backend supports feature score eviction, if there is weights available 
1683+                 # in the prefetch call, we will set metadata for feature score eviction asynchronously 
1684+                 cloned_linear_cache_indices  =  linear_cache_indices .clone ()
1685+             else :
1686+                 cloned_linear_cache_indices  =  None 
16681687
16691688            # Lookup and virtually insert indices into L1. After this operator, 
16701689            # we know: 
@@ -2022,6 +2041,16 @@ def _prefetch(  # noqa C901
20222041                    is_bwd = False ,
20232042                )
20242043
2044+             if  self .backend_type  ==  BackendType .DRAM  and  weights  is  not None :
2045+                 # Write feature score metadata to DRAM 
2046+                 self .record_function_via_dummy_profile (
2047+                     "## ssd_write_feature_score_metadata ##" ,
2048+                     self .ssd_db .set_feature_score_metadata_cuda ,
2049+                     cloned_linear_cache_indices .cpu (),
2050+                     torch .tensor ([weights .shape [0 ]], device = "cpu" , dtype = torch .long ),
2051+                     weights .cpu ().view (torch .float32 ).view (- 1 , 2 ),
2052+                 )
2053+ 
20252054            # Generate row addresses (pointing to either L1 or the current 
20262055            # iteration's scratch pad) 
20272056            with  record_function ("## ssd_generate_row_addrs ##" ):
@@ -2164,6 +2193,7 @@ def forward(
21642193        self ,
21652194        indices : Tensor ,
21662195        offsets : Tensor ,
2196+         weights : Optional [Tensor ] =  None ,
21672197        per_sample_weights : Optional [Tensor ] =  None ,
21682198        feature_requires_grad : Optional [Tensor ] =  None ,
21692199        batch_size_per_feature_per_rank : Optional [List [List [int ]]] =  None ,
@@ -2185,7 +2215,7 @@ def forward(
21852215                context = self .step ,
21862216                stream = self .ssd_eviction_stream ,
21872217            ):
2188-                 self ._prefetch (indices , offsets , vbe_metadata )
2218+                 self ._prefetch (indices , offsets , weights ,  vbe_metadata )
21892219
21902220        assert  len (self .ssd_prefetch_data ) >  0 
21912221
@@ -3745,8 +3775,13 @@ def _report_eviction_stats(self) -> None:
37453775        processed_counts  =  torch .zeros (T , dtype = torch .int64 )
37463776        full_duration_ms  =  torch .tensor (0 , dtype = torch .int64 )
37473777        exec_duration_ms  =  torch .tensor (0 , dtype = torch .int64 )
3778+         dry_run_exec_duration_ms  =  torch .tensor (0 , dtype = torch .int64 )
37483779        self .ssd_db .get_feature_evict_metric (
3749-             evicted_counts , processed_counts , full_duration_ms , exec_duration_ms 
3780+             evicted_counts ,
3781+             processed_counts ,
3782+             full_duration_ms ,
3783+             exec_duration_ms ,
3784+             dry_run_exec_duration_ms ,
37503785        )
37513786
37523787        stats_reporter .report_data_amount (
@@ -3798,6 +3833,12 @@ def _report_eviction_stats(self) -> None:
37983833            duration_ms = exec_duration_ms .item (),
37993834            time_unit = "ms" ,
38003835        )
3836+         stats_reporter .report_duration (
3837+             iteration_step = self .step ,
3838+             event_name = "eviction.feature_table.dry_run_exec_duration_ms" ,
3839+             duration_ms = dry_run_exec_duration_ms .item (),
3840+             time_unit = "ms" ,
3841+         )
38013842        if  full_duration_ms .item () !=  0 :
38023843            stats_reporter .report_data_amount (
38033844                iteration_step = self .step ,
0 commit comments