[stepwise] Plumb through step-wise training for fully async#1536
[stepwise] Plumb through step-wise training for fully async#1536CharlieFRuan merged 1 commit intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for step-wise trajectories by adjusting group size calculations in the trainer and propagating a step_wise flag through the generator output concatenation and validation logic. Feedback indicates that the rollout_metrics calculation should also be updated to respect this flag, as it currently produces turn-level statistics instead of the expected trajectory-level metrics when in step-wise mode.
|
|
||
| num_prompts = len(result["prompt_token_ids"]) | ||
| validate_generator_output(num_prompts, result) | ||
| validate_generator_output(num_prompts, result, step_wise=step_wise) |
There was a problem hiding this comment.
While the step_wise flag is now correctly passed to validate_generator_output, the rollout_metrics re-calculation performed just above (at line 276 in the source) does not account for step-wise trajectories. In step-wise mode, result["response_ids"] contains individual turns, meaning get_rollout_metrics will compute turn-level statistics (e.g., average tokens per turn) instead of trajectory-level statistics.
Consider leveraging the new step_wise flag to filter or aggregate metrics appropriately (e.g., by filtering for is_last_step or grouping by trajectory_id) to ensure that the reported rollout_metrics are consistent with trajectory-level expectations.
There was a problem hiding this comment.
will leave as future TODO
Some minimal changes to enable step-wise training + fully async.
The only change needed is to make
uids.extend([cur_generated_output_group.uid] * group_size)'sgroup_sizeper-group. Since each group can have variable number of generator output entries as the number of turns vary.In addition, we add
step_wiseflag toconcatenate_generator_outputsso that thevalidate_generator_outputcall can validatestep_wiseconstraints when applicable.We ran search-r1 with fully async and step-wise, with the following commands on 1x8xH100s
Retrieval server (launched first, on GPUs 0-3):
The grey curve is what we have above.
This is compared against the curves we had in #1529. The curve grows slower because the grey one has
train_batch_size=256while the other ones have512. So the learning is at a similar pace.