-
Notifications
You must be signed in to change notification settings - Fork 278
[WIP] [Refactor] Add AReaLite API and examples. #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* . * . * efficient loading * format * . * .
| ┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ | ||
| │ Checkpoints │◄───│ FSDP Trainer │◄───│ Training Loop │ | ||
| │ │ │ (Sync Batch) │ │ │ | ||
| └─────────────────┘ └──────────────────┘ └─────────────────┘ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a main entry point annotation somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a data flow arch graph in the code walk-through example: https://github.com/inclusionAI/AReaL/blob/lite/docs/arealite/gsm8k_grpo.md
| if config.type == "my_collector": | ||
| return MyCollector(...) | ||
| ``` | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing in production system often lacks is systematic support of reward model, verifier and generative reward model. From the RL point of view it can be encapsulated in the concept of environment but it is curbersome to integrate them. Method a bit here might be good for those who are using RL in production.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reward models can be implemented as additional TrainEngines, just like the reference model. Generative reward models is encapsulated in the RolloutWorkflow object (aka the previous RolloutCollector), e.g., the agent call use the same LLM but with different prompts to judge whether the previous generated answer is correct.
| def __str__(self): | ||
| """Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'.""" | ||
| return ( | ||
| f"Parallel(mp={self.tensor_parallel_size}," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mp -> tp ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No "mp" any more.
| c_clip: Optional[float] = field( | ||
| default=None, | ||
| metadata={ | ||
| "help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any data validation similar to PyDantic to check range of parameters here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I'll mark that.
| # Licensed under the Apache License, Version 2.0 | ||
|
|
||
| import abc | ||
| import asyncio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw there is uvloop as the async manager. Does it actually used in the project
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to uvloop.run wherever using asyncio.run.
| # Cleanup registry | ||
| try: | ||
| self.registry.unregister_server(self.server_id) | ||
| except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The catch seems a bit too wide
|
|
||
| return Trajectory( | ||
| prompt=env_option, | ||
| data=dict(rewards=torch.tensor(rewards), **pad_sequences_to_tensors(data)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**pad_sequences_to_tensors(data) in function? return data=data_dict? TensorDict is not that bad though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uses TensorDict as the basic data structure
| # Run batched rollout by submitting requests to LLM servers | ||
| trajs = self.rollout_controller.generate_batch( | ||
| batch_size=len(data), | ||
| env_options=data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is good, very clear on prepare_batch and generate_batch
| base_model_path=self.config.actor.path, | ||
| ) | ||
|
|
||
| assert len(mb_stats) == self.config.ppo_n_minibatches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Throw some hints here such as f"Micro batch mismatch, current {mb_stats} and config {self.config.ppo_n_minibatches}. Check your configuration. "
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha, why this is empty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not empty in the new lite branch. :)
* ci: add test-arealite * ci: add checkout before running test-arealite * ci: add USERNAME * ci: add test script * ci: add GitHub mirror * ci: fix typo * ci: clone one commit * ci: fix condition * ci: set command timeout to 60m * ci: enable pip cache * ci: optimize container lifecycle * ci: split into many stages * ci(test-arealite): fix typo * ci: fix wrong env * ci: fix pytest * ci: uninstall transformer-engine * ci: uninstall transformer-engine * ci: fix model paths * ci: show stdout/stderr * ci: fix not clean up * ci: backup sglang * ci: remove tmp repo dir when run * ci: fix docker run exit 1 condition * ci(test-arealite): limit the concurrency and extend command timeout
|
@tsaoyu Hi Yu, thank you for the thorough review! We're currently finalizing some internal discussions about the API design, and the final implementation will differ slightly from what's currently proposed. We'd like to defer addressing these specific issues for a few days while we complete those discussions. Your feedback is valuable and we'll incorporate it into our revised implementation. We'd appreciate having you review again once the final version is ready. Thanks for your patience! |
|
Closed since #154 has been merged. |
Motivation
AReaL is too heavy for AI researchers to use, understand, and develop with for several reasons. The most important issue is that its code architecture is system-centric rather than AI-centric — the RL algorithm workflow consists of multiple workers that run consecutive model function calls, neither of which are well-known concepts for AI researchers. As a result, users must first understand these concepts before they can develop workflows and algorithms for their own use cases.
Additionally, due to historical reasons, AReaL's code is not clean. There are large pieces of code inherited from previous projects that are not useful but significantly increase the burden on users and developers. Sometimes debugging is difficult even for core developers like myself.
Since the tools for building RL workflows are becoming increasingly mature, implementing a framework that achieves comparable efficiency requires much fewer lines of code. Now is the proper time to revisit the API design and distill the giant codebase into a neat and clean one. The distilled codebase does not need to be ultra-efficient. Instead, we want to deliver 90% functionality of the original AReaL while minimizing the lines of code and the burden on potential users. Our aim is to build an RL training framework that is fast to use, fast to read, and fast to execute. Here comes the lite version of AReaL — AReaLite.
AReaLite is the first step in AReaL's refactoring process. It is not only a standalone training library with shallow interfaces, but will also provide the core API definitions to be used by AReaL in the future. AReaL will essentially transform its current worker-based architecture into an AI-centric architecture like AReaLite. AReaL will extend AReaLite's APIs and implementations to support more backends for efficient large-scale training.
Expectations of AReaLite
Highlights
AReaLite's Scope
SequenceSample. All data are PyTorch tensors.Architecture
Core Components
Data Flow Architecture
AReaLite uses an async producer-consumer pattern:
Key Design Principles
1. AI-Centric API Design
Unlike the original AReaL's system-centric approach with workers and model functions, AReaLite uses familiar ML concepts:
AgentandEnvironment(from RL literature)RolloutWorkflow(combines multiple agents and the environment to generate rollout data)Trainer(from HuggingFace/PyTorch, fetches rollout data and updates model parameters)2. Factory Pattern for Extensibility
Each major component uses a factory pattern for easy customization:
EngineFactorycreates training backendsTrainerFactorycreates training algorithmsRolloutWorkflowFactorycreates rollout workflows3. Configuration-Driven Architecture
All components are configured through dataclasses defined in
cli_args.py, enabling:Roadmap
impl/sglang_server.py).impl/sglang_client.py).tests/test_rollout.py).RolloutController).