-
Notifications
You must be signed in to change notification settings - Fork 284
[WIP] [Refactor] Add AReaLite API and examples. #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b23765c
d389cd8
577f72c
989bc8a
fd7103f
ad76c83
7a51668
7adfedd
170ec3a
a433a29
d15b091
0b0aa9c
cafc602
2f17696
b591015
7a81f55
0d07566
cd771ad
8f4370d
29b3d50
dbb2703
b0d0026
da41bf1
90cc896
7de1863
5441e87
2a01a5a
1183bc4
1a63361
f9390da
a5e82f2
20c7cd8
1424e7a
8bf6dd1
e4921d9
7fbe7d9
a218692
4fc6e2c
d5484f9
88bae72
eeda029
06829c6
6ce5ec3
92d6364
866ceac
91c7de2
38e3cac
816e115
211d461
8b87b23
6f1370c
331ca7c
9bf9b18
18aa285
e9d97f3
394a0ff
5ffce54
d3c2f15
9129dd7
6ebcbc5
7f1397e
302e876
6ed10c9
7695179
b4766bd
8d2bd4e
06060cf
b112d83
49a31c4
5d5ac78
e7163ea
ccdf037
92e3a3d
9b8306c
8019335
f9643b7
aa8f4ef
d07f595
242243b
3059640
59d288b
1be260e
eb431c1
63cd942
7e7240d
05a2df0
6ec4493
84ff759
f099bbd
15537cb
8c338e9
9724c8a
77a557c
73b5b3e
4320da8
b424176
d2a317d
a2ade35
8612932
91d6399
c66ed17
2ce1ece
09f339f
df5ee49
d1f863c
ab7503a
a5299b1
3a8796b
89a8d8c
078d3e1
9a06675
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| name: Test AReaLite | ||
|
|
||
| on: | ||
| push: | ||
| paths: | ||
| - .github/workflows/test-arealite.yml | ||
| - arealite/** | ||
| - ci/** | ||
| workflow_dispatch: | ||
|
|
||
| jobs: | ||
| test-arealite: | ||
| runs-on: ubuntu-latest | ||
| concurrency: | ||
| group: test-arealite | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
|
|
||
| - uses: appleboy/ssh-action@v1 | ||
| env: | ||
| GIT_REPO_URL: https://github.bibk.top/${{ github.repository }} | ||
| GIT_COMMIT_SHA: ${{ github.sha }} | ||
| with: | ||
| host: ${{ secrets.CI_NODE_ADDR }} | ||
| username: ${{ secrets.CI_NODE_USER }} | ||
| key: ${{ secrets.REMOTE_SSH_KEY }} | ||
| envs: GIT_REPO_URL,GIT_COMMIT_SHA | ||
| script_path: ci/clone_repo.sh | ||
|
|
||
| - uses: appleboy/ssh-action@v1 | ||
| env: | ||
| GIT_COMMIT_SHA: ${{ github.sha }} | ||
| with: | ||
| host: ${{ secrets.CI_NODE_ADDR }} | ||
| username: ${{ secrets.CI_NODE_USER }} | ||
| key: ${{ secrets.REMOTE_SSH_KEY }} | ||
| command_timeout: 2h | ||
| envs: GIT_COMMIT_SHA | ||
| script_path: ci/build_env_image.sh | ||
|
|
||
| - uses: appleboy/ssh-action@v1 | ||
| env: | ||
| GIT_COMMIT_SHA: ${{ github.sha }} | ||
| with: | ||
| host: ${{ secrets.CI_NODE_ADDR }} | ||
| username: ${{ secrets.CI_NODE_USER }} | ||
| key: ${{ secrets.REMOTE_SSH_KEY }} | ||
| command_timeout: 1h | ||
| envs: GIT_COMMIT_SHA | ||
| script_path: ci/test_arealite.sh |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| # AReaLite Design Doc | ||
|
|
||
| ## Motivation | ||
|
|
||
| AReaL is too heavy for AI researchers to use, understand, and develop with for several reasons. The most important issue is that its code architecture is *system-centric* rather than *AI-centric* — the RL algorithm workflow consists of multiple *workers* that run consecutive *model function calls*, neither of which are well-known concepts for AI researchers. As a result, users must first understand these concepts before they can develop workflows and algorithms for their own use cases. | ||
|
|
||
| Additionally, due to historical reasons, AReaL's code is not clean. There are large pieces of code inherited from previous projects that are not useful but significantly increase the burden on users and developers. Sometimes debugging is difficult even for core developers like myself. | ||
|
|
||
| Since the tools for building RL workflows are becoming increasingly mature, implementing a framework that achieves comparable efficiency requires much fewer lines of code. Now is the proper time to revisit the API design and distill the giant codebase into a neat and clean one. The distilled codebase does not need to be ultra-efficient. Instead, we want to deliver 90% functionality of the original AReaL while minimizing the lines of code and the burden on potential users. Our aim is to build an RL training framework that is fast to use, fast to read, and fast to execute. Here comes the lite version of AReaL — AReaLite. | ||
|
|
||
| AReaLite is the first step in AReaL's refactoring process. It is not only a standalone training library with shallow interfaces, but will also provide the core API definitions to be used by AReaL in the future. AReaL will essentially transform its current worker-based architecture into an AI-centric architecture like AReaLite. AReaL will **extend** AReaLite's APIs and implementations to support more backends for efficient large-scale training. | ||
|
|
||
| ## Expectations of AReaLite | ||
|
|
||
| ### Highlights | ||
|
|
||
| + Fully asynchronous training with decoupled inference and training. | ||
| + Elastic inference device scaling — users can shut down or launch more inference processes independently during training. | ||
| + Full SFT/RL algorithmic functionality matching AReaL. | ||
| + Arbitrary agentic rollout collector customization in a single file. | ||
| + Easy navigation to implementation references via Ctrl+click in VSCode. | ||
| + Support for distributed launching with Ray/SLURM/torchrun. | ||
|
|
||
| ### AReaLite's Scope | ||
|
|
||
| + Not bound to Ray. | ||
| + Only supports SGLang and PyTorch FSDP2 with SPMD launching. | ||
| + No customized data structures like `SequenceSample`. All data are PyTorch tensors. | ||
| + Uses HuggingFace (models, datasets) and PyTorch (FSDP, data structures) as much as possible. | ||
|
|
||
| ## Architecture | ||
|
|
||
| ### Core Components | ||
|
|
||
| ``` | ||
| arealite/ | ||
| ├── api/ # Abstract interfaces and data structures | ||
| ├── impl/ # Concrete implementations | ||
| ├── cli/ # Command-line interfaces | ||
| ├── config/ # Configuration templates | ||
| └── tests/ # Standalone test scripts | ||
| ``` | ||
|
|
||
| #### 1. API Layer (`api/`) | ||
|
|
||
| The API layer defines abstract interfaces and data structures that provide a clean contract between different components: | ||
|
|
||
| - **`engine_api.py`**: Defines `SPMDWrapper` for SPMD-based training backends (FSDP) and `EngineFactory` | ||
| - **`trainer_api.py`**: Defines `Trainer` base class for different training algorithms and `TrainerFactory` | ||
| - **`rollout_api.py`**: Defines `RolloutCollector`, `Agent`, `Environment` for RL data collection and `RolloutCollectorFactory` | ||
| - **`cli_args.py`**: Defines configuration dataclasses for all components | ||
|
|
||
| #### 2. Implementation Layer (`impl/`) | ||
|
|
||
| The implementation layer contains concrete implementations of the API interfaces: | ||
|
|
||
| - **`fsdp_wrapper.py`**: FSDP-based training engine using PyTorch FSDP2 | ||
| - **`trainer/grpo.py`**: GRPO trainer implementation for reinforcement learning | ||
| - **`rollout_controller.py`**: Coordinates rollout data collection across workers | ||
| - **`rlvr/`**: RLVR collector implementations | ||
| - **`agentic/`**: Agentic collector implementations (math, code tasks) | ||
|
|
||
| #### 3. CLI Layer (`cli/`) | ||
|
|
||
| The CLI layer provides user-facing commands: | ||
|
|
||
| - **`main.py`**: Main entry point for launching complete training pipelines | ||
| - **`launch_server.py`**: Utility for launching standalone LLM servers | ||
|
|
||
| ### Data Flow Architecture | ||
|
|
||
| AReaLite uses an **async producer-consumer pattern**: | ||
|
|
||
| ``` | ||
| ┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ | ||
| │ LLM Servers │◄──►│ Rollout Workers │───►│ Data Buffer │ | ||
| │ (SGLang) │ │ (Async Batch) │ │ │ | ||
| └─────────────────┘ └──────────────────┘ └─────────────────┘ | ||
| ▲ │ | ||
| │ ▼ | ||
| ┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ | ||
| │ Checkpoints │◄───│ FSDP Trainer │◄───│ Training Loop │ | ||
| │ │ │ (Sync Batch) │ │ │ | ||
| └─────────────────┘ └──────────────────┘ └─────────────────┘ | ||
| ``` | ||
|
|
||
| ### Key Design Principles | ||
|
|
||
| #### 1. **AI-Centric API Design** | ||
| Unlike the original AReaL's system-centric approach with workers and model functions, AReaLite uses familiar ML concepts: | ||
| - `Agent` and `Environment` (from RL literature) | ||
| - `RolloutCollector` (combines multiple agents and the environment to generate rollout data) | ||
| - `Trainer` (from HuggingFace/PyTorch, fetches rollout data and updates model parameters) | ||
|
|
||
| #### 2. **Factory Pattern for Extensibility** | ||
| Each major component uses a factory pattern for easy customization: | ||
| - `EngineFactory` creates training backends | ||
| - `TrainerFactory` creates training algorithms | ||
| - `RolloutCollectorFactory` creates rollout collectors | ||
|
|
||
| #### 3. **Configuration-Driven Architecture** | ||
| All components are configured through dataclasses defined in `cli_args.py`, enabling: | ||
| - Type-safe configuration | ||
| - Easy CLI argument generation | ||
| - Clear documentation of available options | ||
|
|
||
|
|
||
| ## Implementation Details | ||
|
|
||
| ### Training Pipeline | ||
|
|
||
| 1. **Initialization**: Factory classes create configured instances of engines, trainers, and rollout collectors | ||
| 2. **Rollout Phase**: `RolloutController` coordinates async data collection across multiple `RolloutWorker` instances | ||
| 3. **Training Phase**: `Trainer` performs synchronous gradient updates using collected data | ||
| 4. **Weight Updates**: Updated model weights are pushed to LLM servers via `update_weights_to()` | ||
|
|
||
| ### Rollout System | ||
|
|
||
| The rollout system supports arbitrary agentic rollout paradigms, implemented as `RolloutCollector` instances. `RolloutCollector` exposes a `run_episode` method for users to implement the logic of collecting a complete agentic trajectory. Users can implement gymnasium-compatible `Agent` and `Environment` interfaces first and combine them as a collector as in normal RL literature (in `arealite/impl/agentic/`), or users can implement the collector directly if the agent-environment interfaces are not compatible with the desired use cases (in `arealite/impl/rlvr/`). | ||
|
|
||
| ## Expected Usage | ||
|
|
||
| ### Basic RL Training | ||
| ```bash | ||
| python3 arealite.cli.main \ | ||
| experiment_name=my-exp trial_name=my-trial \ | ||
| trainer.type=grpo \ | ||
| trainer.grpo.actor.path=Qwen/Qwen2-0.5B | ||
| ``` | ||
|
|
||
| ### Rollout-Only Evaluation | ||
| ```bash | ||
| python3 arealite.cli.main \ | ||
| trainer.type=null \ | ||
| valid_dataset.path=huggingface/dataset | ||
| ``` | ||
|
|
||
| ### Distributed Training | ||
| ```bash | ||
| python3 arealite.cli.main \ | ||
| mode=ray \ | ||
| allocation_mode=sglang.d16p1m1+d32p2m1 \ | ||
| trainer.type=grpo | ||
| ``` | ||
|
|
||
| ## Customization Guide | ||
|
|
||
| ### Adding New Trainers | ||
garrett4wade marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| 1. **Implement trainer class** in `impl/trainer/`: | ||
| ```python | ||
| from arealite.api.trainer_api import Trainer | ||
|
|
||
| class MyTrainer(Trainer): | ||
| def train(self, resume_from_checkpoint=None): | ||
| # Implementation here | ||
| pass | ||
| ``` | ||
|
|
||
| 2. **Add configuration** in `cli_args.py`: | ||
| ```python | ||
| @dataclass | ||
| class MyTrainerConfig: | ||
| learning_rate: float = 1e-4 | ||
| ``` | ||
|
|
||
| 3. **Register in factory** in `trainer_api.py`: | ||
| ```python | ||
| def make_trainer(self, config: TrainerConfig) -> Trainer: | ||
| if config.type == "my_trainer": | ||
| return MyTrainer(...) | ||
| ``` | ||
|
|
||
| ### Adding New Rollout Collectors | ||
garrett4wade marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| 1. **Implement collector** in `impl/`: | ||
| ```python | ||
| from arealite.api.rollout_api import RolloutCollector | ||
|
|
||
| class MyCollector(RolloutCollector): | ||
| async def arun_episode(self, gconfig, env_option=None, seed=None): | ||
| # Implementation here | ||
| pass | ||
| ``` | ||
|
|
||
| 2. **Register in factory** in `rollout_api.py`: | ||
| ```python | ||
| def make_collector(self, config: RolloutCollectorConfig): | ||
| if config.type == "my_collector": | ||
| return MyCollector(...) | ||
| ``` | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing in production system often lacks is systematic support of reward model, verifier and generative reward model. From the RL point of view it can be encapsulated in the concept of environment but it is curbersome to integrate them. Method a bit here might be good for those who are using RL in production.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reward models can be implemented as additional |
||
| ## Roadmap | ||
|
|
||
| - [ ] Finalize API design. (In-progress) | ||
| - [x] Implement standalone SGLang server (`impl/sglang_server.py`). | ||
| - [x] Implement SGLang client generation (`impl/sglang_client.py`). | ||
| - [x] Rollout pipeline (`tests/test_rollout.py`). | ||
| - [x] SGLang rollout interruption. | ||
| - [x] Asynchronous RL system-wide utilities (e.g., `RolloutController`). | ||
| - [ ] Various launching scripts: ray, torchrun, slurm. | ||
| - [ ] FSDP2 engine with transformers models. (In-progress) | ||
| - [ ] SFT trainer. (In-progress) | ||
| - [ ] SGLang update weights. (In-progress) | ||
| - [x] GRPO trainer. | ||
| - [ ] Add benchmarking against original AReaL | ||
| - [ ] CI and unittests. | ||
| - [ ] Other RL algorithms (DPO, REINFORCE, etc.) | ||
| - [ ] Support for multi-modal models | ||
| - [ ] User guide for transitioning from v0.3.0. | ||
| - [ ] Advanced agentic collectors (tool use, planning) | ||
| - [ ] Examples of training GSM8K, TLDR, and a search agent. | ||
| - [ ] Allow external persistent SGLang servers for debugging purposes. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a main entry point annotation somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a data flow arch graph in the code walk-through example: https://github.com/inclusionAI/AReaL/blob/lite/docs/arealite/gsm8k_grpo.md