diff --git a/.github/workflows/gpu_skyrl_train.yaml b/.github/workflows/gpu_skyrl_train.yaml index 96c728a8b7..d4a216af1a 100644 --- a/.github/workflows/gpu_skyrl_train.yaml +++ b/.github/workflows/gpu_skyrl_train.yaml @@ -48,5 +48,5 @@ jobs: ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }} ANYSCALE_HOST: https://console.anyscale.com run: | - anyscale job submit -f ci/anyscale_gpu_ci_skyrl_train.yaml --timeout 10000 - anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-train-gpu-ci --timeout 10000 + anyscale job submit -f ci/anyscale_gpu_ci_skyrl_train.yaml --timeout 12000 + anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-train-gpu-ci --timeout 12000 diff --git a/examples/train/search/run_search.sh b/examples/train/search/run_search.sh index fb2d12c287..92714d0d05 100755 --- a/examples/train/search/run_search.sh +++ b/examples/train/search/run_search.sh @@ -1,10 +1,32 @@ set -x -# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data. -# follow the instructions in examples/search/README.md for setting up the dataset -# and for starting the local search server -# export WANDB_API_KEY= -# bash examples/train/search/run_search.sh +# Colocated GRPO training+generation for Qwen2.5-3B-Instruct on SearchR1 data. +# Follow the instructions in docs/content/docs/recipes/searchr1.mdx for setup. +# +# Usage: +# export WANDB_API_KEY= +# bash examples/train/search/run_search.sh +# +# Configurable knobs (override via env vars or command-line args): +# USE_CONVERSATION_MULTI_TURN - set to "true" to use conversation multi-turn format (default: false) +# When true, also enables append_eos_token_after_stop_str_in_multi_turn=true so that +# each turn's response ends with the model's EOS token (required for correct behavior +# when stop strings like or terminate generation instead of EOS). +# STEP_WISE - set to "true" to enable step-wise training (default: false) +# Requires USE_CONVERSATION_MULTI_TURN=true. +# +# Examples: +# # Default (non-conversation, non-step-wise): +# bash examples/train/search/run_search.sh +# +# # Conversation multi-turn format: +# USE_CONVERSATION_MULTI_TURN=true bash examples/train/search/run_search.sh +# +# # Step-wise with conversation multi-turn: +# USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh +# +# # Override any config via positional args (passed to Hydra): +# bash examples/train/search/run_search.sh trainer.epochs=2 trainer.eval_interval=10 # path for dataset (.parquet files) containing the prompts and metadata for each question DATA_DIR="$HOME/data/searchR1" @@ -14,6 +36,28 @@ RUN_NAME="skyrl-search_4turns_maxgeneratelen_500-multiturn-sync-TIS_2.0" TIS_TYPE=token TIS_IMP_RATIO_CAP=2.0 +# Configurable knobs with defaults +: "${USE_CONVERSATION_MULTI_TURN:=false}" +: "${STEP_WISE:=false}" + +# Build conditional args +MULTI_TURN_ARGS="" +if [ "$USE_CONVERSATION_MULTI_TURN" = "true" ]; then + MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true" +else + MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false" +fi + +STEP_WISE_ARGS="" +if [ "$STEP_WISE" = "true" ]; then + STEP_WISE_ARGS="generator.step_wise_trajectories=true" + # Step-wise requires conversation multi-turn + if [ "$USE_CONVERSATION_MULTI_TURN" != "true" ]; then + echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically." + MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true" + fi +fi + uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ data.train_data="['${DATA_DIR}/train.parquet']" \ data.val_data="['${DATA_DIR}/validation.parquet']" \ @@ -49,7 +93,8 @@ uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ generator.sampling_params.max_generate_length=500 \ generator.inference_engine.async_engine=true \ generator.batched=false \ - generator.use_conversation_multi_turn=false \ + $MULTI_TURN_ARGS \ + $STEP_WISE_ARGS \ generator.n_samples_per_prompt=5 \ generator.max_turns=4 \ generator.sampling_params.temperature=1.0 \ diff --git a/examples/train/search/run_search_conversation_format.sh b/examples/train/search/run_search_conversation_format.sh deleted file mode 100755 index 92a6c1d154..0000000000 --- a/examples/train/search/run_search_conversation_format.sh +++ /dev/null @@ -1,87 +0,0 @@ -set -x - -# The exact same script as `run_search.sh` but with `use_conversation_multi_turn=true` -# and hence `append_eos_token_after_stop_str_in_multi_turn=true` -# See https://docs.skyrl.ai/docs/tutorials/skyrl_gym_generator on the -# difference between the two options. You might want to change the data generation prompt -# to let the model know that we are doing multi-turn conversations (i.e. user will provide -# the search result for each turn). - -# Colocated GRPO training+generation for Qwen2.5-Coder-3B-Instruct on SearchR1 data. -# follow the instructions in examples/train/search/README.md for setting up the dataset -# and for starting the local search server -# export WANDB_API_KEY= -# bash examples/train/search/run_search_conversation_format.sh - -# path for dataset (.parquet files) containing the prompts and metadata for each question -DATA_DIR="$HOME/data/searchR1" - -RUN_NAME="skyrl-search_4turns_maxgeneratelen_500" - -TIS_TYPE=token -TIS_IMP_RATIO_CAP=2.0 - -uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ - data.train_data="['${DATA_DIR}/train.parquet']" \ - data.val_data="['${DATA_DIR}/validation.parquet']" \ - trainer.algorithm.advantage_estimator="grpo" \ - trainer.policy.optimizer_config.lr=1.0e-6 \ - trainer.policy.optimizer_config.max_grad_norm=0.5 \ - trainer.policy.optimizer_config.num_warmup_steps=94 \ - trainer.algorithm.use_kl_loss=true \ - trainer.algorithm.kl_loss_coef=0.001 \ - trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ - trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ - trainer.policy.model.path="Qwen/Qwen2.5-3B-Instruct" \ - trainer.placement.colocate_all=true \ - trainer.strategy=fsdp2 \ - trainer.policy.fsdp_config.cpu_offload=false \ - trainer.ref.fsdp_config.cpu_offload=true \ - trainer.placement.policy_num_gpus_per_node=8 \ - trainer.placement.ref_num_gpus_per_node=8 \ - generator.inference_engine.num_engines=4 \ - generator.inference_engine.tensor_parallel_size=2 \ - generator.inference_engine.backend=vllm \ - generator.inference_engine.run_engines_locally=true \ - generator.inference_engine.weight_sync_backend=nccl \ - generator.inference_engine.gpu_memory_utilization=0.5 \ - trainer.epochs=1 \ - trainer.update_epochs_per_batch=1 \ - trainer.train_batch_size=512 \ - trainer.policy_mini_batch_size=256 \ - trainer.micro_forward_batch_size_per_gpu=4 \ - trainer.micro_train_batch_size_per_gpu=4 \ - trainer.max_prompt_length=2048 \ - generator.max_input_length=4096 \ - generator.sampling_params.max_generate_length=500 \ - generator.inference_engine.async_engine=true \ - generator.batched=false \ - generator.use_conversation_multi_turn=true \ - generator.n_samples_per_prompt=5 \ - generator.max_turns=4 \ - generator.sampling_params.temperature=1.0 \ - generator.sampling_params.top_p=1.0 \ - generator.sampling_params.stop='["", ""]' \ - generator.append_eos_token_after_stop_str_in_multi_turn=true \ - environment.env_class="search" \ - environment.skyrl_gym.max_env_workers=16 \ - environment.skyrl_gym.search.log_requests=false \ - environment.skyrl_gym.search.search_url="http://127.0.0.1:8000/retrieve" \ - environment.skyrl_gym.search.topk=3 \ - trainer.logger="wandb" \ - trainer.project_name="skyrl-search" \ - trainer.run_name="${RUN_NAME}" \ - trainer.ckpt_interval=20 \ - trainer.hf_save_interval=100 \ - trainer.max_ckpts_to_keep=5 \ - trainer.resume_mode=latest \ - trainer.ckpt_path="$HOME/${RUN_NAME}" \ - trainer.eval_batch_size=256 \ - trainer.eval_before_train=false \ - generator.eval_sampling_params.temperature=0 \ - generator.eval_sampling_params.stop='["", ""]' \ - generator.eval_sampling_params.max_generate_length=500 \ - trainer.export_path="$HOME/${RUN_NAME}/exports" \ - trainer.eval_interval=50 \ - $@ - \ No newline at end of file diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py index df86161c9a..7a43e69f6c 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py @@ -367,6 +367,10 @@ def loss_func(logits, data): return loss, metrics def forward_step(batch_iter, model): + # NOTE(Charlie): despite the name, methods like `remove_left_padding()` are padding-agnostic + # (can be left, or right) as it uses attention_mask to locate real tokens. Same thing + # for recover_left_padding and setup_per_microbatch_replay_forward. Especially relevant + # after this PR https://github.com/NovaSky-AI/SkyRL/pull/1285. batch = next(batch_iter) rollout_expert_indices = batch.pop("rollout_expert_indices", None) diff --git a/skyrl/train/dataset/preprocess.py b/skyrl/train/dataset/preprocess.py index f16bf79b17..7b083bc06f 100644 --- a/skyrl/train/dataset/preprocess.py +++ b/skyrl/train/dataset/preprocess.py @@ -1,9 +1,12 @@ +import logging from typing import List, Optional, Tuple import torch from jaxtyping import Float, Integer from transformers import AutoTokenizer +logger = logging.getLogger(__name__) + def _verify_inputs( prompts: List[List[int]], @@ -34,6 +37,7 @@ def convert_prompts_responses_to_batch_tensors( loss_masks: List[List[int]], logprobs: Optional[List[List[float]]] = None, rollout_expert_indices: Optional[List[List[List[List[int]]]]] = None, + max_seq_len: Optional[int] = None, ) -> Tuple[ Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch seq_len"], @@ -46,12 +50,33 @@ def convert_prompts_responses_to_batch_tensors( """ Convert prompts and responses to batch tensors for training. - This function concatenates all prompts and responses to the following format: + Each sequence is laid out as a single left-padded block: + + | [PAD] [PAD] prompt prompt prompt respon respon | + | [PAD] prompt prompt prompt respon respon respon | + | prompt prompt prompt respon respon respon respon | + |<---- max_response_len ---->| + + The padded sequence length is ``max(prompt_len_i + response_len_i)``. + This way, the max padded sequence length is ``max_seq_len``. + + This makes the response-level tensors (action_mask, rewards, loss_masks, logprobs): + | prompt prompt respon respon | + | prompt respon respon respon | + | respon respon respon respon | + + So the action_mask is: + | 0 0 1 1 | + | 0 1 1 1 | + | 1 1 1 1 | + + Attention mask is 1 for all real tokens, 0 for padding. + Action mask is 1 for the last ``response_len_i`` positions, 0 for padding. - | [PAD] [PAD] token token token | token token [PAD] [PAD] | - | token token token token token | token token [PAD] [PAD] | - | [PAD] [PAD] [PAD] token token | token token token [PAD] | - |<---------- prompt ----------->|<-------- answer ------->| + Response-level tensors are **right-aligned** within ``(batch, max_response_len)``: non-padded + values occupy the last ``response_len_i`` positions, with leading zeros. This matches the model + forward pass which extracts ``log_probs[:, -num_actions-1:-1]`` —- response tokens are always at + the end of the sequence, so their logprobs are right-aligned in the slice. Assumes that the responses already contain an eos token at index -1. @@ -62,88 +87,89 @@ def convert_prompts_responses_to_batch_tensors( rewards: List of rewards for each response loss_masks: List of loss masks for each response logprobs: List of rollout log probs for each response + max_seq_len: Optional. If provided and ``max(prompt_i + response_i)`` + exceeds it, a warning is logged (no truncation is performed). Returns: - sequences: Full trajectories (padded and concatenated prompts and responses). Size: (batch, seq_len). - attention_mask: Attention mask for the model. Size: (batch, seq_len) - action_mask: Response mask for the model. Size: (batch, response_len) - rewards: Rewards for each output. Size: (batch, response_len) - loss_masks: Loss masks for each output. Size: (batch, response_len) + sequences: ``(batch, max_total)`` where ``max_total = max(prompt_i + response_i)``. + attention_mask: ``(batch, max_total)`` + action_mask: ``(batch, max_response)`` — right-aligned response indicator. + rewards: ``(batch, max_response)`` — right-aligned. + loss_masks: ``(batch, max_response)`` — right-aligned. + logprobs: ``(batch, max_response)`` — right-aligned, or ``None``. """ _verify_inputs(prompts, responses, rewards, loss_masks) - max_input_len, max_output_len = 0, 0 - prompt_token_lens, response_token_lens = [], [] - inputs_token_ids, outputs_token_ids = [], [] - for prompt, response in zip(prompts, responses): + prompt_token_lens = [len(p) for p in prompts] + response_token_lens = [len(r) for r in responses] - inputs_token_ids.append(prompt) - outputs_token_ids.append(response) + max_response = max(response_token_lens) + # Pad to the tightest bound: max per-sample total. + max_total = max(p + r for p, r in zip(prompt_token_lens, response_token_lens)) - prompt_token_len = len(prompt) - response_token_len = len(response) - prompt_token_lens.append(prompt_token_len) - response_token_lens.append(response_token_len) - - max_input_len = max(max_input_len, prompt_token_len) - max_output_len = max(max_output_len, response_token_len) + if max_seq_len is not None and max_total > max_seq_len: + logger.warning( + f"Max sequence length in batch ({max_total}) exceeds max_seq_len ({max_seq_len}). " + f"No truncation is performed; consider checking generator settings." + ) pad_token_id = tokenizer.pad_token_id sequences = [] attention_masks = [] action_masks = [] - for i, prompt in enumerate(prompts): - # left padding input - input_len = prompt_token_lens[i] - input_ids = [pad_token_id] * (max_input_len - input_len) + list(inputs_token_ids[i]) - input_attention_mask = [0] * (max_input_len - input_len) + [1] * input_len - - # right padding output - output_len = response_token_lens[i] - output_ids = list(outputs_token_ids[i]) + [pad_token_id] * (max_output_len - output_len) - output_attention_mask = [1] * output_len + [0] * (max_output_len - output_len) - - # concat input and output - sequences.append(input_ids + output_ids) - attention_masks.append(input_attention_mask + output_attention_mask) - action_masks.append(output_attention_mask) + for i in range(len(prompts)): + total_real = prompt_token_lens[i] + response_token_lens[i] + pad_len = max_total - total_real + + # Unified left-pad: [PAD ... PAD PROMPT RESPONSE] + seq = [pad_token_id] * pad_len + prompts[i] + responses[i] + attention_mask_i = [0] * pad_len + [1] * total_real + + # Response indicator within the last max_response positions (right-aligned). + resp_pad = max_response - response_token_lens[i] + action_mask_i = [0] * resp_pad + [1] * response_token_lens[i] + + sequences.append(seq) + attention_masks.append(attention_mask_i) + action_masks.append(action_mask_i) sequences = torch.tensor(sequences) attention_mask = torch.tensor(attention_masks, dtype=torch.int64) action_mask = torch.tensor(action_masks, dtype=torch.int64) - # initialize ret loss masks to be the same as action mask - ret_loss_masks = torch.zeros_like(action_mask, dtype=torch.float) - for i, loss_mask in enumerate(loss_masks): - ret_loss_masks[i, : len(loss_mask)] = torch.tensor(loss_mask) + # Response-level tensors are RIGHT-ALIGNED to match the model output. + # The model's log_probs[:, -num_actions-1:-1] returns logprobs where + # response tokens occupy the last response_len_i positions. + ret_loss_masks = torch.zeros(len(prompts), max_response, dtype=torch.float) + for i, lm in enumerate(loss_masks): + ret_loss_masks[i, max_response - len(lm) :] = torch.tensor(lm, dtype=torch.float) - # do the same for custom rewards - ret_rewards = torch.zeros_like(action_mask, dtype=torch.float) + # Same thing for rewards. + ret_rewards = torch.zeros(len(prompts), max_response, dtype=torch.float) for i, custom_reward in enumerate(rewards): if isinstance(custom_reward, list): custom_reward = torch.tensor(custom_reward) - ret_rewards[i, : len(custom_reward)] = custom_reward + ret_rewards[i, max_response - len(custom_reward) :] = custom_reward + # Same thing for logprobs. logprobs_tensor = None if logprobs: - max_output_len = action_mask.size(1) - padded_logprobs = [ - sample_logprobs + [0.0] * (max_output_len - len(sample_logprobs)) for sample_logprobs in logprobs - ] - logprobs_tensor = torch.tensor(padded_logprobs, dtype=torch.float) + logprobs_tensor = torch.zeros(len(prompts), max_response, dtype=torch.float) + for i, sample_logprobs in enumerate(logprobs): + lp = torch.tensor(sample_logprobs, dtype=torch.float) + logprobs_tensor[i, max_response - len(sample_logprobs) :] = lp rollout_expert_indices_tensor = None if rollout_expert_indices: first_non_empty = next((x for x in rollout_expert_indices if x), None) if first_non_empty: - total_seq_len = max_input_len + max_output_len num_layers = len(first_non_empty[0]) topk = len(first_non_empty[0][0]) if num_layers > 0 else 0 - padded = torch.zeros(len(rollout_expert_indices), total_seq_len, num_layers, topk, dtype=torch.int32) + padded = torch.zeros(len(rollout_expert_indices), max_total, num_layers, topk, dtype=torch.int32) for i, sample_indices in enumerate(rollout_expert_indices): if sample_indices: - left_pad = max_input_len - prompt_token_lens[i] - n = min(len(sample_indices), total_seq_len - left_pad) + left_pad = max_total - (prompt_token_lens[i] + response_token_lens[i]) + n = min(len(sample_indices), max_total - left_pad) padded[i, left_pad : left_pad + n] = torch.tensor(sample_indices[:n], dtype=torch.int32) rollout_expert_indices_tensor = padded diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index dbca1334ea..1f84a8c4a5 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -260,7 +260,6 @@ async def train(self): # 3. Convert GeneratorOutput to TrainingInputBatch with Timer("convert_to_training_input", self.all_timings): training_input: TrainingInputBatch = self.convert_to_training_input(generator_output, uids) - logger.info(f"Number of sequences: {len(training_input['sequences'])}") # 4. Inference and calculate values, log probs, rewards, kl divergence with Timer("fwd_logprobs_values_reward", self.all_timings): @@ -630,6 +629,7 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis loss_masks, logprobs, rollout_expert_indices, + max_seq_len=self.cfg.trainer.algorithm.max_seq_len, ) # sanity check for off_policy_correction @@ -661,6 +661,14 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis training_input.metadata = {"uids": uids} # padded response length training_input.metadata["response_length"] = response_masks_tensor.shape[1] + batch_num_seq, batch_padded_seq_len = sequences_tensor.shape + logger.info(f"batch_num_seq: {batch_num_seq}, batch_padded_seq_len: {batch_padded_seq_len}") + self.all_metrics.update( + { + "generate/batch_num_seq": batch_num_seq, + "generate/batch_padded_seq_len": batch_padded_seq_len, + } + ) if self.cfg.generator.step_wise_trajectories: assert ( "trajectory_ids" in generator_output diff --git a/tests/train/dataset/test_preprocess.py b/tests/train/dataset/test_preprocess.py index b6aac3396f..1df1d002a1 100644 --- a/tests/train/dataset/test_preprocess.py +++ b/tests/train/dataset/test_preprocess.py @@ -57,6 +57,15 @@ def fake_tokenizer_decode_list(ids, **kwargs): def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer): + """ + Test with inputs of exact lengths. + + | [PAD] [PAD] [PAD] [PAD] prompt prompt prompt respon respon respon | + | prompt prompt prompt prompt prompt respon respon respon respon respon | + |<------- max_response_len ------->| + """ + # prompts: "abc" (3 tokens), "12345" (5 tokens) + # outputs: "def" (3 tokens), "67890" (5 tokens) prompts = ["abc", "12345"] outputs = ["def", "67890"] prompts = tokenizer(prompts)["input_ids"] @@ -75,17 +84,25 @@ def test_convert_prompts_responses_to_batch_tensors_exact(tokenizer): ) ) - # loss mask should be the same length as the action mask (padded to the longest input) + # max_total = max(3+3, 5+5) = 10, max_response = 5 assert sequences.shape[0] == len(prompts) + assert sequences.shape == (2, 10) assert action_mask.shape == ret_loss_masks.shape - assert torch.equal(ret_loss_masks[0], torch.tensor([1, 1, 0, 0, 0])) + # Response data is RIGHT-ALIGNED within (batch, max_response) + # Sample 0: response len=3, so 2 leading zeros then 3 values + assert torch.equal(ret_loss_masks[0], torch.tensor([0, 0, 1, 1, 0])) assert torch.equal(ret_loss_masks[1], torch.tensor([1, 1, 1, 0, 0])) - assert torch.equal(ret_rewards[0], torch.tensor([0, 1, 0, 0, 0])) + assert torch.equal(ret_rewards[0], torch.tensor([0, 0, 0, 1, 0])) assert torch.equal(ret_rewards[1], torch.tensor([1, 0, 0, 0, 0])) + # max_total=10: sample 0 has total=6, so 4 left-pads; sample 1 has total=10, no padding + assert torch.equal(attention_mask[0], torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])) + assert torch.equal(attention_mask[1], torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) def test_convert_prompts_responses_to_batch_tensors_different_lengths(tokenizer): # Test with inputs of different lengths + # "Short" = 5 tokens, "This is a longer prompt" = 23 tokens + # "Long response here" = 18 tokens, "Short" = 5 tokens prompts = ["Short", "This is a longer prompt"] outputs = ["Long response here", "Short"] prompts = tokenizer(prompts)["input_ids"] @@ -104,20 +121,23 @@ def test_convert_prompts_responses_to_batch_tensors_different_lengths(tokenizer) ) max_response_len = max([len(output) for output in outputs]) + # max_total = max(5+18, 23+5) = 28 + max_total = max(len(p) + len(r) for p, r in zip(prompts, outputs)) # Check shapes - assert sequences.shape[0] == 2 # batch size + assert sequences.shape == (2, max_total) assert attention_mask.shape == sequences.shape - # Tensor.shape can be directly compared with tuples assert action_mask.shape == (2, max_response_len) assert ret_rewards.shape == (2, max_response_len) assert ret_loss_masks.shape == (2, max_response_len) - # Verify padding is applied correctly - # First input is shorter than second input. the input is left padded + # Unified left-padding: shorter total gets left-padded + # Sample 0: total=23, pad=28-23=5 left pads assert sequences[0, 0] == tokenizer.pad_token_id - # second output is shorter than first output. the output is right padded - assert sequences[1, -1] == tokenizer.pad_token_id + assert sequences[1, 0] != tokenizer.pad_token_id + # All sequences end with real tokens (response at end), no right padding + assert sequences[0, -1] != tokenizer.pad_token_id + assert sequences[1, -1] != tokenizer.pad_token_id def test_convert_prompts_responses_to_batch_tensors_empty_input(tokenizer): @@ -154,3 +174,195 @@ def test_convert_prompts_responses_to_batch_tensors_mismatched_lengths(tokenizer rewards, loss_masks, ) + + +# --------------------------------------------------------------------------- +# Unified padding layout tests +# --------------------------------------------------------------------------- + + +def test_unified_left_padding_layout(tokenizer): + """Sequences are laid out as [PAD ... PROMPT RESPONSE] with all padding on the left.""" + # Sample 0: prompt=[1,2], response=[10,11,12] -> total=5 + # Sample 1: prompt=[3,4,5,6], response=[20,21] -> total=6 + # max_total=6, max_response=3 + prompts = [[1, 2], [3, 4, 5, 6]] + responses = [[10, 11, 12], [20, 21]] + rewards = [[0.0] * 3, [0.0] * 2] + loss_masks = [[1] * 3, [1] * 2] + + seq, attn, action, rew, lm, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + ) + assert seq.shape == (2, 6) + + # Sample 0: pad=1, then [1,2,10,11,12] + assert seq[0].tolist() == [0, 1, 2, 10, 11, 12] + assert attn[0].tolist() == [0, 1, 1, 1, 1, 1] + # Response ends at the end of the sequence (no right-padding in sequences) + assert seq[0, -1] == 12 + + # Sample 1: no pad, [3,4,5,6,20,21] + assert seq[1].tolist() == [3, 4, 5, 6, 20, 21] + assert attn[1].tolist() == [1, 1, 1, 1, 1, 1] + + +def test_right_aligned_response_data(tokenizer): + """Response-level tensors are right-aligned: actual values at the end, zeros at the start.""" + prompts = [[1, 2, 3], [4, 5]] + responses = [[10], [20, 21, 22]] + rewards = [[1.0], [0.5, 0.6, 0.7]] + loss_masks = [[1], [1, 0, 1]] + logprobs = [[-0.1], [-0.2, -0.3, -0.4]] + prompts_copy = [p[:] for p in prompts] + responses_copy = [r[:] for r in responses] + + seq, attn, action, rew, lm, lp, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + logprobs, + ) + # max_response=3 + assert action.shape == (2, 3) + + # Sample 0: response_len=1, right-aligned -> [0, 0, 1] + assert action[0].tolist() == [0, 0, 1] + assert rew[0].tolist() == [0.0, 0.0, 1.0] + assert lm[0].tolist() == [0.0, 0.0, 1.0] + assert lp[0].tolist() == pytest.approx([0.0, 0.0, -0.1]) + + # Sample 1: response_len=3, right-aligned -> [1, 1, 1] (no padding) + assert action[1].tolist() == [1, 1, 1] + assert rew[1].tolist() == pytest.approx([0.5, 0.6, 0.7]) + assert lm[1].tolist() == [1.0, 0.0, 1.0] + assert lp[1].tolist() == pytest.approx([-0.2, -0.3, -0.4]) + + # Test does not mutate inputs + assert prompts == prompts_copy + assert responses == responses_copy + + +def test_max_seq_len_warns_but_does_not_truncate(tokenizer): + """max_seq_len only warns; no tokens are lost.""" + prompts = [[1] * 50, [2] * 10] + responses = [[3] * 10, [4] * 50] + rewards = [[0.0] * 10, [0.0] * 50] + loss_masks = [[1] * 10, [1] * 50] + + seq, _, action, _, _, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + max_seq_len=30, + ) + # max_total = max(60, 60) = 60, which exceeds max_seq_len=30 + # But no truncation: all tokens preserved + assert seq.shape == (2, 60) + assert action.shape == (2, 50) + + +# --------------------------------------------------------------------------- +# R3 (Router Replay) — rollout_expert_indices padding tests +# --------------------------------------------------------------------------- + + +def test_rollout_expert_indices_shape_padding_and_alignment(tokenizer): + """rollout_expert_indices tensor should have shape [batch, max_total, layers, topk] + with left-padding aligned to the attention_mask.""" + # Sample 0: prompt=2, response=3 → total=5 + # Sample 1: prompt=4, response=2 → total=6 + # max_total=6 + prompts = [[1, 2], [3, 4, 5, 6]] + responses = [[10, 11, 12], [20, 21]] + rewards = [[0.0] * 3, [0.0] * 2] + loss_masks = [[1] * 3, [1] * 2] + + num_layers = 2 + topk = 2 + # rollout_expert_indices[i] has shape [prompt_len_i + response_len_i, num_layers, topk] + # Sample 0: 5 tokens, sample 1: 6 tokens + rei_0 = [[[1, 2]] * num_layers for _ in range(5)] # 5 tokens + rei_1 = [[[3, 4]] * num_layers for _ in range(6)] # 6 tokens + + seq, attn, action, rew, lm, lp, rei_tensor = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + rollout_expert_indices=[rei_0, rei_1], + ) + + assert rei_tensor is not None + # Shape: [batch=2, max_total=6, layers=2, topk=2] + assert rei_tensor.shape == (2, 6, num_layers, topk) + + # Sample 0 has total=5, so 1 left-pad position → first position should be zeros + assert rei_tensor[0, 0].tolist() == [[0, 0]] * num_layers # padding + assert rei_tensor[0, 1].tolist() == [[1, 2]] * num_layers # first real token + + # Sample 1 has total=6, no padding + assert rei_tensor[1, 0].tolist() == [[3, 4]] * num_layers # first real token + + # Non-zero positions in rei_tensor align exactly with attention_mask==1 + for i in range(2): + for pos in range(6): + if attn[i, pos] == 0: + assert rei_tensor[i, pos].tolist() == [[0, 0]] * num_layers + else: + assert rei_tensor[i, pos].tolist() != [[0, 0]] * num_layers + + +def test_rollout_expert_indices_none_when_not_provided(tokenizer): + """When rollout_expert_indices is not provided, the returned tensor should be None.""" + prompts = [[1, 2], [3, 4]] + responses = [[10], [20]] + rewards = [[0.0], [0.0]] + loss_masks = [[1], [1]] + + *_, rei_tensor = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + ) + assert rei_tensor is None + + +def test_stepwise_anti_correlation_no_inflation(tokenizer): + """Step-wise anti-correlated prompt/response lengths: seq_len = max(prompt_i + response_i), + NOT max(prompt_i) + max(response_i).""" + # Early turn: prompt=10, response=90 (total=100) + # Late turn: prompt=90, response=10 (total=100) + prompts = [list(range(10)), list(range(90))] + responses = [list(range(100, 190)), list(range(200, 210))] + rewards = [[0.0] * 90, [0.0] * 10] + loss_masks = [[1] * 90, [1] * 10] + + seq, attn, action, rew, lm, _, _ = convert_prompts_responses_to_batch_tensors( + tokenizer, + prompts, + responses, + rewards, + loss_masks, + ) + # max(10+90, 90+10) = 100, NOT 90+90=180 + assert seq.shape == (2, 100) + assert action.shape == (2, 90) + + # All real tokens are preserved (no truncation) + assert seq[0].tolist() == list(range(10)) + list(range(100, 190)) + assert seq[1].tolist() == list(range(90)) + list(range(200, 210)) + + # Response data right-aligned: sample 1 has 10 tokens -> [0]*80 + [1]*10 + assert action[1].tolist() == [0] * 80 + [1] * 10