-
Notifications
You must be signed in to change notification settings - Fork 548
Integrate fully async training to UnifiedTrainer #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
f46aa54
init new feature on unified fully async design
listar2000 fd69d8f
add coordinator control and refactor queue
listar2000 fb85d2a
cherrypick Kyle's async design refinements from kyle/deepresearch
listar2000 f9f01e5
Refactor chat parser and migrate experimental rollout to engine (#435)
listar2000 9af8505
merge nightly
listar2000 9a9cb76
dump changes to rollout_engine into main file
listar2000 18ca0f4
refactor base rollout engine class to standardize gating behaviors
listar2000 764d0e1
make tinker backend fully compatible
listar2000 1da0085
merge Kyle's fork
kylemontgomery1 8a2db48
Merge remote-tracking branch 'origin/main' into unified-fully-async
kylemontgomery1 f77f94a
bump vllm, deepcopy msgs in Step's post_init
kylemontgomery1 46b3356
[wip] make fully-async unified trainer compatible with agent flow eng…
kylemontgomery1 497d35a
fix staleness thottling
kylemontgomery1 8170c7a
enfore concurrency across engines
kylemontgomery1 3e2eb8d
Merge remote-tracking branch 'origin/main' into unified-fully-async
kylemontgomery1 2f8e2f1
fix fully async, refactor metrics
kylemontgomery1 ec49de5
Merge origin/main into unified-fully-async
kylemontgomery1 0f01be7
revert engine/rollout to main, restore experimental/rollout engines
kylemontgomery1 c86083b
revert TinkerChatTemplateParser and parser changes for separate PR
kylemontgomery1 a5b8b4f
revert bypass_render_with_parser and tinker parser-related changes
kylemontgomery1 4b67829
remove engine/gateway-level gate mechanism
kylemontgomery1 bc7c37f
refactor: move task tracking to coordinator, revert validation rename…
kylemontgomery1 7550fda
restore load_balancer assertion in verl_engine, revert tool_base to main
kylemontgomery1 4f05c8e
fix: add future annotations to rollout_engine for TYPE_CHECKING imports
kylemontgomery1 44d95a5
Merge remote-tracking branch 'origin/main' into unified-fully-async
kylemontgomery1 7993243
style: fix ruff lint and format issues on unified-fully-async branch
kylemontgomery1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
examples/countdown/unified_trainer/train_countdown_unified_tinker.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| import hydra | ||
|
|
||
| from rllm.data.dataset import DatasetRegistry | ||
| from rllm.experimental.unified_trainer import AgentTrainer | ||
| from rllm.rewards.countdown_reward import countdown_reward_fn | ||
| from rllm.workflows.simple_workflow import SimpleWorkflow | ||
|
|
||
|
|
||
| @hydra.main(config_path="pkg://rllm.experimental.config", config_name="unified", version_base=None) | ||
| def main(config): | ||
| train_dataset = DatasetRegistry.load_dataset("countdown", "train") | ||
| test_dataset = DatasetRegistry.load_dataset("countdown", "test") | ||
|
|
||
| trainer = AgentTrainer( | ||
| workflow_class=SimpleWorkflow, | ||
| workflow_args={ | ||
| "reward_function": countdown_reward_fn, | ||
| }, | ||
| config=config, | ||
| train_dataset=train_dataset, | ||
| val_dataset=test_dataset, | ||
| backend="tinker", | ||
| ) | ||
| trainer.train() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
36 changes: 36 additions & 0 deletions
36
examples/countdown/unified_trainer/train_countdown_unified_tinker_async.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| set -x | ||
|
|
||
| python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ | ||
| rllm/backend=tinker \ | ||
| model.name=Qwen/Qwen3-8B \ | ||
| model.lora_rank=32 \ | ||
| training.group_size=8 \ | ||
| training.learning_rate=2e-5 \ | ||
| training.max_length=4096 \ | ||
| sampling.train.temperature=1.0 \ | ||
| sampling.train.top_p=1.0 \ | ||
| sampling.val.temperature=1.0 \ | ||
| sampling.val.top_p=1.0 \ | ||
| validation.group_size=1 \ | ||
| rllm.workflow.n_parallel_tasks=256 \ | ||
| rllm.workflow.retry_limit=1 \ | ||
| rllm.workflow.raise_on_error=false \ | ||
| data.max_prompt_length=2048 \ | ||
| data.max_response_length=2048 \ | ||
| data.train_batch_size=1 \ | ||
| data.val_batch_size=1024 \ | ||
| rllm.algorithm.adv_estimator=grpo \ | ||
| rllm.algorithm.norm_adv_by_std_in_grpo=true \ | ||
| rllm.async_training.enable=true \ | ||
| rllm.async_training.mini_batch_size=32 \ | ||
| rllm.async_training.fwd_bwd_group_size=8 \ | ||
| rllm.async_training.staleness_threshold=0.5 \ | ||
| rllm.async_training.trigger_parameter_sync_step=1 \ | ||
| rllm.async_training.partial_rollout=true \ | ||
| rllm.trainer.total_epochs=1 \ | ||
| rllm.trainer.logger='[wandb]' \ | ||
| rllm.trainer.project_name='rllm-countdown' \ | ||
| rllm.trainer.experiment_name='countdown-tinker-async-staleness-0.5' \ | ||
| rllm.trainer.val_before_train=true \ | ||
| rllm.trainer.test_freq=10 \ | ||
| rllm.trainer.save_freq=-1 |
31 changes: 31 additions & 0 deletions
31
examples/countdown/unified_trainer/train_countdown_unified_tinker_sync.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| set -x | ||
|
|
||
| python -m examples.countdown.unified_trainer.train_countdown_unified_tinker \ | ||
| rllm/backend=tinker \ | ||
| model.name=Qwen/Qwen3-8B \ | ||
| model.lora_rank=32 \ | ||
| training.group_size=8 \ | ||
| training.learning_rate=2e-5 \ | ||
| training.max_length=4096 \ | ||
| sampling.train.temperature=1.0 \ | ||
| sampling.train.top_p=1.0 \ | ||
| sampling.val.temperature=1.0 \ | ||
| sampling.val.top_p=1.0 \ | ||
| validation.group_size=1 \ | ||
| rllm.workflow.n_parallel_tasks=256 \ | ||
| rllm.workflow.retry_limit=1 \ | ||
| rllm.workflow.raise_on_error=false \ | ||
| data.max_prompt_length=2048 \ | ||
| data.max_response_length=2048 \ | ||
| data.train_batch_size=32 \ | ||
| data.val_batch_size=1024 \ | ||
| rllm.algorithm.adv_estimator=grpo \ | ||
| rllm.algorithm.norm_adv_by_std_in_grpo=true \ | ||
| rllm.async_training.enable=false \ | ||
| rllm.trainer.total_epochs=1 \ | ||
| rllm.trainer.logger='[wandb]' \ | ||
| rllm.trainer.project_name='rllm-countdown' \ | ||
| rllm.trainer.experiment_name='countdown-tinker-sync' \ | ||
| rllm.trainer.val_before_train=true \ | ||
| rllm.trainer.test_freq=10 \ | ||
| rllm.trainer.save_freq=-1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This deepcopy was incorrectly removed during the refactor from dataclasses to pydantic. Many old workflows operate with:
If chat completions is not deepcopied, then appending a message on a future turn would mutate a previous turn's step.chat_completions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess in the future we should really have a rLLM built-in
messagesformat & class (similar to Tinker'sMessage), and ensure (1) it's as easy to work with as a plain dictionary, while (2) every step only holds a "view" of it (so no need to keep lots of copies, while earlier steps are not affected).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I think we can spend some time this week rethinking messages/parsers.