From 66025cf5ffee6a905a2d042ee00f50f95d54e774 Mon Sep 17 00:00:00 2001 From: Raghav Hrishikeshan Mukundan Date: Wed, 25 Feb 2026 18:42:21 -0800 Subject: [PATCH 1/6] Test removing thehardcoding from the current TFLOPs llama3 lora formula Signed-off-by: Raghav Hrishikeshan Mukundan --- .../bridge/data/datasets/packing_utils.py | 45 +++++++++++++++++ .../bridge/training/utils/flop_utils.py | 48 ++++++++++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index dd36f5e1f4..e9ac80575a 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -269,3 +269,48 @@ def fill_packing_strategy( assert all(not seq[0] for seq in ifile_handles.values()), "Error: There are items left over from the assignment" assert all(not seq[1] for seq in ifile_handles.values()), "Error: There are items left over from the assignment" return output_data + + +def get_seqlen_list(elem): + num_seq = len(elem['seq_start_id']) + tokens_total = len(elem['input_ids']) + tokens_minus_eos = tokens_total - num_seq + + seq_boundaries = elem['seq_start_id'] + [tokens_total] + + # subtract 1 to account for removing eos token + token_counts = [seq_boundaries[i + 1] - seq_boundaries[i] - 1 for i in range(num_seq)] + + assert sum(token_counts) == tokens_minus_eos, (sum(token_counts), tokens_minus_eos) + + return token_counts, tokens_minus_eos + + +def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder): + data = np.load(dataset_file, allow_pickle=True) + + total_len_accum = 0 + seqlen_sq_accum = 0 + seq_count_accum = 0 + + rows_total = len(data) + count = (rows_total // gbs)*gbs if drop_remainder else rows_total + + if count != rows_total: + print(f'Dropping {rows_total - count}, total was {rows_total}') + + for i, elem in enumerate(data): + if i >= count: + break + seqlen_list, total_count = get_seqlen_list(elem) + seqlen_sq_list = [s*s for s in seqlen_list] + total_len_accum += total_count + seqlen_sq_accum += sum(seqlen_sq_list) + seq_count_accum += len(seqlen_list) + + avg_seqlen_count = seq_count_accum/count + avg_seqlen_total = total_len_accum/count + avg_seqlen_sq_individual = seqlen_sq_accum/seq_count_accum + avg_seqlen_sq_per_row = seqlen_sq_accum/count + + return avg_seqlen_count, avg_seqlen_total, avg_seqlen_sq_individual, avg_seqlen_sq_per_row diff --git a/src/megatron/bridge/training/utils/flop_utils.py b/src/megatron/bridge/training/utils/flop_utils.py index 2dfdc7f054..e56492e986 100644 --- a/src/megatron/bridge/training/utils/flop_utils.py +++ b/src/megatron/bridge/training/utils/flop_utils.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import torch.nn.functional as F +from megatron.bridge.data.datasets.packing_utils import calculate_avg_seqlen +from megatron.bridge.peft.lora import LoRA from megatron.bridge.training.config import ConfigContainer from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1): """Return the number of floating point operations""" - # If the model provider has a custom TFLOPS calculation method, use it. - if hasattr(cfg.model, "_get_num_floating_point_operations"): + peft = getattr(cfg, "peft", None) + is_lora = isinstance(peft, LoRA) + # If the model provider has a custom TFLOPS calculation method, use it (non-LoRA only). + if not is_lora and hasattr(cfg.model, "_get_num_floating_point_operations"): return cfg.model._get_num_floating_point_operations(batch_size) def calculate_layer_counts(): @@ -183,6 +189,44 @@ def transformer_flops(): num_query_groups = ( cfg.model.num_attention_heads if cfg.model.num_query_groups is None else cfg.model.num_query_groups ) + + is_squad = getattr(getattr(cfg, "dataset", None), "dataset_name", None) == "squad" + is_llama3_70b = cfg.model.hidden_size == 8192 and cfg.model.num_layers == 80 + packed_specs = getattr(getattr(cfg, "dataset", None), "packed_sequence_specs", None) + packed_data_path = getattr(packed_specs, "packed_train_data_path", None) + if ( + is_lora + and is_squad + and is_llama3_70b + and packed_data_path is not None + and Path(packed_data_path).exists() + ): + gbs = cfg.train.global_batch_size + seq_len = cfg.model.seq_length + _, avg_tokens, _, avg_seqlen2 = calculate_avg_seqlen( + packed_data_path, gbs, seq_len, drop_remainder=True + ) + + hs = cfg.model.hidden_size + n_layers = cfg.model.num_layers + n_heads = cfg.model.num_attention_heads + ffn_hs = cfg.model.ffn_hidden_size + vocab_size = cfg.model.vocab_size + + model_flops_frozen = ( + avg_tokens + * n_layers + * hs**2 + * ( + 12 + + 12 * num_query_groups / n_heads + + 18 * ffn_hs / hs + + 6 * vocab_size / (n_layers * hs) + ) + ) + model_flops_unfrozen = n_layers * hs**2 * (12 * avg_seqlen2 / hs) + + return batch_size * (model_flops_frozen * (2.0 / 3.0) + model_flops_unfrozen) # MoE. if cfg.model.num_moe_experts is None: # Every Transformer MLP is dense. From df8a5aa89532d11dec95da39497752948c39cd54 Mon Sep 17 00:00:00 2001 From: Raghav Hrishikeshan Mukundan Date: Thu, 26 Feb 2026 10:11:48 -0800 Subject: [PATCH 2/6] Add checks for packed sequencing Signed-off-by: Raghav Hrishikeshan Mukundan --- .../bridge/training/utils/flop_utils.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/megatron/bridge/training/utils/flop_utils.py b/src/megatron/bridge/training/utils/flop_utils.py index e56492e986..fc26e3dab9 100644 --- a/src/megatron/bridge/training/utils/flop_utils.py +++ b/src/megatron/bridge/training/utils/flop_utils.py @@ -21,6 +21,8 @@ from megatron.bridge.training.config import ConfigContainer from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size +_lora_seq_stats_cache: dict = {} + def num_floating_point_operations(cfg: ConfigContainer, batch_size: int = 1): """Return the number of floating point operations""" @@ -194,6 +196,15 @@ def transformer_flops(): is_llama3_70b = cfg.model.hidden_size == 8192 and cfg.model.num_layers == 80 packed_specs = getattr(getattr(cfg, "dataset", None), "packed_sequence_specs", None) packed_data_path = getattr(packed_specs, "packed_train_data_path", None) + # If not explicitly set, try to find the file via dataset_root (the FinetuningDatasetBuilder + # computes this path dynamically, but dataset_root is available from the config). + if packed_data_path is None and packed_specs is not None: + dataset_root = getattr(cfg.dataset, "dataset_root", None) + seq_size = getattr(packed_specs, "packed_sequence_size", None) + if dataset_root is not None and seq_size is not None: + matches = sorted(Path(dataset_root).glob(f"packed/*/training_{seq_size}.npy")) + if matches: + packed_data_path = str(matches[0]) if ( is_lora and is_squad @@ -203,9 +214,12 @@ def transformer_flops(): ): gbs = cfg.train.global_batch_size seq_len = cfg.model.seq_length - _, avg_tokens, _, avg_seqlen2 = calculate_avg_seqlen( - packed_data_path, gbs, seq_len, drop_remainder=True - ) + cache_key = (packed_data_path, gbs, seq_len) + if cache_key not in _lora_seq_stats_cache: + _lora_seq_stats_cache[cache_key] = calculate_avg_seqlen( + packed_data_path, gbs, seq_len, drop_remainder=True + ) + _, avg_tokens, _, avg_seqlen2 = _lora_seq_stats_cache[cache_key] hs = cfg.model.hidden_size n_layers = cfg.model.num_layers From b16213fe0ac9cf574d1e0a0bdd0e81f40a672301 Mon Sep 17 00:00:00 2001 From: Raghav Hrishikeshan Mukundan Date: Thu, 26 Feb 2026 21:54:57 -0800 Subject: [PATCH 3/6] Addressing MR comments Signed-off-by: Raghav Hrishikeshan Mukundan --- src/megatron/bridge/data/datasets/packing_utils.py | 2 +- src/megatron/bridge/training/utils/flop_utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index e9ac80575a..367392f51a 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -297,7 +297,7 @@ def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder): count = (rows_total // gbs)*gbs if drop_remainder else rows_total if count != rows_total: - print(f'Dropping {rows_total - count}, total was {rows_total}') + logger.warning(f'Dropping {rows_total - count}, total was {rows_total}') for i, elem in enumerate(data): if i >= count: diff --git a/src/megatron/bridge/training/utils/flop_utils.py b/src/megatron/bridge/training/utils/flop_utils.py index fc26e3dab9..719a8014e2 100644 --- a/src/megatron/bridge/training/utils/flop_utils.py +++ b/src/megatron/bridge/training/utils/flop_utils.py @@ -193,7 +193,8 @@ def transformer_flops(): ) is_squad = getattr(getattr(cfg, "dataset", None), "dataset_name", None) == "squad" - is_llama3_70b = cfg.model.hidden_size == 8192 and cfg.model.num_layers == 80 + hf_model_id = getattr(cfg.model, "hf_model_id", None) + is_llama3_70b = hf_model_id is not None and "Meta-Llama-3-70B" in hf_model_id packed_specs = getattr(getattr(cfg, "dataset", None), "packed_sequence_specs", None) packed_data_path = getattr(packed_specs, "packed_train_data_path", None) # If not explicitly set, try to find the file via dataset_root (the FinetuningDatasetBuilder From cb8f86c6c3f6c65295a9db1049015a34e94683fd Mon Sep 17 00:00:00 2001 From: Raghav Hrishikeshan Mukundan Date: Thu, 26 Feb 2026 22:09:05 -0800 Subject: [PATCH 4/6] Addressing comments Signed-off-by: Raghav Hrishikeshan Mukundan --- src/megatron/bridge/data/datasets/packing_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index 367392f51a..161040a00d 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -297,7 +297,7 @@ def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder): count = (rows_total // gbs)*gbs if drop_remainder else rows_total if count != rows_total: - logger.warning(f'Dropping {rows_total - count}, total was {rows_total}') + logger.info(f'Dropping {rows_total - count}, total was {rows_total}') for i, elem in enumerate(data): if i >= count: @@ -308,6 +308,11 @@ def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder): seqlen_sq_accum += sum(seqlen_sq_list) seq_count_accum += len(seqlen_list) + if count == 0: + raise ValueError(f"No rows to process: dataset has {rows_total} rows but gbs={gbs} with drop_remainder={drop_remainder}.") + if seq_count_accum == 0: + raise ValueError("No sequences found in dataset; cannot compute average sequence length.") + avg_seqlen_count = seq_count_accum/count avg_seqlen_total = total_len_accum/count avg_seqlen_sq_individual = seqlen_sq_accum/seq_count_accum From 9e1f72b53993eac0dd2ab3e61085adf17998ac90 Mon Sep 17 00:00:00 2001 From: Raghav Hrishikeshan Mukundan Date: Fri, 27 Feb 2026 09:39:37 -0800 Subject: [PATCH 5/6] Adding Doc Strings and Type Hints Signed-off-by: Raghav Hrishikeshan Mukundan --- .../bridge/data/datasets/packing_utils.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index 161040a00d..ab7aed237f 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -271,7 +271,17 @@ def fill_packing_strategy( return output_data -def get_seqlen_list(elem): +def get_seqlen_list(elem: Dict) -> Tuple[List[int], int]: + """Extract per-sequence token counts from a packed dataset element. + + Args: + elem: A packed dataset element with 'input_ids' and 'seq_start_id' fields. + + Returns: + A tuple of (token_counts, tokens_minus_eos) where token_counts is a list of + per-sequence token counts (excluding EOS) and tokens_minus_eos is the total + token count excluding EOS tokens. + """ num_seq = len(elem['seq_start_id']) tokens_total = len(elem['input_ids']) tokens_minus_eos = tokens_total - num_seq @@ -286,7 +296,25 @@ def get_seqlen_list(elem): return token_counts, tokens_minus_eos -def calculate_avg_seqlen(dataset_file, gbs, max_seq_len, drop_remainder): +def calculate_avg_seqlen(dataset_file: str, gbs: int, max_seq_len: int, drop_remainder: bool) -> Tuple[float, float, float, float]: + """Calculate average sequence length statistics from a packed dataset. + + Args: + dataset_file: Path to the .npy packed dataset file. + gbs: Global batch size used to determine how many rows to process. + max_seq_len: Maximum sequence length (reserved for future use). + drop_remainder: If True, drop rows that don't fill a complete batch. + + Returns: + A tuple of (avg_seqlen_count, avg_seqlen_total, avg_seqlen_sq_individual, avg_seqlen_sq_per_row): + - avg_seqlen_count: Average number of sequences per row. + - avg_seqlen_total: Average total tokens (excluding EOS) per row. + - avg_seqlen_sq_individual: Average of squared per-sequence lengths. + - avg_seqlen_sq_per_row: Average of summed squared sequence lengths per row. + + Raises: + ValueError: If no rows remain after applying drop_remainder, or if no sequences are found. + """ data = np.load(dataset_file, allow_pickle=True) total_len_accum = 0 From 201730c48795bfef5e6d2a50561de83f671afaad Mon Sep 17 00:00:00 2001 From: Raghav Hrishikeshan Mukundan Date: Fri, 27 Feb 2026 12:40:16 -0800 Subject: [PATCH 6/6] Apply ruff formatting fixes Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Raghav Hrishikeshan Mukundan --- .../bridge/data/datasets/packing_utils.py | 28 +++++++++++-------- .../bridge/training/utils/flop_utils.py | 16 ++--------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/megatron/bridge/data/datasets/packing_utils.py b/src/megatron/bridge/data/datasets/packing_utils.py index ab7aed237f..b53edd3563 100644 --- a/src/megatron/bridge/data/datasets/packing_utils.py +++ b/src/megatron/bridge/data/datasets/packing_utils.py @@ -282,11 +282,11 @@ def get_seqlen_list(elem: Dict) -> Tuple[List[int], int]: per-sequence token counts (excluding EOS) and tokens_minus_eos is the total token count excluding EOS tokens. """ - num_seq = len(elem['seq_start_id']) - tokens_total = len(elem['input_ids']) + num_seq = len(elem["seq_start_id"]) + tokens_total = len(elem["input_ids"]) tokens_minus_eos = tokens_total - num_seq - seq_boundaries = elem['seq_start_id'] + [tokens_total] + seq_boundaries = elem["seq_start_id"] + [tokens_total] # subtract 1 to account for removing eos token token_counts = [seq_boundaries[i + 1] - seq_boundaries[i] - 1 for i in range(num_seq)] @@ -296,7 +296,9 @@ def get_seqlen_list(elem: Dict) -> Tuple[List[int], int]: return token_counts, tokens_minus_eos -def calculate_avg_seqlen(dataset_file: str, gbs: int, max_seq_len: int, drop_remainder: bool) -> Tuple[float, float, float, float]: +def calculate_avg_seqlen( + dataset_file: str, gbs: int, max_seq_len: int, drop_remainder: bool +) -> Tuple[float, float, float, float]: """Calculate average sequence length statistics from a packed dataset. Args: @@ -322,28 +324,30 @@ def calculate_avg_seqlen(dataset_file: str, gbs: int, max_seq_len: int, drop_rem seq_count_accum = 0 rows_total = len(data) - count = (rows_total // gbs)*gbs if drop_remainder else rows_total + count = (rows_total // gbs) * gbs if drop_remainder else rows_total if count != rows_total: - logger.info(f'Dropping {rows_total - count}, total was {rows_total}') + logger.info(f"Dropping {rows_total - count}, total was {rows_total}") for i, elem in enumerate(data): if i >= count: break seqlen_list, total_count = get_seqlen_list(elem) - seqlen_sq_list = [s*s for s in seqlen_list] + seqlen_sq_list = [s * s for s in seqlen_list] total_len_accum += total_count seqlen_sq_accum += sum(seqlen_sq_list) seq_count_accum += len(seqlen_list) if count == 0: - raise ValueError(f"No rows to process: dataset has {rows_total} rows but gbs={gbs} with drop_remainder={drop_remainder}.") + raise ValueError( + f"No rows to process: dataset has {rows_total} rows but gbs={gbs} with drop_remainder={drop_remainder}." + ) if seq_count_accum == 0: raise ValueError("No sequences found in dataset; cannot compute average sequence length.") - avg_seqlen_count = seq_count_accum/count - avg_seqlen_total = total_len_accum/count - avg_seqlen_sq_individual = seqlen_sq_accum/seq_count_accum - avg_seqlen_sq_per_row = seqlen_sq_accum/count + avg_seqlen_count = seq_count_accum / count + avg_seqlen_total = total_len_accum / count + avg_seqlen_sq_individual = seqlen_sq_accum / seq_count_accum + avg_seqlen_sq_per_row = seqlen_sq_accum / count return avg_seqlen_count, avg_seqlen_total, avg_seqlen_sq_individual, avg_seqlen_sq_per_row diff --git a/src/megatron/bridge/training/utils/flop_utils.py b/src/megatron/bridge/training/utils/flop_utils.py index 719a8014e2..1442ff53dd 100644 --- a/src/megatron/bridge/training/utils/flop_utils.py +++ b/src/megatron/bridge/training/utils/flop_utils.py @@ -21,6 +21,7 @@ from megatron.bridge.training.config import ConfigContainer from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size + _lora_seq_stats_cache: dict = {} @@ -206,13 +207,7 @@ def transformer_flops(): matches = sorted(Path(dataset_root).glob(f"packed/*/training_{seq_size}.npy")) if matches: packed_data_path = str(matches[0]) - if ( - is_lora - and is_squad - and is_llama3_70b - and packed_data_path is not None - and Path(packed_data_path).exists() - ): + if is_lora and is_squad and is_llama3_70b and packed_data_path is not None and Path(packed_data_path).exists(): gbs = cfg.train.global_batch_size seq_len = cfg.model.seq_length cache_key = (packed_data_path, gbs, seq_len) @@ -232,12 +227,7 @@ def transformer_flops(): avg_tokens * n_layers * hs**2 - * ( - 12 - + 12 * num_query_groups / n_heads - + 18 * ffn_hs / hs - + 6 * vocab_size / (n_layers * hs) - ) + * (12 + 12 * num_query_groups / n_heads + 18 * ffn_hs / hs + 6 * vocab_size / (n_layers * hs)) ) model_flops_unfrozen = n_layers * hs**2 * (12 * avg_seqlen2 / hs)