@@ -672,16 +672,22 @@ def __init__(
672672 )
673673 eviction_config = torch .classes .fbgemm .FeatureEvictConfig (
674674 self .kv_zch_params .eviction_policy .eviction_trigger_mode , # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
675- self .kv_zch_params .eviction_policy .eviction_strategy , # evict_trigger_strategy: 0: timestamp, 1: counter (feature score) , 2: counter (feature score) + timestamp, 3: feature l2 norm
675+ self .kv_zch_params .eviction_policy .eviction_strategy , # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
676676 self .kv_zch_params .eviction_policy .eviction_step_intervals , # trigger_step_interval if trigger mode is iteration
677677 eviction_mem_threshold_gb , # mem_util_threshold_in_GB if trigger mode is mem_util
678678 self .kv_zch_params .eviction_policy .ttls_in_mins , # ttls_in_mins for each table if eviction strategy is timestamp
679- self .kv_zch_params .eviction_policy .counter_thresholds , # counter_thresholds for each table if eviction strategy is feature score
680- self .kv_zch_params .eviction_policy .counter_decay_rates , # counter_decay_rates for each table if eviction strategy is feature score
679+ self .kv_zch_params .eviction_policy .counter_thresholds , # counter_thresholds for each table if eviction strategy is counter
680+ self .kv_zch_params .eviction_policy .counter_decay_rates , # counter_decay_rates for each table if eviction strategy is counter
681+ self .kv_zch_params .eviction_policy .feature_score_counter_decay_rates , # feature_score_counter_decay_rates for each table if eviction strategy is feature score
682+ self .kv_zch_params .eviction_policy .max_training_id_num_per_table , # max_training_id_num for each table
683+ self .kv_zch_params .eviction_policy .target_eviction_percent_per_table , # target_eviction_percent for each table
681684 self .kv_zch_params .eviction_policy .l2_weight_thresholds , # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
682685 table_dims .tolist () if table_dims is not None else None ,
686+ self .kv_zch_params .eviction_policy .threshold_calculation_bucket_stride , # threshold_calculation_bucket_stride if eviction strategy is feature score
687+ self .kv_zch_params .eviction_policy .threshold_calculation_bucket_num , # threshold_calculation_bucket_num if eviction strategy is feature score
683688 self .kv_zch_params .eviction_policy .interval_for_insufficient_eviction_s ,
684689 self .kv_zch_params .eviction_policy .interval_for_sufficient_eviction_s ,
690+ self .kv_zch_params .eviction_policy .interval_for_feature_statistics_decay_s ,
685691 )
686692 self ._ssd_db = torch .classes .fbgemm .DramKVEmbeddingCacheWrapper (
687693 self .cache_row_dim ,
@@ -1013,6 +1019,9 @@ def __init__(
10131019 self .stats_reporter .register_stats (
10141020 "eviction.feature_table.exec_duration_ms"
10151021 )
1022+ self .stats_reporter .register_stats (
1023+ "eviction.feature_table.dry_run_exec_duration_ms"
1024+ )
10161025 self .stats_reporter .register_stats (
10171026 "eviction.feature_table.exec_div_full_duration_rate"
10181027 )
@@ -1600,6 +1609,7 @@ def prefetch(
16001609 self ,
16011610 indices : Tensor ,
16021611 offsets : Tensor ,
1612+ weights : Optional [Tensor ] = None , # todo: need to update caller
16031613 forward_stream : Optional [torch .cuda .Stream ] = None ,
16041614 batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
16051615 ) -> None :
@@ -1625,6 +1635,7 @@ def prefetch(
16251635 self ._prefetch (
16261636 indices ,
16271637 offsets ,
1638+ weights ,
16281639 vbe_metadata ,
16291640 forward_stream ,
16301641 )
@@ -1633,6 +1644,7 @@ def _prefetch( # noqa C901
16331644 self ,
16341645 indices : Tensor ,
16351646 offsets : Tensor ,
1647+ weights : Optional [Tensor ] = None ,
16361648 vbe_metadata : Optional [invokers .lookup_args .VBEMetadata ] = None ,
16371649 forward_stream : Optional [torch .cuda .Stream ] = None ,
16381650 ) -> None :
@@ -1660,6 +1672,12 @@ def _prefetch( # noqa C901
16601672
16611673 self .timestep += 1
16621674 self .timesteps_prefetched .append (self .timestep )
1675+ if self .backend_type == BackendType .DRAM and weights is not None :
1676+ # DRAM backend supports feature score eviction, if there is weights available
1677+ # in the prefetch call, we will set metadata for feature score eviction asynchronously
1678+ cloned_linear_cache_indices = linear_cache_indices .clone ()
1679+ else :
1680+ cloned_linear_cache_indices = None
16631681
16641682 # Lookup and virtually insert indices into L1. After this operator,
16651683 # we know:
@@ -1691,6 +1709,18 @@ def _prefetch( # noqa C901
16911709 lxu_cache_locking_counter = self .lxu_cache_locking_counter ,
16921710 )
16931711
1712+ # acc_weights is a 2d tensor, dim0: engagement counter, dim1: show counter
1713+ # how to get unique indices atm? we only have inserted_indices, evicted_indices and unique_indices_length
1714+ acc_weights = (
1715+ torch .ops .fbgemm .jagged_acc_weights_and_counts_2d_tensor (
1716+ weights .view (torch .float32 ).view (- 1 , 2 ),
1717+ linear_index_inverse_indices ,
1718+ unique_indices_length ,
1719+ )
1720+ if weights is not None
1721+ else None
1722+ )
1723+
16941724 # Compute cache locations (rows that are hit are missed but can be
16951725 # inserted will have cache locations != -1)
16961726 with record_function ("## ssd_tbe_lxu_cache_lookup ##" ):
@@ -2015,6 +2045,16 @@ def _prefetch( # noqa C901
20152045 is_bwd = False ,
20162046 )
20172047
2048+ if self .backend_type == BackendType .DRAM and weights is not None :
2049+ # Write feature score metadata to DRAM
2050+ self .record_function_via_dummy_profile (
2051+ "## ssd_write_feature_score_metadata ##" ,
2052+ self .ssd_db .set_feature_score_metadata_cuda ,
2053+ cloned_linear_cache_indices .cpu (),
2054+ torch .tensor ([weights .shape [0 ]], device = "cpu" , dtype = torch .long ),
2055+ weights .cpu ().view (torch .float32 ).view (- 1 , 2 ),
2056+ )
2057+
20182058 # Generate row addresses (pointing to either L1 or the current
20192059 # iteration's scratch pad)
20202060 with record_function ("## ssd_generate_row_addrs ##" ):
@@ -2157,6 +2197,7 @@ def forward(
21572197 self ,
21582198 indices : Tensor ,
21592199 offsets : Tensor ,
2200+ weights : Optional [Tensor ] = None ,
21602201 per_sample_weights : Optional [Tensor ] = None ,
21612202 feature_requires_grad : Optional [Tensor ] = None ,
21622203 batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
@@ -2178,7 +2219,7 @@ def forward(
21782219 context = self .step ,
21792220 stream = self .ssd_eviction_stream ,
21802221 ):
2181- self ._prefetch (indices , offsets , vbe_metadata )
2222+ self ._prefetch (indices , offsets , weights , vbe_metadata )
21822223
21832224 assert len (self .ssd_prefetch_data ) > 0
21842225
@@ -3738,8 +3779,13 @@ def _report_eviction_stats(self) -> None:
37383779 processed_counts = torch .zeros (T , dtype = torch .int64 )
37393780 full_duration_ms = torch .tensor (0 , dtype = torch .int64 )
37403781 exec_duration_ms = torch .tensor (0 , dtype = torch .int64 )
3782+ dry_run_exec_duration_ms = torch .tensor (0 , dtype = torch .int64 )
37413783 self .ssd_db .get_feature_evict_metric (
3742- evicted_counts , processed_counts , full_duration_ms , exec_duration_ms
3784+ evicted_counts ,
3785+ processed_counts ,
3786+ full_duration_ms ,
3787+ exec_duration_ms ,
3788+ dry_run_exec_duration_ms ,
37433789 )
37443790
37453791 stats_reporter .report_data_amount (
@@ -3791,6 +3837,12 @@ def _report_eviction_stats(self) -> None:
37913837 duration_ms = exec_duration_ms .item (),
37923838 time_unit = "ms" ,
37933839 )
3840+ stats_reporter .report_duration (
3841+ iteration_step = self .step ,
3842+ event_name = "eviction.feature_table.dry_run_exec_duration_ms" ,
3843+ duration_ms = dry_run_exec_duration_ms .item (),
3844+ time_unit = "ms" ,
3845+ )
37943846 if full_duration_ms .item () != 0 :
37953847 stats_reporter .report_data_amount (
37963848 iteration_step = self .step ,
0 commit comments