Skip to content

Commit 3c46600

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Feature score eviction frontend support (#4682)
Summary: X-link: meta-pytorch/torchrec#3273 Pull Request resolved: #4682 X-link: facebookresearch/FBGEMM#1708 Adding support for feature score eviction in frontend. Reviewed By: emlin Differential Revision: D79591336
1 parent 72d7af7 commit 3c46600

File tree

3 files changed

+231
-25
lines changed

3 files changed

+231
-25
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class EvictionPolicy(NamedTuple):
6565
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
6666
)
6767
eviction_strategy: int = (
68-
0 # 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
68+
0 # 0: timestamp, 1: counter , 2: counter + timestamp, 3: feature l2 norm 4: timestamp threshold 5: feature score
6969
)
7070
eviction_step_intervals: Optional[int] = (
7171
None # trigger_step_interval if trigger mode is iteration
@@ -74,17 +74,32 @@ class EvictionPolicy(NamedTuple):
7474
None # eviction trigger condition if trigger mode is mem_util
7575
)
7676
counter_thresholds: Optional[List[int]] = (
77-
None # count_thresholds for each table if eviction strategy is feature score
77+
None # count_thresholds for each table if eviction strategy is counter
7878
)
7979
ttls_in_mins: Optional[List[int]] = (
8080
None # ttls_in_mins for each table if eviction strategy is timestamp
8181
)
8282
counter_decay_rates: Optional[List[float]] = (
83-
None # count_decay_rates for each table if eviction strategy is feature score
83+
None # count_decay_rates for each table if eviction strategy is counter
84+
)
85+
feature_score_counter_decay_rates: Optional[List[float]] = (
86+
None # feature_score_counter_decay_rates for each table if eviction strategy is feature score
87+
)
88+
max_training_id_num_per_table: Optional[List[int]] = (
89+
None # max_training_id_num_per_table for each table
90+
)
91+
target_eviction_percent_per_table: Optional[List[float]] = (
92+
None # target_eviction_percent_per_table for each table
8493
)
8594
l2_weight_thresholds: Optional[List[float]] = (
8695
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
8796
)
97+
threshold_calculation_bucket_stride: Optional[float] = (
98+
0.2 # threshold_calculation_bucket_stride if eviction strategy is feature score
99+
)
100+
threshold_calculation_bucket_num: Optional[int] = (
101+
1000000 # 1M, threshold_calculation_bucket_num if eviction strategy is feature score
102+
)
88103
interval_for_insufficient_eviction_s: int = (
89104
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
90105
# insufficient means we didn't evict enough rows, so we want to wait longer time to
@@ -95,6 +110,9 @@ class EvictionPolicy(NamedTuple):
95110
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
96111
60
97112
)
113+
interval_for_feature_statistics_decay_s: int = (
114+
24 * 3600 # 1 day, interval for feature statistics decay
115+
)
98116
meta_header_lens: Optional[List[int]] = None # metaheader length for each table
99117

100118
def validate(self) -> None:
@@ -105,8 +123,8 @@ def validate(self) -> None:
105123
if self.eviction_trigger_mode == 0:
106124
return
107125

108-
assert self.eviction_strategy in [0, 1, 2, 3], (
109-
"eviction_strategy must be 0, 1, 2, or 3, "
126+
assert self.eviction_strategy in [0, 1, 2, 3, 4, 5], (
127+
"eviction_strategy must be 0, 1, 2, 3, 4 or 5, "
110128
f"actual {self.eviction_strategy}"
111129
)
112130
if self.eviction_trigger_mode == 1:
@@ -161,6 +179,35 @@ def validate(self) -> None:
161179
"counter_thresholds and ttls_in_mins must have the same length, "
162180
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
163181
)
182+
elif self.eviction_strategy == 5:
183+
assert self.feature_score_counter_decay_rates is not None, (
184+
"feature_score_counter_decay_rates must be set if eviction_strategy is 5, "
185+
f"actual {self.feature_score_counter_decay_rates}"
186+
)
187+
assert self.max_training_id_num_per_table is not None, (
188+
"max_training_id_num_per_table must be set if eviction_strategy is 5,"
189+
f"actual {self.max_training_id_num_per_table}"
190+
)
191+
assert self.target_eviction_percent_per_table is not None, (
192+
"target_eviction_percent_per_table must be set if eviction_strategy is 5,"
193+
f"actual {self.target_eviction_percent_per_table}"
194+
)
195+
assert self.threshold_calculation_bucket_stride is not None, (
196+
"threshold_calculation_bucket_stride must be set if eviction_strategy is 5,"
197+
f"actual {self.threshold_calculation_bucket_stride}"
198+
)
199+
assert self.threshold_calculation_bucket_num is not None, (
200+
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
201+
f"actual {self.threshold_calculation_bucket_num}"
202+
)
203+
assert (
204+
len(self.target_eviction_percent_per_table)
205+
== len(self.feature_score_counter_decay_rates)
206+
== len(self.max_training_id_num_per_table)
207+
), (
208+
"feature_score_thresholds, max_training_id_num_per_table and target_eviction_percent_per_table must have the same length, "
209+
f"actual {self.target_eviction_percent_per_table} vs {self.feature_score_counter_decay_rates} vs {self.max_training_id_num_per_table}"
210+
)
164211

165212

166213
class KVZCHParams(NamedTuple):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -675,18 +675,25 @@ def __init__(
675675
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
676676
else self.l2_cache_size
677677
)
678+
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
678679
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
679680
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
680-
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
681+
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter, 2: counter + timestamp, 3: feature l2 norm, 4: timestamp threshold 5: feature score
681682
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
682683
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
683684
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
684-
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is feature score
685-
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is feature score
685+
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is counter
686+
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is counter
687+
self.kv_zch_params.eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
688+
self.kv_zch_params.eviction_policy.max_training_id_num_per_table, # max_training_id_num for each table
689+
self.kv_zch_params.eviction_policy.target_eviction_percent_per_table, # target_eviction_percent for each table
686690
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
687691
table_dims.tolist() if table_dims is not None else None,
692+
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
693+
self.kv_zch_params.eviction_policy.threshold_calculation_bucket_num, # threshold_calculation_bucket_num if eviction strategy is feature score
688694
self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
689695
self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
696+
self.kv_zch_params.eviction_policy.interval_for_feature_statistics_decay_s,
690697
)
691698
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
692699
self.cache_row_dim,
@@ -1018,6 +1025,9 @@ def __init__(
10181025
self.stats_reporter.register_stats(
10191026
"eviction.feature_table.exec_duration_ms"
10201027
)
1028+
self.stats_reporter.register_stats(
1029+
"eviction.feature_table.dry_run_exec_duration_ms"
1030+
)
10211031
self.stats_reporter.register_stats(
10221032
"eviction.feature_table.exec_div_full_duration_rate"
10231033
)
@@ -1605,6 +1615,7 @@ def prefetch(
16051615
self,
16061616
indices: Tensor,
16071617
offsets: Tensor,
1618+
weights: Optional[Tensor] = None, # todo: need to update caller
16081619
forward_stream: Optional[torch.cuda.Stream] = None,
16091620
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
16101621
) -> None:
@@ -1630,6 +1641,7 @@ def prefetch(
16301641
self._prefetch(
16311642
indices,
16321643
offsets,
1644+
weights,
16331645
vbe_metadata,
16341646
forward_stream,
16351647
)
@@ -1638,6 +1650,7 @@ def _prefetch( # noqa C901
16381650
self,
16391651
indices: Tensor,
16401652
offsets: Tensor,
1653+
weights: Optional[Tensor] = None,
16411654
vbe_metadata: Optional[invokers.lookup_args.VBEMetadata] = None,
16421655
forward_stream: Optional[torch.cuda.Stream] = None,
16431656
) -> None:
@@ -1665,6 +1678,12 @@ def _prefetch( # noqa C901
16651678

16661679
self.timestep += 1
16671680
self.timesteps_prefetched.append(self.timestep)
1681+
if self.backend_type == BackendType.DRAM and weights is not None:
1682+
# DRAM backend supports feature score eviction, if there is weights available
1683+
# in the prefetch call, we will set metadata for feature score eviction asynchronously
1684+
cloned_linear_cache_indices = linear_cache_indices.clone()
1685+
else:
1686+
cloned_linear_cache_indices = None
16681687

16691688
# Lookup and virtually insert indices into L1. After this operator,
16701689
# we know:
@@ -2022,6 +2041,16 @@ def _prefetch( # noqa C901
20222041
is_bwd=False,
20232042
)
20242043

2044+
if self.backend_type == BackendType.DRAM and weights is not None:
2045+
# Write feature score metadata to DRAM
2046+
self.record_function_via_dummy_profile(
2047+
"## ssd_write_feature_score_metadata ##",
2048+
self.ssd_db.set_feature_score_metadata_cuda,
2049+
cloned_linear_cache_indices.cpu(),
2050+
torch.tensor([weights.shape[0]], device="cpu", dtype=torch.long),
2051+
weights.cpu().view(torch.float32).view(-1, 2),
2052+
)
2053+
20252054
# Generate row addresses (pointing to either L1 or the current
20262055
# iteration's scratch pad)
20272056
with record_function("## ssd_generate_row_addrs ##"):
@@ -2164,6 +2193,7 @@ def forward(
21642193
self,
21652194
indices: Tensor,
21662195
offsets: Tensor,
2196+
weights: Optional[Tensor] = None,
21672197
per_sample_weights: Optional[Tensor] = None,
21682198
feature_requires_grad: Optional[Tensor] = None,
21692199
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
@@ -2185,7 +2215,7 @@ def forward(
21852215
context=self.step,
21862216
stream=self.ssd_eviction_stream,
21872217
):
2188-
self._prefetch(indices, offsets, vbe_metadata)
2218+
self._prefetch(indices, offsets, weights, vbe_metadata)
21892219

21902220
assert len(self.ssd_prefetch_data) > 0
21912221

@@ -3745,8 +3775,13 @@ def _report_eviction_stats(self) -> None:
37453775
processed_counts = torch.zeros(T, dtype=torch.int64)
37463776
full_duration_ms = torch.tensor(0, dtype=torch.int64)
37473777
exec_duration_ms = torch.tensor(0, dtype=torch.int64)
3778+
dry_run_exec_duration_ms = torch.tensor(0, dtype=torch.int64)
37483779
self.ssd_db.get_feature_evict_metric(
3749-
evicted_counts, processed_counts, full_duration_ms, exec_duration_ms
3780+
evicted_counts,
3781+
processed_counts,
3782+
full_duration_ms,
3783+
exec_duration_ms,
3784+
dry_run_exec_duration_ms,
37503785
)
37513786

37523787
stats_reporter.report_data_amount(
@@ -3798,6 +3833,12 @@ def _report_eviction_stats(self) -> None:
37983833
duration_ms=exec_duration_ms.item(),
37993834
time_unit="ms",
38003835
)
3836+
stats_reporter.report_duration(
3837+
iteration_step=self.step,
3838+
event_name="eviction.feature_table.dry_run_exec_duration_ms",
3839+
duration_ms=dry_run_exec_duration_ms.item(),
3840+
time_unit="ms",
3841+
)
38013842
if full_duration_ms.item() != 0:
38023843
stats_reporter.report_data_amount(
38033844
iteration_step=self.step,

0 commit comments

Comments
 (0)