Skip to content

Commit dfda8f6

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Feature score eviction frontend support
Summary: Adding support for feature score eviction in frontend. Differential Revision: D79591336
1 parent 2fc7731 commit dfda8f6

File tree

3 files changed

+234
-23
lines changed

3 files changed

+234
-23
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class EvictionPolicy(NamedTuple):
6565
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
6666
)
6767
eviction_strategy: int = (
68-
0 # 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
68+
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
6969
)
7070
eviction_step_intervals: Optional[int] = (
7171
None # trigger_step_interval if trigger mode is iteration
@@ -74,17 +74,32 @@ class EvictionPolicy(NamedTuple):
7474
None # eviction trigger condition if trigger mode is mem_util
7575
)
7676
counter_thresholds: Optional[List[int]] = (
77-
None # count_thresholds for each table if eviction strategy is feature score
77+
None # count_thresholds for each table if eviction strategy is counter
7878
)
7979
ttls_in_mins: Optional[List[int]] = (
8080
None # ttls_in_mins for each table if eviction strategy is timestamp
8181
)
8282
counter_decay_rates: Optional[List[float]] = (
83-
None # count_decay_rates for each table if eviction strategy is feature score
83+
None # count_decay_rates for each table if eviction strategy is counter
84+
)
85+
feature_score_counter_decay_rates: Optional[List[float]] = (
86+
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
87+
)
88+
max_training_id_num_per_table: Optional[List[int]] = (
89+
None # max_training_id_num_per_table for each table
90+
)
91+
target_eviction_percent_per_table: Optional[List[float]] = (
92+
None # target_eviction_percent_per_table for each table
8493
)
8594
l2_weight_thresholds: Optional[List[float]] = (
8695
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
8796
)
97+
threshold_calculation_bucket_stride: Optional[float] = (
98+
0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score
99+
)
100+
threshold_calculation_bucket_num: Optional[int] = (
101+
1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score
102+
)
88103
interval_for_insufficient_eviction_s: int = (
89104
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
90105
# insufficient means we didn't evict enough rows, so we want to wait longer time to
@@ -95,6 +110,9 @@ class EvictionPolicy(NamedTuple):
95110
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
96111
60
97112
)
113+
interval_for_feature_statistics_decay_s: int = (
114+
24 * 3600 # 1 day, interval for feature statistics decay
115+
)
98116
meta_header_lens: Optional[List[int]] = None # metaheader length for each table
99117

100118
def validate(self) -> None:
@@ -105,8 +123,8 @@ def validate(self) -> None:
105123
if self.eviction_trigger_mode == 0:
106124
return
107125

108-
assert self.eviction_strategy in [0, 1, 2, 3], (
109-
"eviction_strategy must be 0, 1, 2, or 3, "
126+
assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
127+
"eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
110128
f"actual {self.eviction_strategy}"
111129
)
112130
if self.eviction_trigger_mode == 1:
@@ -161,6 +179,35 @@ def validate(self) -> None:
161179
"counter_thresholds and ttls_in_mins must have the same length, "
162180
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
163181
)
182+
elif self.eviction_strategy == 5:
183+
assert self.feature_score_counter_decay_rates is not None, (
184+
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
185+
f"actual {self.feature_score_counter_decay_rates}"
186+
)
187+
assert self.max_training_id_num_per_table is not None, (
188+
"max_training_id_num_per_table must be set if eviction_strategy is 5,"
189+
f"actual {self.max_training_id_num_per_table}"
190+
)
191+
assert self.target_eviction_percent_per_table is not None, (
192+
"target_eviction_percent_per_table must be set if eviction_strategy is 5,"
193+
f"actual {self.target_eviction_percent_per_table}"
194+
)
195+
assert self.threshold_calculation_bucket_stride is not None, (
196+
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
197+
f"actual {self.threshold_calculation_bucket_stride}"
198+
)
199+
assert self.threshold_calculation_bucket_num is not None, (
200+
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
201+
f"actual {self.threshold_calculation_bucket_num}"
202+
)
203+
assert (
204+
len(self.target_eviction_percent_per_table)
205+
== len(self.feature_score_counter_decay_rates)
206+
== len(self.max_training_id_num_per_table)
207+
), (
208+
"feature_score_thresholds, max_training_id_num_per_table and target_eviction_percent_per_table must have the same length, "
209+
f"actual {self.target_eviction_percent_per_table} vs {self.feature_score_counter_decay_rates} vs {self.max_training_id_num_per_table}"
210+
)
164211

165212

166213
class KVZCHParams(NamedTuple):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -672,16 +672,22 @@ def __init__(
672672
)
673673
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
674674
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
675-
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
675+
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
676676
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
677677
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
678678
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
679-
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is feature score
680-
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is feature score
679+
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
680+
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
681+
self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
682+
self.kv_zch_params.eviction_policy.max_training_id_num_per_table, # max_training_id_num for each table
683+
self.kv_zch_params.eviction_policy.target_eviction_percent_per_table, # target_eviction_percent for each table
681684
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
682685
table_dims.tolist() if table_dims is not None else None,
686+
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
687+
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
683688
self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
684689
self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
690+
self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s,
685691
)
686692
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
687693
self.cache_row_dim,
@@ -1013,6 +1019,9 @@ def __init__(
10131019
self.stats_reporter.register_stats(
10141020
"eviction.feature_table.exec_duration_ms"
10151021
)
1022+
self.stats_reporter.register_stats(
1023+
"eviction.feature_table.dry_run_exec_duration_ms"
1024+
)
10161025
self.stats_reporter.register_stats(
10171026
"eviction.feature_table.exec_div_full_duration_rate"
10181027
)
@@ -1600,6 +1609,7 @@ def prefetch(
16001609
self,
16011610
indices: Tensor,
16021611
offsets: Tensor,
1612+
weights: Optional[Tensor] = None, # todo: need to update caller
16031613
forward_stream: Optional[torch.cuda.Stream] = None,
16041614
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
16051615
) -> None:
@@ -1625,6 +1635,7 @@ def prefetch(
16251635
self._prefetch(
16261636
indices,
16271637
offsets,
1638+
weights,
16281639
vbe_metadata,
16291640
forward_stream,
16301641
)
@@ -1633,6 +1644,7 @@ def _prefetch( # noqa C901
16331644
self,
16341645
indices: Tensor,
16351646
offsets: Tensor,
1647+
weights: Optional[Tensor] = None,
16361648
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
16371649
forward_stream: Optional[torch.cuda.Stream] = None,
16381650
) -> None:
@@ -1660,6 +1672,12 @@ def _prefetch( # noqa C901
16601672

16611673
self.timestep += 1
16621674
self.timesteps_prefetched.append(self.timestep)
1675+
if self.backend_type == BackendType.DRAM and weights is not None:
1676+
# DRAM backend supports feature score eviction, if there is weights available
1677+
# in the prefetch call, we will set metadata for feature score eviction asynchronously
1678+
cloned_linear_cache_indices = linear_cache_indices.clone()
1679+
else:
1680+
cloned_linear_cache_indices = None
16631681

16641682
# Lookup and virtually insert indices into L1. After this operator,
16651683
# we know:
@@ -1691,6 +1709,18 @@ def _prefetch( # noqa C901
16911709
lxu_cache_locking_counter=self.lxu_cache_locking_counter,
16921710
)
16931711

1712+
# acc_weights is a 2d tensor, dim0: engagement counter, dim1: show counter
1713+
# how to get unique indices atm? we only have inserted_indices, evicted_indices and unique_indices_length
1714+
acc_weights = (
1715+
torch.ops.fbgemm.jagged_acc_weights_and_counts_2d_tensor(
1716+
weights.view(torch.float32).view(-1, 2),
1717+
linear_index_inverse_indices,
1718+
unique_indices_length,
1719+
)
1720+
if weights is not None
1721+
else None
1722+
)
1723+
16941724
# Compute cache locations (rows that are hit are missed but can be
16951725
# inserted will have cache locations != -1)
16961726
with record_function("## ssd_tbe_lxu_cache_lookup ##"):
@@ -2015,6 +2045,16 @@ def _prefetch( # noqa C901
20152045
is_bwd=False,
20162046
)
20172047

2048+
if self.backend_type == BackendType.DRAM and weights is not None:
2049+
# Write feature score metadata to DRAM
2050+
self.record_function_via_dummy_profile(
2051+
"## ssd_write_feature_score_metadata ##",
2052+
self.ssd_db.set_feature_score_metadata_cuda,
2053+
cloned_linear_cache_indices.cpu(),
2054+
torch.tensor([weights.shape[0]], device="cpu", dtype=torch.long),
2055+
weights.cpu().view(torch.float32).view(-1, 2),
2056+
)
2057+
20182058
# Generate row addresses (pointing to either L1 or the current
20192059
# iteration's scratch pad)
20202060
with record_function("## ssd_generate_row_addrs ##"):
@@ -2157,6 +2197,7 @@ def forward(
21572197
self,
21582198
indices: Tensor,
21592199
offsets: Tensor,
2200+
weights: Optional[Tensor] = None,
21602201
per_sample_weights: Optional[Tensor] = None,
21612202
feature_requires_grad: Optional[Tensor] = None,
21622203
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
@@ -2178,7 +2219,7 @@ def forward(
21782219
context=self.step,
21792220
stream=self.ssd_eviction_stream,
21802221
):
2181-
self._prefetch(indices, offsets, vbe_metadata)
2222+
self._prefetch(indices, offsets, weights, vbe_metadata)
21822223

21832224
assert len(self.ssd_prefetch_data) > 0
21842225

@@ -3738,8 +3779,13 @@ def _report_eviction_stats(self) -> None:
37383779
processed_counts = torch.zeros(T, dtype=torch.int64)
37393780
full_duration_ms = torch.tensor(0, dtype=torch.int64)
37403781
exec_duration_ms = torch.tensor(0, dtype=torch.int64)
3782+
dry_run_exec_duration_ms = torch.tensor(0, dtype=torch.int64)
37413783
self.ssd_db.get_feature_evict_metric(
3742-
evicted_counts, processed_counts, full_duration_ms, exec_duration_ms
3784+
evicted_counts,
3785+
processed_counts,
3786+
full_duration_ms,
3787+
exec_duration_ms,
3788+
dry_run_exec_duration_ms,
37433789
)
37443790

37453791
stats_reporter.report_data_amount(
@@ -3791,6 +3837,12 @@ def _report_eviction_stats(self) -> None:
37913837
duration_ms=exec_duration_ms.item(),
37923838
time_unit="ms",
37933839
)
3840+
stats_reporter.report_duration(
3841+
iteration_step=self.step,
3842+
event_name="eviction.feature_table.dry_run_exec_duration_ms",
3843+
duration_ms=dry_run_exec_duration_ms.item(),
3844+
time_unit="ms",
3845+
)
37943846
if full_duration_ms.item() != 0:
37953847
stats_reporter.report_data_amount(
37963848
iteration_step=self.step,

0 commit comments

Comments
 (0)