Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
434d2f5
PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine
nuzant Jul 14, 2025
d8038b2
PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine
garrett4wade Jul 14, 2025
724628e
PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngin…
garrett4wade Jul 14, 2025
8a15551
PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment u…
garrett4wade Jul 14, 2025
3f95968
PullRequest: 358 [lite] Support GRPO training locally with the GSM8k …
garrett4wade Jul 15, 2025
c75dcaf
merge
garrett4wade Jul 16, 2025
b2bd639
PullRequest: 368 [lite] Refactor train engine after merging contribut…
garrett4wade Jul 16, 2025
b56f599
PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation
garrett4wade Jul 16, 2025
ddabd9c
PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher
nuzant Jul 21, 2025
2f1b679
PullRequest: 392 [lite] Fix several bugs regarding RL learning and ad…
garrett4wade Jul 21, 2025
9c4da33
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite
garrett4wade Jul 21, 2025
ab5db3f
.
garrett4wade Jul 21, 2025
4dd4a22
.
garrett4wade Jul 21, 2025
415cc6b
.
garrett4wade Jul 21, 2025
9db11d1
.
garrett4wade Jul 21, 2025
8a610f3
.
garrett4wade Jul 21, 2025
2d3c6ac
.
garrett4wade Jul 22, 2025
8e96646
.
garrett4wade Jul 22, 2025
5d58cb0
.
garrett4wade Jul 22, 2025
fd47cb4
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into fw/l…
garrett4wade Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions arealite/workflow/multi_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import uuid

import torch
from transformers import PreTrainedTokenizerFast

from arealite.api.cli_args import GenerationHyperparameters
from arealite.api.engine_api import InferenceEngine
from arealite.api.io_struct import LLMRequest
from arealite.api.workflow_api import RolloutWorkflow
from arealite.utils.data import concat_padded_tensors


class MultiTurnWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn,
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
max_turns: int,
turn_discount: float,
):
self.reward_fn = reward_fn
self.gconfig = gconfig
self.tokenizer = tokenizer
self.max_turns = max_turns
self.turn_discount = turn_discount

async def arun_episode(self, engine: InferenceEngine, data):
# Placeholders for the results
seq, logprobs, loss_mask, versions = [], [], [], []
messages = data["messages"]
# Run multi-turn rollout until correct
t = reward = 0
discount = 0
rid = uuid.uuid4().hex
while reward == 0 and t < self.max_turns:
# Amend a prompt if the previous answer is incorrect
if t > 0:
messages += [
{"role": "asistant", "content": completions_str},
{
"role": "user",
"content": "Your answer is not correct. Please try to answer it again.",
},
]
# Convert the prompt into input_ids
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
)
# Send generate request to get the response.
req = LLMRequest(
rid=rid,
input_ids=input_ids,
gconfig=self.gconfig.new(n_samples=1),
)
resp = await engine.agenerate(req)
# compute reward: 1 for correct and 0 otherwise
prompt_str = self.tokenizer.decode(input_ids)
completions_str = self.tokenizer.decode(resp.output_tokens)
reward = self.reward_fn(
prompt=prompt_str,
completions=completions_str,
prompt_ids=resp.input_tokens,
completion_ids=resp.output_tokens,
**data,
)
# Amend results
input_len = len(resp.input_tokens) - len(seq)
seq += resp.input_tokens[-input_len:] + resp.output_tokens
logprobs += [0.0] * input_len + resp.output_logprobs
loss_mask += [0] * input_len + [1] * resp.output_len
versions += [-1] * input_len + resp.output_versions
# Increase counter
t += 1
discount *= self.turn_discount
res = dict(
seq=torch.tensor(seq),
logprobs=torch.tensor(logprobs),
loss_mask=torch.tensor(loss_mask),
versions=torch.tensor(versions),
rewards=torch.tensor([float(reward * discount)]),
attetion_mask=torch.ones(len(seq), dtype=torch.bool),
)
res = {k: v.unsqueeze(0) for k, v in res.items()}
return concat_padded_tensors([res])
5 changes: 5 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ parts:
- caption: Contributing
chapters:
- file: contrib
- caption: Customization (Legacy)
chapters:
- file: legacy/customization/dataset
- file: legacy/customization/agent
- file: legacy/customization/algorithm
Loading
Loading