Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 115 additions & 15 deletions nemo_rl/models/megatron/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,19 @@
from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank


def _round_up_to_multiple(value: int, multiple: int) -> int:
return (
((value + multiple - 1) // multiple * multiple)
if value % multiple != 0
else value
)


def _pack_sequences_for_megatron(
input_ids: torch.Tensor,
seq_lengths: torch.Tensor,
pad_individual_seqs_to_multiple_of: int = 1,
pad_packed_seq_to_multiple_of: int = 1,
pad_packed_seq_to: Optional[int] = None,
cp_rank: int = 0,
cp_size: int = 1,
Expand All @@ -47,7 +56,9 @@ def _pack_sequences_for_megatron(
input_ids: Input token IDs [batch_size, seq_length]
seq_lengths: Actual sequence lengths for each sample [batch_size]
pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value
pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value
pad_packed_seq_to: Pad packed sequences to this value (before CP)
- The three parameters above can be calculated using _get_pack_sequence_parameters_for_megatron, we do not recommend users to set these parameters manually.
cp_size: Context parallelism size

Returns:
Expand All @@ -61,14 +72,22 @@ def _pack_sequences_for_megatron(
batch_size = input_ids.shape[0]

# Build cumulative sequence lengths (cu_seqlens) and extract valid tokens
cu_seqlens = [0]
cu_seqlens_padded = (
[0]
if pad_individual_seqs_to_multiple_of > 1 or pad_packed_seq_to is not None
else None
needs_padding = (
pad_individual_seqs_to_multiple_of > 1
or pad_packed_seq_to_multiple_of > 1
or pad_packed_seq_to is not None
)

cu_seqlens = [0]
cu_seqlens_padded = [0] if needs_padding else None
valid_tokens = []

# Round up the pad_packed_seq_to to the nearest multiple of pad_packed_seq_to_multiple_of
if pad_packed_seq_to is not None:
pad_packed_seq_to = _round_up_to_multiple(
pad_packed_seq_to, pad_packed_seq_to_multiple_of
)

pad_factor = pad_individual_seqs_to_multiple_of

for b in range(batch_size):
Expand All @@ -83,22 +102,26 @@ def _pack_sequences_for_megatron(
cu_seqlens.append(cu_seqlens[-1] + seq_len)

# For context parallelism, track padded sequence lengths
if pad_factor > 1 or pad_packed_seq_to is not None:
if needs_padding:
# Pad sequence length to multiple of (cp_size * 2)
padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor
padded_seq_len = _round_up_to_multiple(seq_len, pad_factor)
cu_seqlens_padded.append(cu_seqlens_padded[-1] + padded_seq_len)

# Convert to tensors
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=input_ids.device)
if pad_factor > 1 or pad_packed_seq_to is not None:
if needs_padding:
cu_seqlens_padded = torch.tensor(
cu_seqlens_padded, dtype=torch.int32, device=input_ids.device
)
if pad_packed_seq_to is not None:
cu_seqlens_padded[-1] = pad_packed_seq_to
elif pad_packed_seq_to_multiple_of > 1:
cu_seqlens_padded[-1] = _round_up_to_multiple(
cu_seqlens_padded[-1], pad_packed_seq_to_multiple_of
)

# Calculate max sequence length (padded if using CP)
if pad_factor > 1 or (pad_packed_seq_to is not None):
if needs_padding:
seq_lens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1]
max_seqlen = seq_lens_padded.max().item()
else:
Expand All @@ -119,11 +142,22 @@ def _pack_sequences_for_megatron(
else seq_lengths[b]
)
# if last element, pad to the max sequence length
if b == batch_size - 1 and pad_packed_seq_to is not None:
padded_seq_len = pad_packed_seq_to - running_seq_len
running_seq_len += padded_seq_len
if b == batch_size - 1 and needs_padding:
if pad_packed_seq_to is not None:
padded_seq_len = pad_packed_seq_to - running_seq_len
elif pad_packed_seq_to_multiple_of > 1:
padded_seq_len = _round_up_to_multiple(seq_len, pad_factor)
padded_seq_len = (
_round_up_to_multiple(
running_seq_len + padded_seq_len,
pad_packed_seq_to_multiple_of,
)
- running_seq_len
)
else:
padded_seq_len = _round_up_to_multiple(seq_len, pad_factor)
else:
padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor
padded_seq_len = _round_up_to_multiple(seq_len, pad_factor)

running_seq_len += padded_seq_len

Expand Down Expand Up @@ -152,8 +186,17 @@ def _pack_sequences_for_megatron(
# For 'thd' format, the shape should be [1, T] where T is total tokens
packed_input_ids = torch.cat(valid_tokens, dim=0).unsqueeze(0)
all_input_ids = packed_input_ids
if pad_packed_seq_to is not None:
pad_len = pad_packed_seq_to - packed_input_ids.shape[1]
if needs_padding:
if pad_packed_seq_to is not None:
pad_len = pad_packed_seq_to - packed_input_ids.shape[1]
elif pad_packed_seq_to_multiple_of > 1:
current_seq_len = packed_input_ids.shape[1]
pad_this_seq_to = _round_up_to_multiple(
current_seq_len, pad_packed_seq_to_multiple_of
)
pad_len = pad_this_seq_to - current_seq_len
else:
pad_len = 0
if pad_len > 0:
packed_input_ids = torch.nn.functional.pad(
packed_input_ids, (0, pad_len), value=0
Expand Down Expand Up @@ -184,6 +227,61 @@ def _pack_sequences_for_megatron(
)


def _get_pack_sequence_parameters_for_megatron(
megatron_cfg: dict,
max_seq_len_in_batch: int,
):
"""Get pack sequence parameters for Megatron model processing with optional context parallelism.

Args:
megatron_cfg: Megatron configuration
max_seq_len_in_batch: Maximum sequence length in batch

Returns:
Tuple of:
- pad_individual_seqs_to_multiple_of: Pad individual sequences to a multiple of this value
- pad_packed_seq_to_multiple_of: Pad packed sequences to a multiple of this value
- pad_packed_seq_to: Pad packed sequences to this value (before CP)
"""
tp_size = megatron_cfg["tensor_model_parallel_size"]
sp = megatron_cfg["sequence_parallel"]
pp_size = megatron_cfg["pipeline_model_parallel_size"]
cp_size = megatron_cfg["context_parallel_size"]
fp8_cfg = megatron_cfg.get("fp8_cfg", None) or {}
use_fp8 = fp8_cfg.get("enabled", False)
use_blockwise_fp8 = fp8_cfg.get("fp8_recipe", None) == "blockwise"

# individual sequence needs to be splitted to CP domain, and to TP domain when SP is enabled.
pad_individual_seqs_to_multiple_of = 1
if cp_size > 1:
pad_individual_seqs_to_multiple_of *= cp_size * 2
if tp_size > 1 and sp:
pad_individual_seqs_to_multiple_of *= tp_size

# packed sequence length, after splitted to TP and CP domains, needs to be divisible by 128 if using blockwise FP8, and divisible by 16 if using other FP8 recipes.
if use_fp8:
divisor = 128 if use_blockwise_fp8 else 16
pad_packed_seq_to_multiple_of = divisor
if cp_size > 1:
pad_packed_seq_to_multiple_of *= cp_size * 2
if tp_size > 1 and sp:
pad_packed_seq_to_multiple_of *= tp_size
else:
pad_packed_seq_to_multiple_of = 1

# when PP is used, all sequences must have the same length, so we need to pad the packed sequence to the max sequence length in the batch.
if pp_size > 1:
pad_packed_seq_to = max_seq_len_in_batch
else:
pad_packed_seq_to = None

return (
pad_individual_seqs_to_multiple_of,
pad_packed_seq_to_multiple_of,
pad_packed_seq_to,
)


def _unpack_sequences_from_megatron(
output_tensor: torch.Tensor,
seq_lengths: torch.Tensor,
Expand Down Expand Up @@ -258,6 +356,7 @@ def forward_step_arbitrary_loss(
pack_sequences: bool = False,
seq_length_key: Optional[str] = None,
pad_individual_seqs_to_multiple_of: int = 1,
pad_packed_seq_to_multiple_of: int = 1,
pad_full_seq_to: Optional[int] = None,
cp_normalize: bool = True,
policy_cfg: Optional[dict] = None,
Expand Down Expand Up @@ -321,6 +420,7 @@ def forward_step_arbitrary_loss(
input_ids,
seq_lengths,
pad_individual_seqs_to_multiple_of,
pad_packed_seq_to_multiple_of,
pad_full_seq_to,
cp_rank=get_context_parallel_rank(),
cp_size=get_context_parallel_world_size(),
Expand Down
11 changes: 0 additions & 11 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -235,16 +234,6 @@ def __init__(
sequence_length_pad_multiple = (
cp_size * 2 * tp_size if cp_size > 1 else tp_size
)
if (
config["megatron_cfg"]["enabled"]
and config["megatron_cfg"].get("fp8_cfg", None) is not None
and config["megatron_cfg"]["fp8_cfg"].get("enabled", False)
):
# if fp8 is enabled, ensure the sequence is padded to multiples of 16
# Ref: https://github.com/NVIDIA/TransformerEngine/blob/5b3092a0e40654436bec5ea0a0b0f7ad2887b20d/transformer_engine/pytorch/utils.py#L437-L441
sequence_length_pad_multiple = math.lcm(
16, sequence_length_pad_multiple
)
self.sequence_packing_args: SequencePackingArgs = {
"algorithm": config["sequence_packing"]["algorithm"],
"input_key": "input_ids",
Expand Down
Loading
Loading