Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni

## 🚀 News

* [2025-11] Introducing [BOTS](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots): dynamic RL task selection for efficient LLM fine-tuning ([paper](https://arxiv.org/pdf/2510.26374)).
* [2025-10] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.2)] Trinity-RFT v0.3.2 released: bug fixes and advanced task selection & scheduling.
* [2025-10] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.1)] Trinity-RFT v0.3.1 released: multi-stage training support, improved agentic RL examples, LoRA support, debug mode and new RL algorithms.
* [2025-09] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.0)] Trinity-RFT v0.3.0 released: enhanced Buffer, FSDP2 & Megatron support, multi-modal models, and new RL algorithms/examples.
Expand Down
69 changes: 69 additions & 0 deletions examples/bots/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# 🤖🤖🤖 BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning

<p align="center">
<a href="https://arxiv.org/abs/2510.26374">
<img alt="Paper" src="https://img.shields.io/badge/Paper-arXiv%3A2510.26374-b31b1b?style=flat&logo=arxiv">
</a>
</p>

### Overview

<img src="https://gw.alicdn.com/imgextra/i2/O1CN01MO34b71y4VQnD3WRp_!!6000000006525-2-tps-1247-567.png" alt="Agentic workflows" width="700" />

BOTS operates in a continuous loop of task selection, model training, and posterior updating.
(1) **Selection**: Thompson sampling from the posterior beliefs selects a batch of tasks whose estimated success probabilities are near a target difficulty (e.g., $p^*=0.5$).
(2) **Training \& Evidence Collection**: The LLM is finetuned, yielding direct success/failure counts (_explicit evidence_) for the selected batch.
For unselected tasks, predicted counts (_implicit evidence_) are produced by a plug-in; We introduce an ultra-lightweight interpolation-based variant with negligible overhead.
(3) **Posterior Updating**: Explicit and implicit evidence are fused using our generalized Bayesian update rule.

### Usage

##### Step 1: Environment Preparation

Ensure Trinity-RFT is well installed ([Installation Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)). No extra dependence is required.

##### Step 2: Prepare Model & Dataset

Download the model your want to train (e.g. [Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct)).

Download the [GURU](https://huggingface.co/datasets/LLM360/guru-RL-92k) dataset.
Also refer to the [Data Preparation Guide](https://github.com/LLM360/Reasoning360?tab=readme-ov-file#data-preparation) and the [Tech Report](https://www.arxiv.org/pdf/2506.14965) provided by the LLM360 team.

Remember to modify the model/data path in `bots.yaml` and `random.yaml` accordingly.

##### Step 3: Training
Launch training by executing:
```bash
trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/plugins
```
The improvement over random selection baseline can be stably obtained 🤖🤖🤖.

<img src="https://img.alicdn.com/imgextra/i3/O1CN016wQqpG1wFq00KWzV7_!!6000000006279-2-tps-1894-1066.png" alt="Agentic workflows" width="700" />

### Complete Reproduction

For complete reproduction of the results in our paper, please use the verl version implementation available [here](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip).

### Citation
If you find the repo helpful, please cite:
```
@misc{TrinityRFT,
title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models},
author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou},
year={2025},
eprint={2505.17826},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2505.17826},
}
@misc{BOTS,
title={BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning},
author={Qianli Shen and Daoyuan Chen and Yilun Huang and Zhenqing Ling and Yaliang Li and Bolin Ding and Jingren Zhou},
year={2025},
eprint={2510.26374},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2510.26374},
}
```
79 changes: 79 additions & 0 deletions examples/bots/bots.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
project: "BOTS-Selector"
name: "qwen2.5-1.5B-instruct-bots"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
data_processor:
experience_pipeline:
operators:
- name: pass_rate_calculator
algorithm:
algorithm_type: grpo
repeat_times: 16
optimizer:
lr: 1e-6
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_prompt_tokens: 4096
max_response_tokens: 8192
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 32
explorer_input:
taskset:
name: math-train
storage_type: file
path: '<DATA_ROOT>/LLM360/guru-RL-92k/train/math__combined_54.4k.parquet'
split: 'train'
format:
prompt_key: 'prompt'
response_key: 'reward_model.ground_truth'
rollout_args:
temperature: 1.0
task_selector:
selector_type: difficulty_based
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ]
kwargs:
m: 16
lamb: 0.1
rho: 0.1
target_reward: 0.5
tau: 0
do_sample: true
eval_tasksets:
- name: math-eval
storage_type: file
path: '<DATA_ROOT>/LLM360/guru-RL-92k/online_eval/math__math_500.parquet'
format:
prompt_key: 'prompt'
response_key: 'reward_model.ground_truth'
rollout_args:
temperature: 1.0
default_workflow_type: 'bots_math_boxed_workflow'
trainer_input:
experience_buffer:
name: exp_buffer
storage_type: queue
path: 'sqlite:///bots_trainer_buffer.db'
explorer:
eval_interval: 40
runner_per_model: 8
rollout_model:
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 8
sync_timeout: 1200
trainer:
trainer_type: 'verl'
save_interval: 800
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 24576
ulysses_sequence_parallel_size: 1
33 changes: 33 additions & 0 deletions examples/bots/plugins/bots_math_boxed_reward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn
from trinity.utils.eval_utils import validate_think_pattern

from .bots_reward import compute_score


@REWARD_FUNCTIONS.register_module("bots_math_boxed_reward")
class BOTSMathBoxedRewardFn(RewardFn):
"""A reward function that rewards for math task for BOTS."""

def __init__(
self,
**kwargs,
) -> None:
pass

def __call__( # type: ignore
self,
response: str,
truth: Optional[str] = None,
with_think: Optional[bool] = False,
format_score_coef: Optional[float] = 0.1,
**kwargs,
) -> dict[str, float]:
accuracy_score = compute_score(response, truth)

format_score = 0.0
if with_think and not validate_think_pattern(response):
format_score = (format_score_coef or 0.1) * -1.0

return {"accuracy": accuracy_score, "format_score": format_score}
17 changes: 17 additions & 0 deletions examples/bots/plugins/bots_math_boxed_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task
from trinity.common.workflows.workflow import WORKFLOWS

from .bots_math_boxed_reward import BOTSMathBoxedRewardFn


@WORKFLOWS.register_module("bots_math_boxed_workflow")
class BOTSMathBoxedWorkflow(MathBoxedWorkflow):
"""A workflow for math tasks that give answers in boxed format for BOTS."""

def reset(self, task: Task):
super().reset(task)
self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args)

def format_messages(self):
# the prompts are already in message format
return self.task_desc
Loading