Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
718 commits
Select commit Hold shift + click to select a range
6d3073b
0707_1
antoinegg1 Jul 7, 2025
32100c6
0707_1
antoinegg1 Jul 7, 2025
79af776
0707_2
antoinegg1 Jul 7, 2025
5f6bdcc
0707_2
antoinegg1 Jul 7, 2025
17ed423
Merge branch 'lcy/refactor' into fw/refactor
antoinegg1 Jul 7, 2025
9cdf903
Merge branch 'lcy/refactor' into fw/refactor
antoinegg1 Jul 7, 2025
7d9f41b
0703_3
antoinegg1 Jul 7, 2025
3132862
0703_3
antoinegg1 Jul 7, 2025
d15f131
r
antoinegg1 Jul 7, 2025
1dfe91c
add api
garrett4wade Jul 7, 2025
1006be8
add directory structure
garrett4wade Jul 7, 2025
28c9479
add tests template
garrett4wade Jul 7, 2025
db590ea
p
antoinegg1 Jul 7, 2025
fe3c27f
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 7, 2025
b6e19db
format
garrett4wade Jul 7, 2025
5010107
Merge branch 'fw/refactor' of https://code.alipay.com/inclusionAI/ARe…
garrett4wade Jul 7, 2025
43b07b3
fix
antoinegg1 Jul 7, 2025
5409498
fix
antoinegg1 Jul 7, 2025
132b755
fix
antoinegg1 Jul 7, 2025
6710d5f
Merge branch 'lite' of https://code.alipay.com/inclusionAI/AReaL into…
garrett4wade Jul 7, 2025
3a0f1e5
checkout previous impl
garrett4wade Jul 7, 2025
95c315e
checkout previous implementations
garrett4wade Jul 7, 2025
3b2f43a
checkout prev impl
garrett4wade Jul 7, 2025
9c0c094
refactor
antoinegg1 Jul 7, 2025
e251abb
add remote sglang engine
garrett4wade Jul 7, 2025
7ab6755
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite
garrett4wade Jul 7, 2025
cf0db6a
format
garrett4wade Jul 7, 2025
34a64a9
0707_6
antoinegg1 Jul 7, 2025
9dd893c
0707_7
antoinegg1 Jul 7, 2025
57b9b94
add readme
garrett4wade Jul 7, 2025
645b58c
refactor1
antoinegg1 Jul 7, 2025
90f4cf0
0707_undone
antoinegg1 Jul 7, 2025
b006b31
f
antoinegg1 Jul 7, 2025
aced39b
0708_1
antoinegg1 Jul 8, 2025
6018376
Merge remote-tracking branch 'origin/lcy/refactor' into lcy/refactor
antoinegg1 Jul 8, 2025
74a2eba
0708_2
antoinegg1 Jul 8, 2025
fcfa067
0708_3
antoinegg1 Jul 8, 2025
b584cd2
0708_7
antoinegg1 Jul 8, 2025
3d3f682
0708_4
antoinegg1 Jul 8, 2025
184f9e8
0709_1
antoinegg1 Jul 9, 2025
2b6f962
0709_2
antoinegg1 Jul 9, 2025
e7991fc
0709_3
antoinegg1 Jul 9, 2025
8771778
PullRequest: 331 [lite] Support remote sglang engine with correspondi…
garrett4wade Jul 9, 2025
223cafd
0709_4
antoinegg1 Jul 9, 2025
c01052a
0709_5
antoinegg1 Jul 9, 2025
605342d
0709_
antoinegg1 Jul 9, 2025
7379a9d
0709_6
antoinegg1 Jul 9, 2025
8a7d656
0709_7
antoinegg1 Jul 9, 2025
7a438c0
PullRequest: 336 add wrapper
kdada Jul 9, 2025
92f144e
0709_7
antoinegg1 Jul 9, 2025
3eaf620
0709_8
antoinegg1 Jul 9, 2025
2edcd2a
0709_9
antoinegg1 Jul 9, 2025
15dfbe8
PullRequest: 332 [lite] Support FSDP engines
garrett4wade Jul 9, 2025
7be4ab0
PullRequest: 339 [Fix] Fix some minor issues to pass all tests.
garrett4wade Jul 9, 2025
ee6f5a8
chore: empty commit
futrime Jul 9, 2025
8e201ef
ci: build images on demand
futrime Jul 9, 2025
a70cd28
ci: fix on demand condition
futrime Jul 9, 2025
a203c7c
ci: fix env sha
futrime Jul 9, 2025
c38cffc
PullRequest: 340 [lite] Refactor trainer API into utilities and remov…
garrett4wade Jul 10, 2025
42c717b
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite
garrett4wade Jul 10, 2025
0cd58b5
0710_1
antoinegg1 Jul 10, 2025
d48bf00
Merge branch 'main' of https://github.com/inclusionAI/AReaL into lite
garrett4wade Jul 10, 2025
496413f
0710_2
antoinegg1 Jul 10, 2025
3bf9c85
[Fix] Merge previous contributions from fw/refactor to lite (#163)
garrett4wade Jul 10, 2025
e57cb20
0710_2
antoinegg1 Jul 10, 2025
3122d90
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 10, 2025
50cf951
0710_3
antoinegg1 Jul 10, 2025
27c06b9
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 10, 2025
8d4b8dc
[Doc] Add an instruction about how to run the SFT example. (#164)
garrett4wade Jul 10, 2025
622781d
0710_3
antoinegg1 Jul 10, 2025
0d3c579
0710_3
antoinegg1 Jul 10, 2025
0640d5a
0710_5
antoinegg1 Jul 10, 2025
e1f2853
0710_4
antoinegg1 Jul 10, 2025
8c7affe
merge_lite
antoinegg1 Jul 10, 2025
395affa
merge_2
antoinegg1 Jul 10, 2025
4c0cd02
merge_3
antoinegg1 Jul 10, 2025
16e087d
0711_1
antoinegg1 Jul 11, 2025
fc51cd5
0711_2
antoinegg1 Jul 11, 2025
2af8cd5
0711_3
antoinegg1 Jul 11, 2025
b3fed3c
0711_4
antoinegg1 Jul 11, 2025
cad4488
0711_6
antoinegg1 Jul 11, 2025
437f7a7
0711_7
antoinegg1 Jul 11, 2025
35ab78b
0711_8
antoinegg1 Jul 11, 2025
04e432e
0711_8
antoinegg1 Jul 11, 2025
d6ff9e7
0711_9
antoinegg1 Jul 11, 2025
036aa9a
0711_10
antoinegg1 Jul 11, 2025
10a3731
0711-11
antoinegg1 Jul 11, 2025
c5f0235
[Fix] Fix CI running condition for lite. (#172)
garrett4wade Jul 12, 2025
434d2f5
PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine
nuzant Jul 14, 2025
d8038b2
PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine
garrett4wade Jul 14, 2025
d79f0dc
0714_1
antoinegg1 Jul 14, 2025
eb524e8
0714_2
antoinegg1 Jul 14, 2025
84cd936
0714_3
antoinegg1 Jul 14, 2025
b74f240
0714_3
antoinegg1 Jul 14, 2025
101c4e9
0714_5
antoinegg1 Jul 14, 2025
724628e
PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngin…
garrett4wade Jul 14, 2025
8a15551
PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment u…
garrett4wade Jul 14, 2025
8cc9b1f
added LocalSGlangEngine and test (#170)
PrinsYin Jul 15, 2025
9ed043f
format (#174)
garrett4wade Jul 15, 2025
3f95968
PullRequest: 358 [lite] Support GRPO training locally with the GSM8k …
garrett4wade Jul 15, 2025
69f5450
merge1
antoinegg1 Jul 15, 2025
0435aa5
merge2
antoinegg1 Jul 15, 2025
e960c17
0715_1
antoinegg1 Jul 15, 2025
ec60071
0715_2
antoinegg1 Jul 15, 2025
325ef6e
0715_2
antoinegg1 Jul 15, 2025
ef4215d
[Feat][Refactor]Support DeepSpeed AutoTP; Refactor hf_engine.py and u…
Jayon02 Jul 16, 2025
517353c
fix ci (#175)
garrett4wade Jul 16, 2025
4490b11
[Feature] Add pre-commit (#178)
garrett4wade Jul 16, 2025
e13db01
[lite] [refactor] Add GSM8k GRPO example. (#179)
garrett4wade Jul 16, 2025
c75dcaf
merge
garrett4wade Jul 16, 2025
712a4ab
0716_1
antoinegg1 Jul 16, 2025
5efd861
0716_2
antoinegg1 Jul 16, 2025
b2bd639
PullRequest: 368 [lite] Refactor train engine after merging contribut…
garrett4wade Jul 16, 2025
b56f599
PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation
garrett4wade Jul 16, 2025
0283cfa
change doc (#180)
garrett4wade Jul 16, 2025
29e164a
[Fix] [lite] Merge from the internal repo to fix GRPO bugs and refact…
garrett4wade Jul 16, 2025
74fcc38
0716_3
antoinegg1 Jul 16, 2025
1cbb642
0716_4
antoinegg1 Jul 16, 2025
2419d44
0716_4
antoinegg1 Jul 16, 2025
8596ef4
0716_5
antoinegg1 Jul 16, 2025
3c7c739
0717_1
antoinegg1 Jul 17, 2025
871b25a
0717_3
antoinegg1 Jul 17, 2025
0a2b9db
0717_3
antoinegg1 Jul 17, 2025
510313b
0717_4
antoinegg1 Jul 17, 2025
ce796f2
0717_5
antoinegg1 Jul 17, 2025
e9dc112
0717_6
antoinegg1 Jul 17, 2025
587544b
0717_6
antoinegg1 Jul 17, 2025
a032333
0717_6
antoinegg1 Jul 17, 2025
c0176b5
0718_1
antoinegg1 Jul 18, 2025
0e27a10
0718_2
antoinegg1 Jul 18, 2025
a08043e
0718_4
antoinegg1 Jul 18, 2025
090850a
0718_5
antoinegg1 Jul 18, 2025
ddabd9c
PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher
nuzant Jul 21, 2025
ade6a1d
Merge remote-tracking branch 'origin/lite' into lcy/refactor
antoinegg1 Jul 21, 2025
c8952f0
merge_0721
antoinegg1 Jul 21, 2025
25b65a2
0721_1
antoinegg1 Jul 21, 2025
2f1b679
PullRequest: 392 [lite] Fix several bugs regarding RL learning and ad…
garrett4wade Jul 21, 2025
588ffd2
0721_2
antoinegg1 Jul 21, 2025
a157510
0721_3
antoinegg1 Jul 21, 2025
8f26371
merge_0721_2
antoinegg1 Jul 21, 2025
f68a4f6
Implement fsdp distributed update (#183)
PrinsYin Jul 21, 2025
9c4da33
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite
garrett4wade Jul 21, 2025
18f8a05
[Feature] [lite] Merge from internal dev repo (#189)
garrett4wade Jul 21, 2025
4804b05
[Refactor] Rename files in arealite before release. (#190)
garrett4wade Jul 21, 2025
9fcc177
0721_4
antoinegg1 Jul 21, 2025
ab5db3f
.
garrett4wade Jul 21, 2025
4dd4a22
.
garrett4wade Jul 21, 2025
339e87a
0721_formal
antoinegg1 Jul 21, 2025
67760d3
0721_formal
antoinegg1 Jul 21, 2025
60ac722
0721_merge3
antoinegg1 Jul 21, 2025
a2d6d21
0721_merge4
antoinegg1 Jul 21, 2025
b4e8215
0721_merge5
antoinegg1 Jul 21, 2025
475c35c
0721_6
antoinegg1 Jul 21, 2025
aed6a90
Merge remote-tracking branch 'backup/lite' into lcy/refactor
antoinegg1 Jul 21, 2025
c295614
0721_merge6
antoinegg1 Jul 21, 2025
f451dbd
0721_merge7
antoinegg1 Jul 21, 2025
80862b7
0721_8
antoinegg1 Jul 21, 2025
79e2a81
0722_1
antoinegg1 Jul 22, 2025
3d2f7a9
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 22, 2025
7199ce2
0722_2
antoinegg1 Jul 22, 2025
eba0b5f
0722_3
antoinegg1 Jul 22, 2025
ba16d4e
add quickstart (#194)
nuzant Jul 22, 2025
6239633
[doc] [lite] Add customization docs for AReaLite. (#191)
garrett4wade Jul 22, 2025
229f101
0722_4
antoinegg1 Jul 22, 2025
ea12141
0722_4
antoinegg1 Jul 22, 2025
c27a51b
0722_5
antoinegg1 Jul 22, 2025
5c0662f
0722_6
antoinegg1 Jul 22, 2025
af2f80c
0722_7
antoinegg1 Jul 22, 2025
8815be6
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 22, 2025
eff8f09
0723_1
antoinegg1 Jul 23, 2025
6bde86a
reformatted
antoinegg1 Jul 23, 2025
52c9447
clang-reformatted
antoinegg1 Jul 23, 2025
25884f5
clang-reformatted2
antoinegg1 Jul 23, 2025
391bd85
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 23, 2025
dd14838
0723_1
antoinegg1 Jul 23, 2025
9ec2c3f
0723_1
antoinegg1 Jul 23, 2025
4041afb
0723_1
antoinegg1 Jul 23, 2025
2a2e2fe
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 23, 2025
8e82c59
0723_merge3
antoinegg1 Jul 23, 2025
d12dec2
0723_4
antoinegg1 Jul 23, 2025
82442b8
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 23, 2025
00b5d87
0723_reformatted_5
antoinegg1 Jul 23, 2025
9a16605
0724_1
antoinegg1 Jul 24, 2025
6c28d52
0724_1
antoinegg1 Jul 24, 2025
c816a3c
0724_merge1
antoinegg1 Jul 24, 2025
e97e33f
0724_merge2
antoinegg1 Jul 24, 2025
176ec4b
0724_merge3
antoinegg1 Jul 24, 2025
5118cfa
0724_merge3
antoinegg1 Jul 24, 2025
5690b52
0724_merge4
antoinegg1 Jul 24, 2025
311bcd7
[lite] [feature] Bump to SGLang v0.4.9.post2 and use NCCL to update w…
garrett4wade Jul 24, 2025
84e2d75
Merge remote-tracking branch 'backup/lite' into lcy/refactor
antoinegg1 Jul 24, 2025
1bc9310
0724_merge5
antoinegg1 Jul 24, 2025
13fc236
0724_merge6
antoinegg1 Jul 24, 2025
27fd51a
0724_merge7
antoinegg1 Jul 24, 2025
e705db1
0724_merge8
antoinegg1 Jul 24, 2025
e26a43a
[Docs] [lite] Add example code walkthrough documentation. (#197)
nuzant Jul 24, 2025
f299740
[lite] [doc] Add AReaLite design doc as README (#198)
garrett4wade Jul 24, 2025
aa6c28e
Merge branch 'main' into lite
garrett4wade Jul 24, 2025
6aeeabf
0724_4
antoinegg1 Jul 24, 2025
f5924b1
0724_merge7
antoinegg1 Jul 24, 2025
84be9c9
Merge remote-tracking branch 'backup/lite' into lcy/refactor
antoinegg1 Jul 24, 2025
6255ad5
0724-merge8
antoinegg1 Jul 24, 2025
b8549ac
0724_merge8
antoinegg1 Jul 24, 2025
4198cd6
0725_1
antoinegg1 Jul 25, 2025
3c272ff
0725_6
antoinegg1 Jul 25, 2025
8eaced4
0725_7
antoinegg1 Jul 25, 2025
4f8b17f
0725_4padded_image
antoinegg1 Jul 25, 2025
cc3c6bb
0725_9padded_image
antoinegg1 Jul 25, 2025
60ac19a
0725_10padded_image
antoinegg1 Jul 25, 2025
fb1796d
0725_11
antoinegg1 Jul 25, 2025
a4ad671
0725
antoinegg1 Jul 25, 2025
6b8bfcf
0725_12
antoinegg1 Jul 25, 2025
4ff813a
0725_format
antoinegg1 Jul 25, 2025
e2a3579
Add self-hosted runner support (#199)
futrime Jul 28, 2025
7fb6a80
[WIP][feat] Initial support for VLMs, add Qwen2VL SFT test and Qwen2.…
antoinegg1 Jul 28, 2025
68b4b02
0731
antoinegg1 Jul 31, 2025
1ae006c
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 31, 2025
d6a6240
0731_1
antoinegg1 Jul 31, 2025
c5cd21d
0731_2
antoinegg1 Jul 31, 2025
78d0367
0731_2
antoinegg1 Jul 31, 2025
2e0af5d
0731_3
antoinegg1 Jul 31, 2025
c3c986a
0731_4
antoinegg1 Jul 31, 2025
47972dd
0801_1
antoinegg1 Aug 1, 2025
d3084c3
0801_2
antoinegg1 Aug 1, 2025
6724134
0804_1
antoinegg1 Aug 4, 2025
c5ec8d1
0804_2
antoinegg1 Aug 4, 2025
866905e
0804_2
antoinegg1 Aug 4, 2025
3062f95
0804_merge
antoinegg1 Aug 4, 2025
6bb9082
0804_merge2
antoinegg1 Aug 4, 2025
be294c6
0804_5
antoinegg1 Aug 4, 2025
82646dc
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Aug 4, 2025
7be94c0
0805_1
antoinegg1 Aug 5, 2025
3f741ff
0805_2
antoinegg1 Aug 5, 2025
3368fe6
0805_3
antoinegg1 Aug 5, 2025
0761668
0805_2
antoinegg1 Aug 5, 2025
78c62c1
0806
antoinegg1 Aug 6, 2025
6d6ddc9
0806_2
antoinegg1 Aug 6, 2025
6224db6
0806_merge1
antoinegg1 Aug 6, 2025
506b28c
0806_merge2
antoinegg1 Aug 6, 2025
818fa4f
0806_format1
antoinegg1 Aug 6, 2025
00f49f6
0806_merge3
antoinegg1 Aug 6, 2025
77210b5
0806_4
antoinegg1 Aug 6, 2025
94a7360
0806_6
antoinegg1 Aug 6, 2025
975b09b
0806_7
antoinegg1 Aug 6, 2025
96a9ea8
0806_formatted2
antoinegg1 Aug 6, 2025
2ca7fb7
Merge branch 'main' of https://github.com/inclusionAI/AReaL into lcy/…
garrett4wade Aug 7, 2025
06990e9
fix
garrett4wade Aug 7, 2025
847fdd8
revert examples
garrett4wade Aug 7, 2025
6bc6ef3
.
garrett4wade Aug 7, 2025
63d780d
.
garrett4wade Aug 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion areal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -558,4 +558,4 @@ class MyRolloutWorkflow:
`TrainEngine` respectively. Controllers handle engine deployment across the cluster and
manage data distribution, invoking engine methods through remote procedure calls (RPCs).
This architecture enables distributed operation while maintaining familiar interfaces
for users.
for users.
3 changes: 3 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ class DatasetConfig:
default=0, metadata={"help": "Number of worker processes for data loading"}
)
drop_last: bool = field(default=True)
reward_fn: Optional[str] = field(
default=None,
)


@dataclass
Expand Down
22 changes: 17 additions & 5 deletions areal/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import transformers

VALID_DATASETS = ["gsm8k", "clevr_count_70k"]
VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k"]


def get_custom_dataset(
Expand All @@ -17,25 +17,37 @@ def get_custom_dataset(
):

if "gsm8k" in path and type == "sft":
from areal.dataset.gsm8k import get_gsm8k_sft_dataset
from .gsm8k import get_gsm8k_sft_dataset

return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size, **kwargs)
elif "gsm8k" in path and type == "rl":
from areal.dataset.gsm8k import get_gsm8k_rl_dataset
from .gsm8k import get_gsm8k_rl_dataset

return get_gsm8k_rl_dataset(path, split, rank, world_size, **kwargs)
elif "clevr_count_70k" in path and type == "sft":
from areal.dataset.clevr_count_70k import get_clevr_count_70k_sft_dataset
from .clevr_count_70k import get_clevr_count_70k_sft_dataset

return get_clevr_count_70k_sft_dataset(
path, split, processor, rank, world_size, **kwargs
)
elif "clevr_count_70k" in path and type == "rl":
from areal.dataset.clevr_count_70k import get_clevr_count_70k_rl_dataset
from .clevr_count_70k import get_clevr_count_70k_rl_dataset

return get_clevr_count_70k_rl_dataset(
path, split, processor, rank, world_size, **kwargs
)
elif "geometry3k" in path and type == "sft":
from .geometry3k import get_geometry3k_sft_dataset

return get_geometry3k_sft_dataset(
path, split, processor, rank, world_size, **kwargs
)
elif "geometry3k" in path and type == "rl":
from .geometry3k import get_geometry3k_rl_dataset

return get_geometry3k_rl_dataset(
path, split, processor, rank, world_size, **kwargs
)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "
Expand Down
141 changes: 141 additions & 0 deletions areal/dataset/geometry3k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import math
from io import BytesIO
from typing import Any, Dict, Optional, Union

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from PIL import Image
from PIL.Image import Image as ImageObject
from torchvision import transforms


def pad_to_square(img: Image.Image, fill=(0, 0, 0)) -> Image.Image:

w, h = img.size
side = max(w, h)
new_img = Image.new(img.mode, (side, side), color=fill)
offset = ((side - w) // 2, (side - h) // 2)
new_img.paste(img, offset)
return new_img


def convert_image(
image: Union[Dict[str, Any], ImageObject, str],
fixed_width: Optional[int] = None,
fixed_height: Optional[int] = None,
) -> ImageObject:
if (
fixed_width is not None
and fixed_height is not None
and (image.width != fixed_width or image.height != fixed_height)
):
preprocess = transforms.Compose(
[
transforms.CenterCrop((fixed_width, fixed_height)), # <─ 核心操作
]
)
image = preprocess(image)
if image.mode != "RGB":
image = image.convert("RGB")
with BytesIO() as output:
image.save(output, format="JPEG")
return output.getvalue()


def get_geometry3k_sft_dataset(path, split, processor, rank, world_size):
"""
"geometry3k": {
"image_key": "images",
"question_key": "problem",
"answer_key": "answer"
},
"""
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
tokenizer = processor.tokenizer

def process_example(example, idx):
# Add query_id column
images = example["images"]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
example["problem"] = (
example["problem"].replace("<image>", image_token).replace("different", "")
)
processed_images = []
for image in images:
processed_images.append(convert_image(image, 512, 512))
example["images"] = processed_images
example["seq"] = example["problem"] + example["answer"] + tokenizer.eos_token

return example

dataset = dataset.map(
lambda example, idx: process_example(example, idx),
with_indices=True,
)

def _process(example):
text = example["seq"]
processed_input = processor(
text=[text],
images=example["images"],
padding=False,
return_tensors="pt",
return_length=True,
return_attention_mask=False,
)

example["input_ids"] = processed_input["input_ids"].squeeze(0)
example["pixel_values"] = processed_input["pixel_values"]
example["image_grid_thw"] = processed_input["image_grid_thw"]
answer_token = tokenizer.encode(example["answer"])
loss_mask = [0] * (len(example["input_ids"]) - len(answer_token)) + [1] * len(
answer_token
)
example["loss_mask"] = loss_mask
return example

dataset = dataset.map(
lambda x: _process(x), remove_columns=["images", "seq", "problem", "answer"]
)
return dataset


def get_geometry3k_rl_dataset(path, split, processor, rank, world_size):
dataset = load_dataset(path=path, split=split)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)

def process(sample):
processed_images = [
convert_image(image, 448, 448) for image in sample["images"]
]
if "qwen" in processor.image_processor.image_processor_type.lower():
image_token = "<|vision_start|><|image_pad|><|vision_end|>"
else:
image_token = processor.image_token if processor is not None else "<image>"
system_prompt = {
"role": "system",
"content": (
"Solve the following geometric problem based on the image. You may explain your reasoning before providing the final answer. The answer should be enclosed in [ ] and can be a number, decimal, or LaTeX format (e.g. \frac { 4 }{ 9 } \sqrt { 3 }).\n"
),
}

messages = [
{
"role": "user",
"content": sample["problem"]
.replace("<image>", image_token)
.replace("different", ""),
}
]
messages.insert(0, system_prompt)
messages = processor.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return {"messages": messages, "images": processed_images}

dataset = dataset.map(process).remove_columns(["problem"])
return dataset
34 changes: 30 additions & 4 deletions areal/engine/base_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
unsqueeze_mb_list,
)
from areal.utils.fsdp import get_cosine_schedule_with_warmup
from areal.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
from areal.utils.model import (
VALID_VISION_MODELS,
disable_dropout_in_model,
is_qwen2_vl_model,
)
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
from realhf.base import constants, logging

Expand Down Expand Up @@ -253,10 +257,26 @@ def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert (
"pixel_values" in input_ and "image_grid_thw" in input_
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"

if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)
if is_qwen2_vl_model(self.model_config.model_type):
# Create the special t,h,w position IDs for qwen 2.5 VL
attn_mask = input_["attention_mask"]
input_ids = input_["input_ids"]
image_grid_thw = input_.get("image_grid_thw", None)
video_grid_thw = input_.get("video_grid_thw", None)
if image_grid_thw is not None:
image_grid_thw = image_grid_thw.squeeze(1)
if video_grid_thw is not None:
video_grid_thw = video_grid_thw.squeeze(1)
position_ids, _ = self.model.model.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attn_mask
)
# [3, bs, seqlen] -> [bs, seqlen, 3]
position_ids = torch.einsum("ijk->jki", position_ids)
input_["position_ids"] = position_ids
else:
input_ = amend_position_ids(input_)

mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
mb_list.mbs = [pack_tensor_dict(mb) for mb in mb_list.mbs]
Expand All @@ -272,6 +292,10 @@ def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)
if is_qwen2_vl_model(self.model_config.model_type):
for mb in mb_list.padded_mbs:
# [1, total_seqlen, 3] -> [3, 1, total_seqlen]
mb["position_ids"] = torch.einsum("ijk->kij", mb["position_ids"])

# FIXME: the resulting max_seqlen is a tensor rather than an integer
# TODO: remove the usage of tensordict
Expand All @@ -283,10 +307,12 @@ def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
mb_list.padded_mbs[i] = dict(**mb)
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["cu_seqlens_q"] = mb["cu_seqlens_k"] = mb["cu_seqlens"]
mb["use_cache"] = False
mb["attention_mask"] = dict(full_attention=None)
for mb in mb_list.padded_mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["cu_seqlens_q"] = mb["cu_seqlens_k"] = mb["cu_seqlens"]
mb["use_cache"] = False
mb["attention_mask"] = dict(full_attention=None)

Expand Down Expand Up @@ -317,7 +343,6 @@ def train_batch(
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):

outputs = self.model(**padded_mb_input)

logits = outputs.logits.squeeze(0)
Expand Down Expand Up @@ -408,6 +433,7 @@ def forward(
for pad_length, padded_mb_input, mb_input in zip(
mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs
):

outputs = self.model(**padded_mb_input)
logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
Expand Down
17 changes: 17 additions & 0 deletions areal/reward/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
VALID_REWARD_FN = ["clevr_count_70k", "geometry3k"]


def get_custom_reward_fn(path: str, **kwargs):
if "clevr_count_70k" in path:
from .clevr_count_70k import clevr_count_70k_reward_fn

return clevr_count_70k_reward_fn
elif "geometry3k" in path:
from .geometry3k import geometry3k_reward_fn

return geometry3k_reward_fn
else:
raise ValueError(
f"Reward function {path} is not supported. "
f"Supported reward functions are: {VALID_REWARD_FN}. "
)
27 changes: 27 additions & 0 deletions areal/reward/clevr_count_70k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import re


def extract_answer(pred_str, data_name, use_last_number=True):
match = re.findall(r"\[([0-9\.]+)\]", pred_str)
if match:
return match[-1]

return ""


def clevr_count_70k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer

if sol is None:
return 0
if ans is None:
return 0

if sol.strip() == ans.strip():
print(f"completions: {completions}, answer: {answer}")
return 1

return 0
29 changes: 29 additions & 0 deletions areal/reward/geometry3k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import re


def extract_answer(pred_str, data_name, use_last_number=True):
matches = re.findall(r"\[([^\]]+)\]", pred_str)
if matches:
return matches[-1]

return ""


def geometry3k_reward_fn(
prompt, completions, prompt_ids, completion_ids, answer, **kwargs
):
sol = extract_answer(completions, data_name="") # str number
ans = answer
sol = sol.replace(" ", "")
ans = ans.replace(" ", "")
if sol is None:
return 0
if ans is None:
return 0
# print(f"sol: {sol}, ans: {ans}")
from realhf.impl.dataset.math_parser import math_equal

if math_equal(sol, ans):
# print(f"completions: {completions}, answer: {answer}")
return 1
return 0
Loading
Loading