From ad747e088837b86447b83360152d8526f36f217e Mon Sep 17 00:00:00 2001 From: Shen Qianli Date: Thu, 30 Oct 2025 16:09:53 +0800 Subject: [PATCH 01/12] examples/bots init commit --- examples/bots/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 examples/bots/README.md diff --git a/examples/bots/README.md b/examples/bots/README.md new file mode 100644 index 0000000000..0295c9b0dc --- /dev/null +++ b/examples/bots/README.md @@ -0,0 +1,13 @@ +# 🤖🤖🤖 BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning + +

+ + Paper + +

+ +### Repository Status + +This repository hosts the upcoming Trinity version of our code, which is still under development and not yet released. + +For complete reproduction of the results in our paper, please use the verl version available [here](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip). \ No newline at end of file From e3413d5e9dfb711b5c642a40457e4d56ece662ef Mon Sep 17 00:00:00 2001 From: Shen Qianli Date: Fri, 31 Oct 2025 10:11:35 +0800 Subject: [PATCH 02/12] examples/bots update README --- examples/bots/README.md | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/bots/README.md b/examples/bots/README.md index 0295c9b0dc..00dce0768c 100644 --- a/examples/bots/README.md +++ b/examples/bots/README.md @@ -1,8 +1,8 @@ # 🤖🤖🤖 BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning

- - Paper + + Paper

@@ -10,4 +10,18 @@ This repository hosts the upcoming Trinity version of our code, which is still under development and not yet released. -For complete reproduction of the results in our paper, please use the verl version available [here](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip). \ No newline at end of file +For complete reproduction of the results in our paper, please use the verl version available [here](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip). + +### Citation +If you find the repo helpful, please cite: +``` +@misc{shen2025botsunifiedframeworkbayesian, + title={BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning}, + author={Qianli Shen and Daoyuan Chen and Yilun Huang and Zhenqing Ling and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2510.26374}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2510.26374}, +} +``` \ No newline at end of file From be47df79f583a4c42f8951de0d2b3ebccf9e3197 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=83=88=E9=9C=96?= Date: Wed, 5 Nov 2025 15:05:28 +0800 Subject: [PATCH 03/12] + add examples for bots --- examples/bots/bots.yaml | 79 ++ .../bots/plugins/bots_math_boxed_reward.py | 32 + .../bots/plugins/bots_math_boxed_workflow.py | 16 + examples/bots/plugins/bots_reward.py | 892 ++++++++++++++++++ examples/bots/random.yaml | 67 ++ .../operators/mappers/pass_rate_calculator.py | 29 +- trinity/buffer/task_scheduler.py | 2 +- trinity/common/experience.py | 2 +- trinity/common/workflows/workflow.py | 21 +- 9 files changed, 1135 insertions(+), 5 deletions(-) create mode 100644 examples/bots/bots.yaml create mode 100644 examples/bots/plugins/bots_math_boxed_reward.py create mode 100644 examples/bots/plugins/bots_math_boxed_workflow.py create mode 100644 examples/bots/plugins/bots_reward.py create mode 100644 examples/bots/random.yaml diff --git a/examples/bots/bots.yaml b/examples/bots/bots.yaml new file mode 100644 index 0000000000..a5bd722de9 --- /dev/null +++ b/examples/bots/bots.yaml @@ -0,0 +1,79 @@ +project: "BOTS-Selector" +name: "qwen2.5-1.5B-instruct-bots" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +data_processor: + experience_pipeline: + operators: + - name: pass_rate_calculator +algorithm: + algorithm_type: grpo + repeat_times: 16 + optimizer: + lr: 1e-6 +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + max_prompt_tokens: 4096 + max_response_tokens: 8192 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 32 + explorer_input: + taskset: + name: math-train + storage_type: file + path: '/LLM360/guru-RL-92k/train/math__combined_54.4k.parquet' + split: 'train' + format: + prompt_key: 'prompt' + response_key: 'reward_model.ground_truth' + rollout_args: + temperature: 1.0 + task_selector: + selector_type: difficulty_based + feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ] + kwargs: + m: 16 + lamb: 0.1 + rho: 0.1 + target_reward: 0.5 + tau: 0 + do_sample: true + eval_tasksets: + - name: math-eval + storage_type: file + path: '/LLM360/guru-RL-92k/online_eval/math__math_500.parquet' + format: + prompt_key: 'prompt' + response_key: 'reward_model.ground_truth' + rollout_args: + temperature: 1.0 + default_workflow_type: 'bots_math_boxed_workflow' + trainer_input: + experience_buffer: + name: exp_buffer + storage_type: queue + path: 'sqlite:///bots_trainer_buffer.db' +explorer: + eval_interval: 40 + runner_per_model: 8 + rollout_model: + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 8 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + save_interval: 800 + grad_clip: 1.0 + use_dynamic_bsz: true + max_token_len_per_gpu: 24576 + ulysses_sequence_parallel_size: 1 \ No newline at end of file diff --git a/examples/bots/plugins/bots_math_boxed_reward.py b/examples/bots/plugins/bots_math_boxed_reward.py new file mode 100644 index 0000000000..e5b70ca637 --- /dev/null +++ b/examples/bots/plugins/bots_math_boxed_reward.py @@ -0,0 +1,32 @@ +from typing import Optional + +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.utils.eval_utils import validate_think_pattern + +from .bots_reward import compute_score + +@REWARD_FUNCTIONS.register_module("bots_math_boxed_reward") +class BOTSMathBoxedRewardFn(RewardFn): + """A reward function that rewards for math task for BOTS.""" + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def __call__( # type: ignore + self, + response: str, + truth: Optional[str] = None, + with_think: Optional[bool] = False, + format_score_coef: Optional[float] = 0.1, + **kwargs, + ) -> dict[str, float]: + accuracy_score = compute_score(response, truth) + + format_score = 0.0 + if with_think and not validate_think_pattern(response): + format_score = (format_score_coef or 0.1) * -1.0 + + return {"accuracy": accuracy_score, "format_score": format_score} \ No newline at end of file diff --git a/examples/bots/plugins/bots_math_boxed_workflow.py b/examples/bots/plugins/bots_math_boxed_workflow.py new file mode 100644 index 0000000000..8c08d01874 --- /dev/null +++ b/examples/bots/plugins/bots_math_boxed_workflow.py @@ -0,0 +1,16 @@ +from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task +from trinity.common.workflows.workflow import WORKFLOWS + +from .bots_math_boxed_reward import BOTSMathBoxedRewardFn + +@WORKFLOWS.register_module("bots_math_boxed_workflow") +class BOTSMathBoxedWorkflow(MathBoxedWorkflow): + """A workflow for math tasks that give answers in boxed format for BOTS.""" + + def reset(self, task: Task): + super().reset(task) + self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) + + def format_messages(self): + # the prompts are already in message format + return self.task_desc diff --git a/examples/bots/plugins/bots_reward.py b/examples/bots/plugins/bots_reward.py new file mode 100644 index 0000000000..6a7be7e692 --- /dev/null +++ b/examples/bots/plugins/bots_reward.py @@ -0,0 +1,892 @@ +# Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py + +import re +import signal +from typing import Optional, Union +import math +from math import isclose +import contextlib + +import sympy +from pylatexenc import latex2text +from sympy.parsing import sympy_parser +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr +from sympy import N, simplify +import os + +from verl.utils.py_functional import timeout_limit + +def handle_base(x) -> str: + if isinstance(x, str) and "_" in x: + # Due to base + x = x.split("_")[0] + x = float(x) + return int(x) + return x + + +def handle_pi(string, pi): + if isinstance(string, str) and "\\pi" in string: + # Find the first occurrence of "\pi" + idx = string.find("\\pi") + + # Iterate over the string and find all occurrences of "\pi" with a valid previous character + while idx != -1: + if idx > 0 and string[idx - 1].isdigit(): + # Replace "\pi" with "*math.pi" if the previous character is a digit + string = string[:idx] + f"*{pi}" + string[idx + 3 :] + else: + # Replace "\pi" with "1*math.pi" if the previous character is not a digit + string = string[:idx] + f"1*{pi}" + string[idx + 3 :] + + # Find the next occurrence of "\pi" + idx = string.find("\\pi", idx + 1) + + # Evaluate the expression using eval() function + with contextlib.suppress(Exception): + string = eval(string) + + return string + +def normalize(answer, pi) -> str: + # checking if answer is $ and removing $ in that case to compare + if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)): + return answer[1:] + + # checking if answer is % or \\% and removing % + if isinstance(answer, str) and (bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))): + return answer.replace("\\%", "").replace("%", "") + + # handle base + answer = handle_base(answer) + + # handle pi + answer = handle_pi(answer, pi) + + return answer + +def is_digit(s): + try: + if "{,}" in str(s): + num = float(str(s).replace("{,}", "")) + return True, num + + num = float(str(s).replace(",", "")) + return True, num + except ValueError: + return False, None + +def format_intervals(prediction): + patterns = { + "Interval(": r"^Interval\((.*)\)$", + "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", + "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$", + "Interval.open(": r"^Interval\.open\((.*)\)$", + } + + for key, pattern in patterns.items(): + match = re.match(pattern, prediction) + if match: + inner_content = match.group(1) + + if key == "Interval(": # Intarval(a, b) == [a, b] + return f"[{inner_content}]" + elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b) + return f"[{inner_content})" + elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b] + return f"({inner_content}]" + elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) + return f"({inner_content})" + + return prediction + + +def symbolic_equal(a, b, tolerance, timeout=10.0): + def _parse(s): + for f in [parse_expr, parse_latex]: + try: + with timeout_limit(seconds=timeout): + return f(s) + except TimeoutError: + print(f"Parsing timed out for {s}") + continue + except Exception: + continue + return s + + a = _parse(a) + b = _parse(b) + + try: + with timeout_limit(seconds=timeout): + if simplify(a - b) == 0: + return True + except TimeoutError: + print(f"Simplification timed out for {a} - {b}") + pass + except Exception: + pass + + try: + with timeout_limit(seconds=timeout): + if isclose(N(a), N(b), rel_tol=tolerance): + return True + except TimeoutError: + print(f"Numerical evaluation timed out for {a}, {b}") + pass + except Exception: + pass + return False + + +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + tolerance: float = 1e-4, + timeout: float = 10.0, + pi: float = math.pi, +) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + + prediction = normalize(prediction, pi) + reference = normalize(reference, pi) + + if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases + prediction = prediction[:1000] + + # 0. string comparison + if isinstance(prediction, str) and isinstance(reference, str): + if prediction.strip().lower() == reference.strip().lower(): + return True + if prediction.replace(" ", "") == reference.replace(" ", ""): + return True + + try: # 1. numerical equal + if is_digit(prediction)[0] and is_digit(reference)[0]: + prediction = is_digit(prediction)[1] + reference = is_digit(reference)[1] + # number questions + gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + for item in gt_result: + try: + if isclose(item, prediction, rel_tol=tolerance): + return True + except Exception: + continue + return False + except Exception: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + ## deal with [], (), {} + prediction = format_intervals(prediction) + + pred_str, ref_str = prediction, reference + if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): + pred_str = pred_str.strip("[]()") + ref_str = ref_str.strip("[]()") + for s in ["{", "}", "(", ")"]: + ref_str = ref_str.replace(s, "") + pred_str = pred_str.replace(s, "") + if pred_str == ref_str: + return True + + ## [a, b] vs. [c, d], return a==c and b==d + if prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and prediction[0] == reference[0] and prediction[-1] == reference[-1]: + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]): + return True + + if "," in prediction and "," in reference: + pred_parts = [item.strip() for item in prediction.split(",")] + ref_parts = [item.strip() for item in reference.split(",")] + + if len(pred_parts) == len(ref_parts): + return bool(all([math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))])) + + # if we have point == tuple of values + if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": + pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]): + return True + + # if reference is a matrix + if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"): + try: + pred_matrix = parse_expr(prediction) + ref_matrix_items = reference.split()[1:-1:2] + if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]): + return True + except Exception: + pass + elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): + if isinstance(eval(prediction), list): + try: + pred_matrix = eval(prediction) + # ref_matrix_items = reference.split()[1:-1:2] + ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\\end{pmatrix}") # noqa: B005 + ref_matrix_items = ref_matrix_items.split("\\") + ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] + if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]): + return True + except Exception: + pass + + return symbolic_equal(prediction, reference, tolerance, timeout) + +class timeout: + + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +# sympy might hang -- we don't care about trying to be lenient in these cases +BAD_SUBSTRINGS = ["^{", "^("] +BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +TUPLE_CHARS = "()[]" + + +def timeout(timeout_seconds: int = 8): + if os.name == "posix": + import signal + + def decorator(func): + + def handler(signum, frame): + raise TimeoutError("Operation timed out!") + + def wrapper(*args, **kwargs): + old_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, handler) + signal.alarm(timeout_seconds) + + try: + return func(*args, **kwargs) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + return wrapper + + return decorator + else: + raise NotImplementedError(f"Unsupported OS: {os.name}") + + +def _sympy_parse(expr: str): + """Parses an expression with sympy.""" + py_expr = expr.replace("^", "**") + return sympy_parser.parse_expr( + py_expr, + transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + ) + + +def _parse_latex(expr: str) -> str: + """Attempts to parse latex to an expression sympy can read.""" + expr = expr.replace("\\tfrac", "\\frac") + expr = expr.replace("\\dfrac", "\\frac") + expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. + expr = latex2text.LatexNodes2Text().latex_to_text(expr) + + # Replace the specific characters that this parser uses. + expr = expr.replace("√", "sqrt") + expr = expr.replace("π", "pi") + expr = expr.replace("∞", "inf") + expr = expr.replace("∪", "U") + expr = expr.replace("·", "*") + expr = expr.replace("×", "*") + + return expr.strip() + + +def _is_float(num: str) -> bool: + try: + float(num) + return True + except ValueError: + return False + + +def _is_int(x: float) -> bool: + try: + return abs(x - int(round(x))) <= 1e-7 + except: + return False + + +def _is_frac(expr: str) -> bool: + return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) + + +def _str_is_int(x: str) -> bool: + try: + x = _strip_properly_formatted_commas(x) + x = float(x) + return abs(x - int(round(x))) <= 1e-7 + except: + return False + + +def _str_to_int(x: str) -> bool: + x = x.replace(",", "") + x = float(x) + return int(x) + + +def _inject_implicit_mixed_number(step: str): + """ + Automatically make a mixed number evalable + e.g. 7 3/4 => 7+3/4 + """ + p1 = re.compile("([0-9]) +([0-9])") + step = p1.sub("\\1+\\2", step) ## implicit mults + return step + + +def _strip_properly_formatted_commas(expr: str): + # We want to be careful because we don't want to strip tuple commas + p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + while True: + next_expr = p1.sub("\\1\\3\\4", expr) + if next_expr == expr: + break + expr = next_expr + return next_expr + + +def _normalize(expr: str) -> str: + """Normalize answer expressions.""" + if expr is None: + return None + + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\{(?P.+?)\}$", expr) + if m is not None: + expr = m.group("text") + + expr = expr.replace("\\%", "%") + expr = expr.replace("\\$", "$") + expr = expr.replace("$", "") + expr = expr.replace("%", "") + expr = expr.replace(" or ", " , ") + expr = expr.replace(" and ", " , ") + + expr = expr.replace("million", "*10^6") + expr = expr.replace("billion", "*10^9") + expr = expr.replace("trillion", "*10^12") + + for unit in [ + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + "liter", + ]: + expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) + expr = re.sub(f"\^ *\\\\circ", "", expr) + + if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": + expr = expr[1:-1] + + expr = re.sub(",\\\\! *", "", expr) + if _is_float(expr) and _is_int(float(expr)): + expr = str(int(round(float(expr)))) + if "\\" in expr: + try: + expr = _parse_latex(expr) + except: + pass + + # edge case with mixed numbers and negative signs + expr = re.sub("- *", "-", expr) + + expr = _inject_implicit_mixed_number(expr) + + # don't be case sensitive for text answers + expr = expr.lower() + + if _str_is_int(expr): + expr = str(_str_to_int(expr)) + + return expr + + +def count_unknown_letters_in_expr(expr: str): + expr = expr.replace("sqrt", "") + expr = expr.replace("frac", "") + letters_in_expr = set([x for x in expr if x.isalpha()]) + return len(letters_in_expr) + + +def should_allow_eval(expr: str): + # we don't want to try parsing unknown text or functions of more than two variables + if count_unknown_letters_in_expr(expr) > 2: + return False + + for bad_string in BAD_SUBSTRINGS: + if bad_string in expr: + return False + + for bad_regex in BAD_REGEXES: + if re.search(bad_regex, expr) is not None: + return False + + return True + + +# @timeout(timeout_seconds=10) +def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): + are_equal = False + try: + expr = f"({ground_truth_normalized})-({given_normalized})" + if should_allow_eval(expr): + sympy_diff = _sympy_parse(expr) + simplified = sympy.simplify(sympy_diff) + if simplified == 0: + are_equal = True + except: + pass + return are_equal + + +def split_tuple(expr: str): + """ + Split the elements in a tuple/interval, while handling well-formatted commas in large numbers + """ + expr = _strip_properly_formatted_commas(expr) + if len(expr) == 0: + return [] + if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and + all([ch not in expr[1:-1] for ch in TUPLE_CHARS])): + elems = [elem.strip() for elem in expr[1:-1].split(",")] + else: + elems = [expr] + return elems + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: # noqa: E722 + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: # noqa: E722 + return string + + +def _remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = _remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\\%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = _fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def normalize_answer(answer: Optional[str]) -> Optional[str]: + if answer is None: + return None + answer = answer.strip() + try: + # Remove enclosing `\text{}`. + m = re.search("^\\\\text\\{(?P.+?)\\}$", answer) + if m is not None: + answer = m.group("text").strip() + return _strip_string(answer) + except: # noqa: E722 + return answer + +def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: + """ + The answer will be considered correct if: + (a) it normalizes to the same string as the ground truth answer + OR + (b) sympy can simplify the difference between the expressions to 0 + """ + if given_answer is None: + return False + + ground_truth_normalized_mathd = normalize_answer(ground_truth) + given_answer_normalized_mathd = normalize_answer(given_answer) + + # be at least as lenient as mathd + if ground_truth_normalized_mathd == given_answer_normalized_mathd: + return True, given_answer_normalized_mathd + + ground_truth_normalized = _normalize(ground_truth) + given_normalized = _normalize(given_answer) + + if ground_truth_normalized is None: + return False, given_normalized + + if ground_truth_normalized == given_normalized: + return True, given_normalized + + if len(given_normalized) == 0: + return False, given_normalized + + ground_truth_elems = split_tuple(ground_truth_normalized) + given_elems = split_tuple(given_normalized) + + if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or + ground_truth_normalized[-1] != given_normalized[-1]): + is_correct = False + elif len(ground_truth_elems) != len(given_elems): + is_correct = False + else: + for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): + if _is_frac(ground_truth_elem) and _is_frac(given_elem): + # if fractions aren't reduced, then shouldn't be marked as correct + # so, we don't want to allow sympy.simplify in this case + is_correct = ground_truth_elem == given_elem + elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): + # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) + is_correct = False + else: + is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + if not is_correct: + break + + return is_correct, given_normalized + + +def _last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + left_brace_idx = None + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if left_brace_idx is None: + left_brace_idx = i + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + + i += 1 + + if left_brace_idx is None or right_brace_idx is None: + return None + + return string[left_brace_idx + 1:right_brace_idx].strip() + + +def match_answer(response): + is_matched = False + response = response.split("")[-1] + + # Find boxed + ans_boxed = _last_boxed_only_string(response) + if ans_boxed: + is_matched = True + response = ans_boxed + + return is_matched, response + + +import math + + +def compute_score(solution_str: str, + ground_truth: str) -> float: + """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + extra_info: dict with additional info for the score computation + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # First assert intended generation and gt type + model_output = str(solution_str) + ground_truth = str(ground_truth) + + # Extract answer from generated output + is_matched, extracted_model_output = match_answer(model_output) + + # TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score + + # Verify the solution, first check simple comparisons. + correct, pred = grade_answer(extracted_model_output, ground_truth) + + if not correct: + try: + if "\\pi" in extracted_model_output or "\\pi" in ground_truth: + equivs = [] + for pi in [math.pi, 3.14]: + equivs.append(math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi)) + correct = any(equivs) + else: + correct = math_equal(extracted_model_output, ground_truth, timeout=True) + except: + correct = False + + # reward = 1.0 if correct else -1.0 + reward = 1.0 if correct else 0. + + return reward diff --git a/examples/bots/random.yaml b/examples/bots/random.yaml new file mode 100644 index 0000000000..e5ab442def --- /dev/null +++ b/examples/bots/random.yaml @@ -0,0 +1,67 @@ +project: "BOTS-Selector" +name: "qwen2.5-1.5B-instruct-random" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 16 + optimizer: + lr: 1e-6 +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + max_prompt_tokens: 4096 + max_response_tokens: 8192 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 32 + explorer_input: + taskset: + name: math-train + storage_type: file + path: '/LLM360/guru-RL-92k/train/math__combined_54.4k.parquet' + split: 'train' + format: + prompt_key: 'prompt' + response_key: 'reward_model.ground_truth' + rollout_args: + temperature: 1.0 + task_selector: + selector_type: random + eval_tasksets: + - name: math-eval + storage_type: file + path: '/LLM360/guru-RL-92k/online_eval/math__math_500.parquet' + format: + prompt_key: 'prompt' + response_key: 'reward_model.ground_truth' + rollout_args: + temperature: 1.0 + default_workflow_type: 'bots_math_boxed_workflow' + trainer_input: + experience_buffer: + name: exp_buffer + storage_type: queue + path: 'sqlite:///random_trainer_buffer.db' +explorer: + eval_interval: 40 + runner_per_model: 8 + rollout_model: + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 8 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + save_interval: 800 + grad_clip: 1.0 + use_dynamic_bsz: true + max_token_len_per_gpu: 24576 + ulysses_sequence_parallel_size: 1 \ No newline at end of file diff --git a/trinity/buffer/operators/mappers/pass_rate_calculator.py b/trinity/buffer/operators/mappers/pass_rate_calculator.py index 38ff5627c5..a743c9c122 100644 --- a/trinity/buffer/operators/mappers/pass_rate_calculator.py +++ b/trinity/buffer/operators/mappers/pass_rate_calculator.py @@ -24,6 +24,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: assert "index" in task_index raw_metric[task_index["taskset_id"]][task_index["index"]].append(exp.reward) metric = {} + ref_pass_rates = [] for taskset_id, taskset_metric in raw_metric.items(): indices = [] reward_means = [] @@ -34,4 +35,30 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: "indices": indices, "values": reward_means, } - return exps, {SELECTOR_METRIC: metric} + ref_pass_rates.extend(reward_means) + ret_metric = {SELECTOR_METRIC: metric} + + valid_ratio = np.mean([1 if 0 < pr < 1 else 0 for pr in ref_pass_rates]) + strict_valid_ratio = np.mean( + [1 if 1 / 16 + 1e-3 < pr < 15 / 16 - 1e-3 else 0 for pr in ref_pass_rates] + ) + less_than_one_ratio = np.mean([1 if pr < 1 else 0 for pr in ref_pass_rates]) + larger_than_zero_ratio = np.mean([1 if pr > 0 else 0 for pr in ref_pass_rates]) + less_than_15_over_16_ratio = np.mean( + [1 if pr < 15 / 16 - 1e-3 else 0 for pr in ref_pass_rates] + ) + larger_than_1_over_16_ratio = np.mean( + [1 if pr > 1 / 16 + 1e-3 else 0 for pr in ref_pass_rates] + ) + ret_metric.update( + { + "selection/valid_ratio": valid_ratio, + "selection/strict_valid_ratio": strict_valid_ratio, + "selection/<1_ratio": less_than_one_ratio, + "selection/>0_ratio": larger_than_zero_ratio, + "selection/<15_16_ratio": less_than_15_over_16_ratio, + "selection/>1_16_ratio": larger_than_1_over_16_ratio, + } + ) + + return exps, ret_metric diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 35a4eff2ce..9101f2cba9 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -190,7 +190,7 @@ def update(self, pipeline_metrics: Dict) -> None: """ if SELECTOR_METRIC not in pipeline_metrics: return - selector_metric = pipeline_metrics[SELECTOR_METRIC] + selector_metric = pipeline_metrics.pop(SELECTOR_METRIC, {}) for taskset_id, taskset_kwargs in selector_metric.items(): selector = self.selectors[taskset_id] selector.update(**taskset_kwargs) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 6847d8e655..42af635873 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -238,7 +238,7 @@ def deserialize(cls, data: bytes) -> Experience: def to_dict(self) -> dict: """Convert the experience to a dictionary.""" res = { - "eid": self.eid, + "eid": self.eid.to_dict(), "type": self.experience_type, "prompt_length": self.prompt_length, "response_length": len(self.tokens) - self.prompt_length, # type: ignore [arg-type] diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 8a493e161f..90d19a8784 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -19,6 +19,23 @@ WORKFLOWS = Registry("workflows") +def nested_query(query_key: str, query_obj: Union[dict, None]): + # support nested query for a dict given query_keys split by '.' + if query_obj is None: + return None + if "." in query_key: + query_keys = query_key.split(".") + else: + query_keys = [query_key] + ret = query_obj + for key in query_keys: + if isinstance(ret, dict) and key in ret: + ret = ret[key] + else: + return None + return ret + + @dataclass class Task(dict): """A Task class that defines a task and its associated reward function / workflow.""" @@ -64,13 +81,13 @@ def to_workflow( @property def task_desc(self) -> Union[str, None]: prompt_key = self.format_args.prompt_key - return self.raw_task[prompt_key] if prompt_key in self.raw_task else None # type: ignore + return nested_query(prompt_key, self.raw_task) # type: ignore # Deprecated property, will be removed in the future @property def truth(self) -> Union[str, None]: response_key = self.format_args.response_key - return self.raw_task[response_key] if response_key in self.raw_task else None # type: ignore + return nested_query(response_key, self.raw_task) def to_dict(self) -> dict: return self.raw_task # type: ignore From ac1735943bce34cad83e54b56e6c84db8c158021 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=83=E7=AF=B1?= Date: Wed, 5 Nov 2025 16:45:58 +0800 Subject: [PATCH 04/12] Update BOTS in News --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 7e79f1d956..302099e4df 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni ## 🚀 News +* [2025-11] Introducing [BOTS](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots): dynamic RL task selection for efficient LLM fine-tuning ([paper](https://arxiv.org/pdf/2510.26374)). * [2025-10] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.2)] Trinity-RFT v0.3.2 released: bug fixes and advanced task selection & scheduling. * [2025-10] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.1)] Trinity-RFT v0.3.1 released: multi-stage training support, improved agentic RL examples, LoRA support, debug mode and new RL algorithms. * [2025-09] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.0)] Trinity-RFT v0.3.0 released: enhanced Buffer, FSDP2 & Megatron support, multi-modal models, and new RL algorithms/examples. From 275072232752490807a6f659315974fd33d23c54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=83=E7=AF=B1?= Date: Wed, 5 Nov 2025 16:46:19 +0800 Subject: [PATCH 05/12] Update README --- examples/bots/README.md | 50 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/examples/bots/README.md b/examples/bots/README.md index 00dce0768c..7c4e5afd28 100644 --- a/examples/bots/README.md +++ b/examples/bots/README.md @@ -6,16 +6,58 @@

-### Repository Status +### Overview -This repository hosts the upcoming Trinity version of our code, which is still under development and not yet released. +Agentic workflows -For complete reproduction of the results in our paper, please use the verl version available [here](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip). +BOTS operates in a continuous loop of task selection, model training, and posterior updating. +(1) **Selection**: Thompson sampling from the posterior beliefs selects a batch of tasks whose estimated success probabilities are near a target difficulty (e.g., $p^*=0.5$). +(2) **Training \& Evidence Collection**: The LLM is finetuned, yielding direct success/failure counts (_explicit evidence_) for the selected batch. +For unselected tasks, predicted counts (_implicit evidence_) are produced by a plug-in; We introduce an ultra-lightweight interpolation-based variant with negligible overhead. +(3) **Posterior Updating**: Explicit and implicit evidence are fused using our generalized Bayesian update rule. + +### Usage + +##### Step 1: Environment Preparation + +Ensure Trinity-RFT is well installed ([Installation Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)). No extra dependence is required. + +##### Step 2: Prepare Model & Dataset + +Download the model your want to train (e.g. [Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct)). + +Download the [GURU](https://huggingface.co/datasets/LLM360/guru-RL-92k) dataset. +Also refer to the [Data Preparation Guide](https://github.com/LLM360/Reasoning360?tab=readme-ov-file#data-preparation) and the [Tech Report](https://www.arxiv.org/pdf/2506.14965) provided by the LLM360 team. + +Remember to modify the model/data path in `bots.yaml` and `random.yaml` accordingly. + +##### Step 3: Training +Launch training by executing: +```bash +trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/plugins +``` +The improvement over random selection baseline can be stably obtained 🤖🤖🤖. + +Agentic workflows + +### Complete Reproduction + +For complete reproduction of the results in our paper, please use the verl version implementation available [here](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip). ### Citation If you find the repo helpful, please cite: ``` -@misc{shen2025botsunifiedframeworkbayesian, +@misc{TrinityRFT, + title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models}, + author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2505.17826}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2505.17826}, +} + +@misc{BOTS, title={BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning}, author={Qianli Shen and Daoyuan Chen and Yilun Huang and Zhenqing Ling and Yaliang Li and Bolin Ding and Jingren Zhou}, year={2025}, From 0a96529a90d1fcac63068ed6d433010e1ad2d088 Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Wed, 5 Nov 2025 18:22:18 +0800 Subject: [PATCH 06/12] examples/bots --- examples/bots/README.md | 14 +- examples/bots/bots.yaml | 2 +- .../bots/plugins/bots_math_boxed_reward.py | 3 +- .../bots/plugins/bots_math_boxed_workflow.py | 1 + examples/bots/plugins/bots_reward.py | 209 ++++++++++-------- examples/bots/random.yaml | 2 +- 6 files changed, 125 insertions(+), 106 deletions(-) diff --git a/examples/bots/README.md b/examples/bots/README.md index 7c4e5afd28..728570f4a9 100644 --- a/examples/bots/README.md +++ b/examples/bots/README.md @@ -10,10 +10,10 @@ Agentic workflows -BOTS operates in a continuous loop of task selection, model training, and posterior updating. -(1) **Selection**: Thompson sampling from the posterior beliefs selects a batch of tasks whose estimated success probabilities are near a target difficulty (e.g., $p^*=0.5$). -(2) **Training \& Evidence Collection**: The LLM is finetuned, yielding direct success/failure counts (_explicit evidence_) for the selected batch. -For unselected tasks, predicted counts (_implicit evidence_) are produced by a plug-in; We introduce an ultra-lightweight interpolation-based variant with negligible overhead. +BOTS operates in a continuous loop of task selection, model training, and posterior updating. +(1) **Selection**: Thompson sampling from the posterior beliefs selects a batch of tasks whose estimated success probabilities are near a target difficulty (e.g., $p^*=0.5$). +(2) **Training \& Evidence Collection**: The LLM is finetuned, yielding direct success/failure counts (_explicit evidence_) for the selected batch. +For unselected tasks, predicted counts (_implicit evidence_) are produced by a plug-in; We introduce an ultra-lightweight interpolation-based variant with negligible overhead. (3) **Posterior Updating**: Explicit and implicit evidence are fused using our generalized Bayesian update rule. ### Usage @@ -58,12 +58,12 @@ If you find the repo helpful, please cite: } @misc{BOTS, - title={BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning}, + title={BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning}, author={Qianli Shen and Daoyuan Chen and Yilun Huang and Zhenqing Ling and Yaliang Li and Bolin Ding and Jingren Zhou}, year={2025}, eprint={2510.26374}, archivePrefix={arXiv}, primaryClass={cs.AI}, - url={https://arxiv.org/abs/2510.26374}, + url={https://arxiv.org/abs/2510.26374}, } -``` \ No newline at end of file +``` diff --git a/examples/bots/bots.yaml b/examples/bots/bots.yaml index a5bd722de9..e3a948fee3 100644 --- a/examples/bots/bots.yaml +++ b/examples/bots/bots.yaml @@ -76,4 +76,4 @@ trainer: grad_clip: 1.0 use_dynamic_bsz: true max_token_len_per_gpu: 24576 - ulysses_sequence_parallel_size: 1 \ No newline at end of file + ulysses_sequence_parallel_size: 1 diff --git a/examples/bots/plugins/bots_math_boxed_reward.py b/examples/bots/plugins/bots_math_boxed_reward.py index e5b70ca637..335f72378d 100644 --- a/examples/bots/plugins/bots_math_boxed_reward.py +++ b/examples/bots/plugins/bots_math_boxed_reward.py @@ -5,6 +5,7 @@ from .bots_reward import compute_score + @REWARD_FUNCTIONS.register_module("bots_math_boxed_reward") class BOTSMathBoxedRewardFn(RewardFn): """A reward function that rewards for math task for BOTS.""" @@ -29,4 +30,4 @@ def __call__( # type: ignore if with_think and not validate_think_pattern(response): format_score = (format_score_coef or 0.1) * -1.0 - return {"accuracy": accuracy_score, "format_score": format_score} \ No newline at end of file + return {"accuracy": accuracy_score, "format_score": format_score} diff --git a/examples/bots/plugins/bots_math_boxed_workflow.py b/examples/bots/plugins/bots_math_boxed_workflow.py index 8c08d01874..537b77ae15 100644 --- a/examples/bots/plugins/bots_math_boxed_workflow.py +++ b/examples/bots/plugins/bots_math_boxed_workflow.py @@ -3,6 +3,7 @@ from .bots_math_boxed_reward import BOTSMathBoxedRewardFn + @WORKFLOWS.register_module("bots_math_boxed_workflow") class BOTSMathBoxedWorkflow(MathBoxedWorkflow): """A workflow for math tasks that give answers in boxed format for BOTS.""" diff --git a/examples/bots/plugins/bots_reward.py b/examples/bots/plugins/bots_reward.py index 6a7be7e692..61ea7789ed 100644 --- a/examples/bots/plugins/bots_reward.py +++ b/examples/bots/plugins/bots_reward.py @@ -1,23 +1,21 @@ # Adapted from Reasoning360: https://github.com/LLM360/Reasoning360/blob/main/verl/utils/reward_score/naive_dapo.py -import re -import signal -from typing import Optional, Union +import contextlib import math +import re from math import isclose -import contextlib +from typing import Optional, Union import sympy from pylatexenc import latex2text +from sympy import N, simplify from sympy.parsing import sympy_parser from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr -from sympy import N, simplify -import os - from verl.utils.py_functional import timeout_limit -def handle_base(x) -> str: + +def handle_base(x): if isinstance(x, str) and "_" in x: # Due to base x = x.split("_")[0] @@ -49,13 +47,16 @@ def handle_pi(string, pi): return string + def normalize(answer, pi) -> str: # checking if answer is $ and removing $ in that case to compare if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)): return answer[1:] # checking if answer is % or \\% and removing % - if isinstance(answer, str) and (bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer))): + if isinstance(answer, str) and ( + bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) + ): return answer.replace("\\%", "").replace("%", "") # handle base @@ -66,6 +67,7 @@ def normalize(answer, pi) -> str: return answer + def is_digit(s): try: if "{,}" in str(s): @@ -75,9 +77,10 @@ def is_digit(s): num = float(str(s).replace(",", "")) return True, num except ValueError: - return False, None + return False, 0.0 + -def format_intervals(prediction): +def format_intervals(prediction) -> str: patterns = { "Interval(": r"^Interval\((.*)\)$", "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", @@ -99,7 +102,7 @@ def format_intervals(prediction): elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) return f"({inner_content})" - return prediction + return str(prediction) def symbolic_equal(a, b, tolerance, timeout=10.0): @@ -140,7 +143,7 @@ def _parse(s): return False -def math_equal( +def math_equal( # noqa prediction: Union[bool, float, str], reference: Union[float, str], include_percentage: bool = True, @@ -172,10 +175,14 @@ def math_equal( prediction = is_digit(prediction)[1] reference = is_digit(reference)[1] # number questions - gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] + gt_result = ( + [float(reference) / 100.0, float(reference), float(reference) * 100.0] + if include_percentage + else [float(reference)] + ) for item in gt_result: try: - if isclose(item, prediction, rel_tol=tolerance): + if isclose(float(item), float(prediction), rel_tol=tolerance): return True except Exception: continue @@ -194,7 +201,11 @@ def math_equal( prediction = format_intervals(prediction) pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")): + if ( + prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(") + ) or ( + prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") + ): pred_str = pred_str.strip("[]()") ref_str = ref_str.strip("[]()") for s in ["{", "}", "(", ")"]: @@ -204,10 +215,22 @@ def math_equal( return True ## [a, b] vs. [c, d], return a==c and b==d - if prediction and reference and prediction[0] in "([" and prediction[-1] in ")]" and prediction[0] == reference[0] and prediction[-1] == reference[-1]: + if ( + prediction + and reference + and prediction[0] in "([" + and prediction[-1] in ")]" + and prediction[0] == reference[0] + and prediction[-1] == reference[-1] + ): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]): + if len(pred_parts) == len(ref_parts) and all( + [ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts) + ] + ): return True if "," in prediction and "," in reference: @@ -215,13 +238,25 @@ def math_equal( ref_parts = [item.strip() for item in reference.split(",")] if len(pred_parts) == len(ref_parts): - return bool(all([math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) for i in range(len(pred_parts))])) + return bool( + all( + [ + math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) + for i in range(len(pred_parts)) + ] + ) + ) # if we have point == tuple of values if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts) and all([math_equal(pred_pt, ref_pt, include_percentage, tolerance) for pred_pt, ref_pt in zip(pred_parts, ref_parts)]): + if len(pred_parts) == len(ref_parts) and all( + [ + math_equal(pred_pt, ref_pt, include_percentage, tolerance) + for pred_pt, ref_pt in zip(pred_parts, ref_parts) + ] + ): return True # if reference is a matrix @@ -229,7 +264,12 @@ def math_equal( try: pred_matrix = parse_expr(prediction) ref_matrix_items = reference.split()[1:-1:2] - if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]): + if len(pred_matrix) == len(ref_matrix_items) and all( + [ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix) + ] + ): return True except Exception: pass @@ -238,32 +278,28 @@ def math_equal( try: pred_matrix = eval(prediction) # ref_matrix_items = reference.split()[1:-1:2] - ref_matrix_items = reference.lstrip("\\begin{pmatrix}").lstrip("\begin{pmatrix}").rstrip("\\end{pmatrix}").rstrip("\\end{pmatrix}") # noqa: B005 + ref_matrix_items = ( + reference.lstrip("\\begin{pmatrix}") + .lstrip("\begin{pmatrix}") + .rstrip("\\end{pmatrix}") + .rstrip("\\end{pmatrix}") + ) # noqa: B005 ref_matrix_items = ref_matrix_items.split("\\") - ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] - if len(pred_matrix) == len(ref_matrix_items) and all([math_equal(pred, ref, include_percentage, tolerance) for ref, pred in zip(ref_matrix_items, pred_matrix)]): + # ref_matrix_items = [ + # row.split("&") if "&" in row else row for row in ref_matrix_items + # ] + if len(pred_matrix) == len(ref_matrix_items) and all( + [ + math_equal(pred, ref, include_percentage, tolerance) + for ref, pred in zip(ref_matrix_items, pred_matrix) + ] + ): return True except Exception: pass return symbolic_equal(prediction, reference, tolerance, timeout) -class timeout: - - def __init__(self, seconds=1, error_message="Timeout"): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - # Constants for normalization SUBSTITUTIONS = [ @@ -367,43 +403,19 @@ def normalize_final_answer(final_answer: str) -> str: # sympy might hang -- we don't care about trying to be lenient in these cases BAD_SUBSTRINGS = ["^{", "^("] -BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] +BAD_REGEXES = [r"\^[0-9]+\^", r"\^[0-9][0-9]+"] TUPLE_CHARS = "()[]" -def timeout(timeout_seconds: int = 8): - if os.name == "posix": - import signal - - def decorator(func): - - def handler(signum, frame): - raise TimeoutError("Operation timed out!") - - def wrapper(*args, **kwargs): - old_handler = signal.getsignal(signal.SIGALRM) - signal.signal(signal.SIGALRM, handler) - signal.alarm(timeout_seconds) - - try: - return func(*args, **kwargs) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) - - return wrapper - - return decorator - else: - raise NotImplementedError(f"Unsupported OS: {os.name}") - - def _sympy_parse(expr: str): """Parses an expression with sympy.""" py_expr = expr.replace("^", "**") return sympy_parser.parse_expr( py_expr, - transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), + transformations=( + sympy_parser.standard_transformations + + (sympy_parser.implicit_multiplication_application,) + ), ) @@ -436,7 +448,7 @@ def _is_float(num: str) -> bool: def _is_int(x: float) -> bool: try: return abs(x - int(round(x))) <= 1e-7 - except: + except Exception: return False @@ -447,13 +459,12 @@ def _is_frac(expr: str) -> bool: def _str_is_int(x: str) -> bool: try: x = _strip_properly_formatted_commas(x) - x = float(x) - return abs(x - int(round(x))) <= 1e-7 - except: + return abs(float(x) - int(round(float(x)))) <= 1e-7 + except Exception: return False -def _str_to_int(x: str) -> bool: +def _str_to_int(x: str) -> int: x = x.replace(",", "") x = float(x) return int(x) @@ -465,13 +476,13 @@ def _inject_implicit_mixed_number(step: str): e.g. 7 3/4 => 7+3/4 """ p1 = re.compile("([0-9]) +([0-9])") - step = p1.sub("\\1+\\2", step) ## implicit mults + step = p1.sub("\\1+\\2", step) # implicit mults return step def _strip_properly_formatted_commas(expr: str): # We want to be careful because we don't want to strip tuple commas - p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") + p1 = re.compile(r"(\d)(,)(\d\d\d)($|\D)") while True: next_expr = p1.sub("\\1\\3\\4", expr) if next_expr == expr: @@ -486,7 +497,7 @@ def _normalize(expr: str) -> str: return None # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", expr) + m = re.search(r"^\\\\text\{(?P.+?)\}$", expr) if m is not None: expr = m.group("text") @@ -520,8 +531,8 @@ def _normalize(expr: str) -> str: "yard", "liter", ]: - expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub(f"\^ *\\\\circ", "", expr) + expr = re.sub(r"{}(es)?(s)? *(\^[0-9]+)?".format(unit), "", expr) + expr = re.sub(r"\^ *\\\\circ", "", expr) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] @@ -532,7 +543,7 @@ def _normalize(expr: str) -> str: if "\\" in expr: try: expr = _parse_latex(expr) - except: + except Exception: pass # edge case with mixed numbers and negative signs @@ -582,7 +593,7 @@ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): simplified = sympy.simplify(sympy_diff) if simplified == 0: are_equal = True - except: + except Exception: pass return are_equal @@ -594,13 +605,18 @@ def split_tuple(expr: str): expr = _strip_properly_formatted_commas(expr) if len(expr) == 0: return [] - if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and - all([ch not in expr[1:-1] for ch in TUPLE_CHARS])): + if ( + len(expr) > 2 + and expr[0] in TUPLE_CHARS + and expr[-1] in TUPLE_CHARS + and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) + ): elems = [elem.strip() for elem in expr[1:-1].split(",")] else: elems = [expr] return elems + def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] @@ -737,9 +753,9 @@ def _strip_string(string): return string -def normalize_answer(answer: Optional[str]) -> Optional[str]: +def normalize_answer(answer: Optional[str]) -> str: if answer is None: - return None + return "" answer = answer.strip() try: # Remove enclosing `\text{}`. @@ -750,6 +766,7 @@ def normalize_answer(answer: Optional[str]) -> Optional[str]: except: # noqa: E722 return answer + def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: """ The answer will be considered correct if: @@ -782,8 +799,10 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: ground_truth_elems = split_tuple(ground_truth_normalized) given_elems = split_tuple(given_normalized) - if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or - ground_truth_normalized[-1] != given_normalized[-1]): + if len(ground_truth_elems) > 1 and ( + ground_truth_normalized[0] != given_normalized[0] + or ground_truth_normalized[-1] != given_normalized[-1] + ): is_correct = False elif len(ground_truth_elems) != len(given_elems): is_correct = False @@ -831,7 +850,7 @@ def _last_boxed_only_string(string): if left_brace_idx is None or right_brace_idx is None: return None - return string[left_brace_idx + 1:right_brace_idx].strip() + return string[left_brace_idx + 1 : right_brace_idx].strip() def match_answer(response): @@ -847,11 +866,7 @@ def match_answer(response): return is_matched, response -import math - - -def compute_score(solution_str: str, - ground_truth: str) -> float: +def compute_score(solution_str: str, ground_truth: Optional[str]) -> float: """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions Args: @@ -879,14 +894,16 @@ def compute_score(solution_str: str, if "\\pi" in extracted_model_output or "\\pi" in ground_truth: equivs = [] for pi in [math.pi, 3.14]: - equivs.append(math_equal(extracted_model_output, ground_truth, tiemout=True, pi=pi)) + equivs.append( + math_equal(extracted_model_output, ground_truth, timeout=True, pi=pi) + ) correct = any(equivs) else: correct = math_equal(extracted_model_output, ground_truth, timeout=True) - except: + except Exception: correct = False # reward = 1.0 if correct else -1.0 - reward = 1.0 if correct else 0. + reward = 1.0 if correct else 0.0 return reward diff --git a/examples/bots/random.yaml b/examples/bots/random.yaml index e5ab442def..4fa2e3978a 100644 --- a/examples/bots/random.yaml +++ b/examples/bots/random.yaml @@ -64,4 +64,4 @@ trainer: grad_clip: 1.0 use_dynamic_bsz: true max_token_len_per_gpu: 24576 - ulysses_sequence_parallel_size: 1 \ No newline at end of file + ulysses_sequence_parallel_size: 1 From 36ee2ac1a0690cb71d2aa1cf23c7260bfb3d7022 Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Thu, 6 Nov 2025 15:25:22 +0800 Subject: [PATCH 07/12] examples/bots --- README.md | 16 ++--- README_zh.md | 13 ++-- docs/sphinx_doc/source/main.md | 2 +- docs/sphinx_doc/source_zh/main.md | 12 ++-- examples/bots/README.md | 10 +-- examples/bots/README_zh.md | 68 +++++++++++++++++++ .../bots/plugins/bots_math_boxed_workflow.py | 17 ----- .../bots_math_boxed_reward.py | 0 .../bots/workflow/bots_math_boxed_workflow.py | 46 +++++++++++++ .../bots/{plugins => workflow}/bots_reward.py | 0 trinity/common/workflows/workflow.py | 29 -------- 11 files changed, 142 insertions(+), 71 deletions(-) create mode 100644 examples/bots/README_zh.md delete mode 100644 examples/bots/plugins/bots_math_boxed_workflow.py rename examples/bots/{plugins => workflow}/bots_math_boxed_reward.py (100%) create mode 100644 examples/bots/workflow/bots_math_boxed_workflow.py rename examples/bots/{plugins => workflow}/bots_reward.py (100%) diff --git a/README.md b/README.md index 302099e4df..5d90d647af 100644 --- a/README.md +++ b/README.md @@ -67,13 +67,13 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni ## 🔨 Tutorials and Guidelines -| Category | Tutorial / Guideline | -| --- | --- | -| Run diverse RFT modes | + [Quick example: GRPO on GSM8k](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [Fully asynchronous RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [Offline learning by DPO or SFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | -| Multi-step agentic scenarios | + [Concatenated multi-turn workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [General multi-step workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [ReAct workflow with an agent framework](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | -| Advanced data pipelines | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | -| Algorithm development / research | + [RL algorithm development with Trinity-RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) ([paper](https://arxiv.org/pdf/2508.11408))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([paper](https://arxiv.org/abs/2509.24203))| -| Going deeper into Trinity-RFT | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | +| Category | Tutorial / Guideline | +| --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Run diverse RFT modes | + [Quick example: GRPO on GSM8k](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [Fully asynchronous RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [Offline learning by DPO or SFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | +| Multi-step agentic scenarios | + [Concatenated multi-turn workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [General multi-step workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [ReAct workflow with an agent framework](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | +| Advanced data pipelines | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([paper](https://arxiv.org/pdf/2510.26374))
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html)
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay) | +| Algorithm development / research | + [RL algorithm development with Trinity-RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) ([paper](https://arxiv.org/pdf/2508.11408))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([paper](https://arxiv.org/abs/2509.24203)) | +| Going deeper into Trinity-RFT | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | > [!NOTE] @@ -82,7 +82,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni ## 🚀 News -* [2025-11] Introducing [BOTS](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots): dynamic RL task selection for efficient LLM fine-tuning ([paper](https://arxiv.org/pdf/2510.26374)). +* [2025-11] Introducing [BOTS](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots): online RL task selection for efficient LLM fine-tuning ([paper](https://arxiv.org/pdf/2510.26374)). * [2025-10] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.2)] Trinity-RFT v0.3.2 released: bug fixes and advanced task selection & scheduling. * [2025-10] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.1)] Trinity-RFT v0.3.1 released: multi-stage training support, improved agentic RL examples, LoRA support, debug mode and new RL algorithms. * [2025-09] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.0)] Trinity-RFT v0.3.0 released: enhanced Buffer, FSDP2 & Megatron support, multi-modal models, and new RL algorithms/examples. diff --git a/README_zh.md b/README_zh.md index a54b2389b6..670b58c739 100644 --- a/README_zh.md +++ b/README_zh.md @@ -67,13 +67,13 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF ## 🔨 教程与指南 -| Category | Tutorial / Guideline | -| --- | --- | -| 运行各种 RFT 模式 | + [快速开始:在 GSM8k 上运行 GRPO](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [全异步 RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [通过 DPO 或 SFT 进行离线学习](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | -| 多轮智能体场景 | + [拼接多轮任务](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [通用多轮任务](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [调用智能体框架中的 ReAct 工作流](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | -| 数据流水线进阶能力 | + [Rollout 任务混合与选取](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | +| Category | Tutorial / Guideline | +| --- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 运行各种 RFT 模式 | + [快速开始:在 GSM8k 上运行 GRPO](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [全异步 RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [通过 DPO 或 SFT 进行离线学习](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | +| 多轮智能体场景 | + [拼接多轮任务](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [通用多轮任务](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [调用智能体框架中的 ReAct 工作流](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | +| 数据流水线进阶能力 | + [Rollout 任务混合与选取](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [在线任务课程](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([论文](https://arxiv.org/pdf/2510.26374))
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | | RL 算法开发/研究 | + [使用 Trinity-RFT 进行 RL 算法开发](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) ([论文](https://arxiv.org/pdf/2508.11408))
+ 不可验证的领域:[RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [可训练 RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [研究项目: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([论文](https://arxiv.org/abs/2509.24203)) | -| 深入认识 Trinity-RFT | + [完整配置指南](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [用于快速验证和实验的 Benchmark 工具](./benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | +| 深入认识 Trinity-RFT | + [完整配置指南](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [用于快速验证和实验的 Benchmark 工具](./benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | > [!NOTE] @@ -83,6 +83,7 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF ## 🚀 新闻 +* [2025-11] 推出 [BOTS](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots):在线RL任务选择,实现高效LLM微调([论文](https://arxiv.org/pdf/2510.26374))。 * [2025-10] [[发布说明](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.2)] Trinity-RFT v0.3.2 发布:修复若干 Bug 并支持进阶的任务选择和调度。 * [2025-10] [[发布说明](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.1)] Trinity-RFT v0.3.1 发布:多阶段训练支持、改进的智能体 RL 示例、LoRA 支持、调试模式和全新 RL 算法。 * [2025-09] [[发布说明](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.0)] Trinity-RFT v0.3.0 发布:增强的 Buffer、FSDP2 & Megatron 支持,多模态模型,以及全新 RL 算法/示例。 diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 9c6857e237..ae2f91dc65 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -52,7 +52,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni | --- | --- | | Run diverse RFT modes | + [Quick example: GRPO on GSM8k](/tutorial/example_reasoning_basic.md)
+ [Off-policy RFT](/tutorial/example_reasoning_advanced.md)
+ [Fully asynchronous RFT](/tutorial/example_async_mode.md)
+ [Offline learning by DPO or SFT](/tutorial/example_dpo.md) | | Multi-step agentic scenarios | + [Concatenated multi-turn workflow](/tutorial/example_multi_turn.md)
+ [General multi-step workflow](/tutorial/example_step_wise.md)
+ [ReAct workflow with an agent framework](/tutorial/example_react.md) | -| Advanced data pipelines | + [Rollout task mixing and selection](/tutorial/develop_selector.md)
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](/tutorial/example_data_functionalities.md) | +| Advanced data pipelines | + [Rollout task mixing and selection](/tutorial/develop_selector.md)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([paper](https://arxiv.org/pdf/2510.26374))
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](/tutorial/example_data_functionalities.md) | | Algorithm development / research | + [RL algorithm development with Trinity-RFT](/tutorial/example_mix_algo.md) ([paper](https://arxiv.org/pdf/2508.11408))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([paper](https://arxiv.org/abs/2509.24203))| | Going deeper into Trinity-RFT | + [Full configurations](/tutorial/trinity_configs.md)
+ [Benchmark toolkit for quick verification and experimentation](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [Understand the coordination between explorer and trainer](/tutorial/synchronizer.md) | diff --git a/docs/sphinx_doc/source_zh/main.md b/docs/sphinx_doc/source_zh/main.md index e982516020..44321cea9c 100644 --- a/docs/sphinx_doc/source_zh/main.md +++ b/docs/sphinx_doc/source_zh/main.md @@ -48,13 +48,13 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF ## 🔨 教程与指南 -| Category | Tutorial / Guideline | -| --- | --- | -| 运行各种 RFT 模式 | + [快速开始:在 GSM8k 上运行 GRPO](/tutorial/example_reasoning_basic.md)
+ [Off-policy RFT](/tutorial/example_reasoning_advanced.md)
+ [全异步 RFT](/tutorial/example_async_mode.md)
+ [通过 DPO 或 SFT 进行离线学习](/tutorial/example_dpo.md) | -| 多轮智能体场景 | + [拼接多轮任务](/tutorial/example_multi_turn.md)
+ [通用多轮任务](/tutorial/example_step_wise.md)
+ [调用智能体框架中的 ReAct 工作流](/tutorial/example_react.md) | -| 数据流水线进阶能力 | + [Rollout 任务混合与选取](/tutorial/develop_selector.md)
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](/tutorial/example_data_functionalities.md) | +| Category | Tutorial / Guideline | +| --- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 运行各种 RFT 模式 | + [快速开始:在 GSM8k 上运行 GRPO](/tutorial/example_reasoning_basic.md)
+ [Off-policy RFT](/tutorial/example_reasoning_advanced.md)
+ [全异步 RFT](/tutorial/example_async_mode.md)
+ [通过 DPO 或 SFT 进行离线学习](/tutorial/example_dpo.md) | +| 多轮智能体场景 | + [拼接多轮任务](/tutorial/example_multi_turn.md)
+ [通用多轮任务](/tutorial/example_step_wise.md)
+ [调用智能体框架中的 ReAct 工作流](/tutorial/example_react.md) | +| 数据流水线进阶能力 | + [Rollout 任务混合与选取](/tutorial/develop_selector.md)
+ [在线任务课程](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([论文](https://arxiv.org/pdf/2510.26374))
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](/tutorial/example_data_functionalities.md) | | RL 算法开发/研究 | + [使用 Trinity-RFT 进行 RL 算法开发](/tutorial/example_mix_algo.md) ([论文](https://arxiv.org/pdf/2508.11408))
+ 不可验证的领域:[RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [可训练 RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [研究项目: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([论文](https://arxiv.org/abs/2509.24203)) | -| 深入认识 Trinity-RFT | + [完整配置指南](/tutorial/trinity_configs.md)
+ [用于快速验证和实验的 Benchmark 工具](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](/tutorial/synchronizer.md) | +| 深入认识 Trinity-RFT | + [完整配置指南](/tutorial/trinity_configs.md)
+ [用于快速验证和实验的 Benchmark 工具](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](/tutorial/synchronizer.md) | diff --git a/examples/bots/README.md b/examples/bots/README.md index 728570f4a9..63fd2345ca 100644 --- a/examples/bots/README.md +++ b/examples/bots/README.md @@ -8,6 +8,8 @@ ### Overview +BOTS is a unified framework for **B**ayesian **O**nline **T**ask **S**election in LLM reinforcement finetuning. + Agentic workflows BOTS operates in a continuous loop of task selection, model training, and posterior updating. @@ -22,9 +24,9 @@ For unselected tasks, predicted counts (_implicit evidence_) are produced by a p Ensure Trinity-RFT is well installed ([Installation Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)). No extra dependence is required. -##### Step 2: Prepare Model & Dataset +##### Step 2: Model & Dataset Preparation -Download the model your want to train (e.g. [Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct)). +Download the model your want to train (e.g., [Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct)). Download the [GURU](https://huggingface.co/datasets/LLM360/guru-RL-92k) dataset. Also refer to the [Data Preparation Guide](https://github.com/LLM360/Reasoning360?tab=readme-ov-file#data-preparation) and the [Tech Report](https://www.arxiv.org/pdf/2506.14965) provided by the LLM360 team. @@ -34,11 +36,11 @@ Remember to modify the model/data path in `bots.yaml` and `random.yaml` accordin ##### Step 3: Training Launch training by executing: ```bash -trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/plugins +trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/workflow ``` The improvement over random selection baseline can be stably obtained 🤖🤖🤖. -Agentic workflows +Agentic workflows ### Complete Reproduction diff --git a/examples/bots/README_zh.md b/examples/bots/README_zh.md new file mode 100644 index 0000000000..5b5a0c56ec --- /dev/null +++ b/examples/bots/README_zh.md @@ -0,0 +1,68 @@ +# 🤖🤖🤖 BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning + +

+ + Paper + +

+ +### 概览 + +BOTS是一个统一的LLM强化微调的**贝叶斯在线任务选择**框架。 + +Agentic workflows + +BOTS 以任务选择、模型训练和后验概率更新的连续循环运行。 +(1) **任务选择**:从后验概率信念中采用汤普森采样选择一批估计成功概率接近目标难度(例如,$p^*=0.5$)的任务。 +(2) **模型训练和证据收集**:对 LLM 模型进行微调,从而获得所选任务批次的直接成功/失败计数(显式证据)。 +对于未选择的任务,预测计数(隐式证据)由插件生成;我们引入了一种基于插值的超轻量级变体,其开销可忽略不计。 +(3) **后验概率更新**:使用我们提出的广义贝叶斯更新规则融合显式和隐式证据。 +### 使用 + +##### 第一步:环境准备 + +确保Trinity-RFT安装好了([安装指南](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html))。不需要额外的依赖。 + +##### 第二步:模型和数据准备 + +下载你想要训练的模型(例如:[Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct))。 +下载[GURU](https://huggingface.co/datasets/LLM360/guru-RL-92k)数据集, +请参考LLM360提供的[数据准备指南](https://github.com/LLM360/Reasoning360?tab=readme-ov-file#data-preparation)和[技术报告](https://www.arxiv.org/pdf/2506.14965)。 +请修改`bots.yaml`和`random.yaml`中相应的模型/数据路径。 + +##### 第三步:训练 +执行以下命令启动训练: +```bash +trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/workflow +``` +相比随机选择基线的提升可以被稳定地观察到🤖🤖🤖. + +Agentic workflows + +### 完整复现 + +想要完整复线我们论文中的结果,请从[这里](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip)下载verl版本的框架。 + +### 引用 +如果你觉得这个代码仓库有帮助,请引用: +``` +@misc{TrinityRFT, + title={Trinity-RFT: A General-Purpose and Unified Framework for Reinforcement Fine-Tuning of Large Language Models}, + author={Xuchen Pan and Yanxi Chen and Yushuo Chen and Yuchang Sun and Daoyuan Chen and Wenhao Zhang and Yuexiang Xie and Yilun Huang and Yilei Zhang and Dawei Gao and Weijie Shi and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2505.17826}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2505.17826}, +} + +@misc{BOTS, + title={BOTS: A Unified Framework for Bayesian Online Task Selection in LLM Reinforcement Finetuning}, + author={Qianli Shen and Daoyuan Chen and Yilun Huang and Zhenqing Ling and Yaliang Li and Bolin Ding and Jingren Zhou}, + year={2025}, + eprint={2510.26374}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2510.26374}, +} +``` diff --git a/examples/bots/plugins/bots_math_boxed_workflow.py b/examples/bots/plugins/bots_math_boxed_workflow.py deleted file mode 100644 index 537b77ae15..0000000000 --- a/examples/bots/plugins/bots_math_boxed_workflow.py +++ /dev/null @@ -1,17 +0,0 @@ -from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task -from trinity.common.workflows.workflow import WORKFLOWS - -from .bots_math_boxed_reward import BOTSMathBoxedRewardFn - - -@WORKFLOWS.register_module("bots_math_boxed_workflow") -class BOTSMathBoxedWorkflow(MathBoxedWorkflow): - """A workflow for math tasks that give answers in boxed format for BOTS.""" - - def reset(self, task: Task): - super().reset(task) - self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) - - def format_messages(self): - # the prompts are already in message format - return self.task_desc diff --git a/examples/bots/plugins/bots_math_boxed_reward.py b/examples/bots/workflow/bots_math_boxed_reward.py similarity index 100% rename from examples/bots/plugins/bots_math_boxed_reward.py rename to examples/bots/workflow/bots_math_boxed_reward.py diff --git a/examples/bots/workflow/bots_math_boxed_workflow.py b/examples/bots/workflow/bots_math_boxed_workflow.py new file mode 100644 index 0000000000..89f2f65042 --- /dev/null +++ b/examples/bots/workflow/bots_math_boxed_workflow.py @@ -0,0 +1,46 @@ +from typing import Union + +from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task +from trinity.common.workflows.workflow import WORKFLOWS + +from .bots_math_boxed_reward import BOTSMathBoxedRewardFn + + +@WORKFLOWS.register_module("bots_math_boxed_workflow") +class BOTSMathBoxedWorkflow(MathBoxedWorkflow): + """A workflow for math tasks that give answers in boxed format for BOTS.""" + + def reset(self, task: Task): + super().reset(task) + self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) + + def format_messages(self): + # the prompts are already in message format + return self.task_desc + + @property + def task_desc(self) -> Union[str, None]: + prompt_key = self.format_args.prompt_key + return nested_query(prompt_key, self.raw_task) # type: ignore + + @property + def truth(self) -> Union[str, None]: + response_key = self.format_args.response_key + return nested_query(response_key, self.raw_task) + + +def nested_query(query_key: str, query_obj: Union[dict, None]): + # support nested query for a dict given query_keys split by '.' + if query_obj is None: + return None + if "." in query_key: + query_keys = query_key.split(".") + else: + query_keys = [query_key] + ret = query_obj + for key in query_keys: + if isinstance(ret, dict) and key in ret: + ret = ret[key] + else: + return None + return ret \ No newline at end of file diff --git a/examples/bots/plugins/bots_reward.py b/examples/bots/workflow/bots_reward.py similarity index 100% rename from examples/bots/plugins/bots_reward.py rename to examples/bots/workflow/bots_reward.py diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 90d19a8784..38881f22a7 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -19,23 +19,6 @@ WORKFLOWS = Registry("workflows") -def nested_query(query_key: str, query_obj: Union[dict, None]): - # support nested query for a dict given query_keys split by '.' - if query_obj is None: - return None - if "." in query_key: - query_keys = query_key.split(".") - else: - query_keys = [query_key] - ret = query_obj - for key in query_keys: - if isinstance(ret, dict) and key in ret: - ret = ret[key] - else: - return None - return ret - - @dataclass class Task(dict): """A Task class that defines a task and its associated reward function / workflow.""" @@ -77,18 +60,6 @@ def to_workflow( auxiliary_models=auxiliary_models, ) - # Deprecated property, will be removed in the future - @property - def task_desc(self) -> Union[str, None]: - prompt_key = self.format_args.prompt_key - return nested_query(prompt_key, self.raw_task) # type: ignore - - # Deprecated property, will be removed in the future - @property - def truth(self) -> Union[str, None]: - response_key = self.format_args.response_key - return nested_query(response_key, self.raw_task) - def to_dict(self) -> dict: return self.raw_task # type: ignore From 8c83c677be0b6d7465d9c5082accb8d6794e66c2 Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Thu, 6 Nov 2025 16:05:53 +0800 Subject: [PATCH 08/12] examples/bots --- examples/bots/README_zh.md | 10 +++++----- examples/bots/workflow/bots_math_boxed_workflow.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/bots/README_zh.md b/examples/bots/README_zh.md index 5b5a0c56ec..7649c66a50 100644 --- a/examples/bots/README_zh.md +++ b/examples/bots/README_zh.md @@ -12,10 +12,10 @@ BOTS是一个统一的LLM强化微调的**贝叶斯在线任务选择**框架。 Agentic workflows -BOTS 以任务选择、模型训练和后验概率更新的连续循环运行。 -(1) **任务选择**:从后验概率信念中采用汤普森采样选择一批估计成功概率接近目标难度(例如,$p^*=0.5$)的任务。 +BOTS 以任务选择、模型训练和后验概率更新的连续循环运行。 +(1) **任务选择**:从后验概率信念中采用汤普森采样选择一批估计成功概率接近目标难度(例如,$p^*=0.5$)的任务。 (2) **模型训练和证据收集**:对 LLM 模型进行微调,从而获得所选任务批次的直接成功/失败计数(显式证据)。 -对于未选择的任务,预测计数(隐式证据)由插件生成;我们引入了一种基于插值的超轻量级变体,其开销可忽略不计。 +对于未选择的任务,预测计数(隐式证据)由插件生成;我们引入了一种基于插值的超轻量级变体,其开销可忽略不计。 (3) **后验概率更新**:使用我们提出的广义贝叶斯更新规则融合显式和隐式证据。 ### 使用 @@ -25,9 +25,9 @@ BOTS 以任务选择、模型训练和后验概率更新的连续循环运行。 ##### 第二步:模型和数据准备 -下载你想要训练的模型(例如:[Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct))。 +下载你想要训练的模型(例如:[Qwen2.5-1.5B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct))。 下载[GURU](https://huggingface.co/datasets/LLM360/guru-RL-92k)数据集, -请参考LLM360提供的[数据准备指南](https://github.com/LLM360/Reasoning360?tab=readme-ov-file#data-preparation)和[技术报告](https://www.arxiv.org/pdf/2506.14965)。 +请参考LLM360提供的[数据准备指南](https://github.com/LLM360/Reasoning360?tab=readme-ov-file#data-preparation)和[技术报告](https://www.arxiv.org/pdf/2506.14965)。 请修改`bots.yaml`和`random.yaml`中相应的模型/数据路径。 ##### 第三步:训练 diff --git a/examples/bots/workflow/bots_math_boxed_workflow.py b/examples/bots/workflow/bots_math_boxed_workflow.py index 89f2f65042..1596aa87f6 100644 --- a/examples/bots/workflow/bots_math_boxed_workflow.py +++ b/examples/bots/workflow/bots_math_boxed_workflow.py @@ -43,4 +43,4 @@ def nested_query(query_key: str, query_obj: Union[dict, None]): ret = ret[key] else: return None - return ret \ No newline at end of file + return ret From d46a62d77bd0443648de12e714b7749de16fa976 Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Thu, 6 Nov 2025 16:19:13 +0800 Subject: [PATCH 09/12] update README and doc --- README_zh.md | 2 +- docs/sphinx_doc/source_zh/main.md | 2 +- examples/bots/README_zh.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README_zh.md b/README_zh.md index 670b58c739..3644b3330b 100644 --- a/README_zh.md +++ b/README_zh.md @@ -71,7 +71,7 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF | --- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 运行各种 RFT 模式 | + [快速开始:在 GSM8k 上运行 GRPO](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [全异步 RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [通过 DPO 或 SFT 进行离线学习](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | | 多轮智能体场景 | + [拼接多轮任务](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [通用多轮任务](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [调用智能体框架中的 ReAct 工作流](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | -| 数据流水线进阶能力 | + [Rollout 任务混合与选取](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [在线任务课程](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([论文](https://arxiv.org/pdf/2510.26374))
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | +| 数据流水线进阶能力 | + [Rollout 任务混合与选取](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [在线任务选择](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([论文](https://arxiv.org/pdf/2510.26374))
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | | RL 算法开发/研究 | + [使用 Trinity-RFT 进行 RL 算法开发](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) ([论文](https://arxiv.org/pdf/2508.11408))
+ 不可验证的领域:[RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [可训练 RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [研究项目: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([论文](https://arxiv.org/abs/2509.24203)) | | 深入认识 Trinity-RFT | + [完整配置指南](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [用于快速验证和实验的 Benchmark 工具](./benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | diff --git a/docs/sphinx_doc/source_zh/main.md b/docs/sphinx_doc/source_zh/main.md index 44321cea9c..fa4daa40c7 100644 --- a/docs/sphinx_doc/source_zh/main.md +++ b/docs/sphinx_doc/source_zh/main.md @@ -52,7 +52,7 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF | --- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 运行各种 RFT 模式 | + [快速开始:在 GSM8k 上运行 GRPO](/tutorial/example_reasoning_basic.md)
+ [Off-policy RFT](/tutorial/example_reasoning_advanced.md)
+ [全异步 RFT](/tutorial/example_async_mode.md)
+ [通过 DPO 或 SFT 进行离线学习](/tutorial/example_dpo.md) | | 多轮智能体场景 | + [拼接多轮任务](/tutorial/example_multi_turn.md)
+ [通用多轮任务](/tutorial/example_step_wise.md)
+ [调用智能体框架中的 ReAct 工作流](/tutorial/example_react.md) | -| 数据流水线进阶能力 | + [Rollout 任务混合与选取](/tutorial/develop_selector.md)
+ [在线任务课程](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([论文](https://arxiv.org/pdf/2510.26374))
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](/tutorial/example_data_functionalities.md) | +| 数据流水线进阶能力 | + [Rollout 任务混合与选取](/tutorial/develop_selector.md)
+ [在线任务选择](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([论文](https://arxiv.org/pdf/2510.26374))
+ [经验回放](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [高级数据处理能力 & Human-in-the-loop](/tutorial/example_data_functionalities.md) | | RL 算法开发/研究 | + [使用 Trinity-RFT 进行 RL 算法开发](/tutorial/example_mix_algo.md) ([论文](https://arxiv.org/pdf/2508.11408))
+ 不可验证的领域:[RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [可训练 RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [研究项目: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([论文](https://arxiv.org/abs/2509.24203)) | | 深入认识 Trinity-RFT | + [完整配置指南](/tutorial/trinity_configs.md)
+ [用于快速验证和实验的 Benchmark 工具](https://github.com/modelscope/Trinity-RFT/tree/main/benchmark/README.md)
+ [理解 explorer-trainer 同步逻辑](/tutorial/synchronizer.md) | diff --git a/examples/bots/README_zh.md b/examples/bots/README_zh.md index 7649c66a50..9f20ba0e4b 100644 --- a/examples/bots/README_zh.md +++ b/examples/bots/README_zh.md @@ -41,7 +41,7 @@ trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/workflow ### 完整复现 -想要完整复线我们论文中的结果,请从[这里](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip)下载verl版本的框架。 +想要完整复现我们论文中的结果,请从[这里](https://dail-wlcb.oss-cn-wulanchabu.aliyuncs.com/public/BOTS_verl_version.zip)下载verl版本的框架。 ### 引用 如果你觉得这个代码仓库有帮助,请引用: From b74ef8935afc54da745c582d7a7b6bb908f41f4b Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Thu, 6 Nov 2025 16:23:48 +0800 Subject: [PATCH 10/12] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5d90d647af..124bfc2bb0 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni | --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Run diverse RFT modes | + [Quick example: GRPO on GSM8k](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [Fully asynchronous RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [Offline learning by DPO or SFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | | Multi-step agentic scenarios | + [Concatenated multi-turn workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [General multi-step workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [ReAct workflow with an agent framework](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | -| Advanced data pipelines | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([paper](https://arxiv.org/pdf/2510.26374))
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html)
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay) | +| Advanced data pipelines | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([paper](https://arxiv.org/pdf/2510.26374))
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html)
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay) | | Algorithm development / research | + [RL algorithm development with Trinity-RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) ([paper](https://arxiv.org/pdf/2508.11408))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([paper](https://arxiv.org/abs/2509.24203)) | | Going deeper into Trinity-RFT | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | From f39ac488e0f7224975ff0161fe014fc532f7b395 Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Thu, 6 Nov 2025 16:33:14 +0800 Subject: [PATCH 11/12] Update README --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 124bfc2bb0..37cbf82756 100644 --- a/README.md +++ b/README.md @@ -67,13 +67,13 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni ## 🔨 Tutorials and Guidelines -| Category | Tutorial / Guideline | -| --- |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Run diverse RFT modes | + [Quick example: GRPO on GSM8k](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [Fully asynchronous RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [Offline learning by DPO or SFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | -| Multi-step agentic scenarios | + [Concatenated multi-turn workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [General multi-step workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [ReAct workflow with an agent framework](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | -| Advanced data pipelines | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([paper](https://arxiv.org/pdf/2510.26374))
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html)
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay) | +| Category | Tutorial / Guideline | +| --- |------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Run diverse RFT modes | + [Quick example: GRPO on GSM8k](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_basic.html)
+ [Off-policy RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_reasoning_advanced.html)
+ [Fully asynchronous RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_async_mode.html)
+ [Offline learning by DPO or SFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_dpo.html) | +| Multi-step agentic scenarios | + [Concatenated multi-turn workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_multi_turn.html)
+ [General multi-step workflow](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_step_wise.html)
+ [ReAct workflow with an agent framework](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_react.html) | +| Advanced data pipelines | + [Rollout task mixing and selection](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)
+ [Online task curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ([paper](https://arxiv.org/pdf/2510.26374))
+ [Experience replay](https://github.com/modelscope/Trinity-RFT/tree/main/examples/ppo_countdown_exp_replay)
+ [Advanced data processing & human-in-the-loop](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) | | Algorithm development / research | + [RL algorithm development with Trinity-RFT](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_mix_algo.html) ([paper](https://arxiv.org/pdf/2508.11408))
+ Non-verifiable domains: [RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_ruler), [trainable RULER](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_gsm8k_trainable_ruler), [rubric-as-reward](https://github.com/modelscope/Trinity-RFT/tree/main/examples/grpo_rubric_as_reward)
+ [Research project: group-relative REINFORCE](https://github.com/modelscope/Trinity-RFT/tree/main/examples/rec_gsm8k) ([paper](https://arxiv.org/abs/2509.24203)) | -| Going deeper into Trinity-RFT | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | +| Going deeper into Trinity-RFT | + [Full configurations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)
+ [Benchmark toolkit for quick verification and experimentation](./benchmark/README.md)
+ [Understand the coordination between explorer and trainer](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/synchronizer.html) | > [!NOTE] From 888cee9ca1c58e4323a5a9feca920bffa069997e Mon Sep 17 00:00:00 2001 From: ShenQianli Date: Thu, 6 Nov 2025 16:46:39 +0800 Subject: [PATCH 12/12] Update README --- README.md | 2 +- README_zh.md | 2 +- docs/sphinx_doc/source/main.md | 2 +- docs/sphinx_doc/source_zh/main.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 37cbf82756..2d1404bdb3 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni * 📊 For data engineers. [[tutorial]](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/develop_operator.html) - Create datasets and build data pipelines for cleaning, augmentation, and human-in-the-loop scenarios. - - Example: [Data Processing](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html) + - Example: [Data Processing Foundations](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/example_data_functionalities.html), [Online Task Curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ## 🌟 Key Features diff --git a/README_zh.md b/README_zh.md index 3644b3330b..e8700d83ea 100644 --- a/README_zh.md +++ b/README_zh.md @@ -32,7 +32,7 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF * 📊 面向数据工程师。[[教程]](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/develop_operator.html) - 设计针对任务定制的数据集,构建处理流水线以支持数据清洗、增强以及人类参与场景 - - 示例:[数据处理](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_data_functionalities.html) + - 示例:[数据处理基础](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/example_data_functionalities.html),[在线任务选择](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) # 🌟 核心特性 diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index ae2f91dc65..c44ff90401 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -12,7 +12,7 @@ Trinity-RFT is a flexible, general-purpose framework for reinforcement fine-tuni * 📊 For data engineers. [[tutorial]](/tutorial/develop_operator.md) - Create datasets and build data pipelines for cleaning, augmentation, and human-in-the-loop scenarios. - - Example: [Data Processing](/tutorial/example_data_functionalities.md) + - Example: [Data Processing Foundations](/tutorial/example_data_functionalities.md), [Online Task Curriculum](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) ## 🌟 Key Features diff --git a/docs/sphinx_doc/source_zh/main.md b/docs/sphinx_doc/source_zh/main.md index fa4daa40c7..7f1c871998 100644 --- a/docs/sphinx_doc/source_zh/main.md +++ b/docs/sphinx_doc/source_zh/main.md @@ -12,7 +12,7 @@ Trinity-RFT 是一个灵活、通用的大语言模型(LLM)强化微调(RF * 📊 面向数据工程师。[[教程]](/tutorial/develop_operator.md) - 设计针对任务定制的数据集,构建处理流水线以支持数据清洗、增强以及人类参与场景 - - 示例:[数据处理](/tutorial/example_data_functionalities.md) + - 示例:[数据处理基础](/tutorial/example_data_functionalities.md),[在线任务选择](https://github.com/modelscope/Trinity-RFT/tree/main/examples/bots) # 🌟 核心特性