Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
635 commits
Select commit Hold shift + click to select a range
ef2da68
vlm_sft_test
antoinegg1 Jun 30, 2025
8007f17
vlm_sft_test
antoinegg1 Jun 30, 2025
719081c
.
garrett4wade Jul 1, 2025
2ce1ece
.
garrett4wade Jul 1, 2025
1e51b2c
Fix unresolved issue in SFTTrainer PR (#139)
nuzant Jul 1, 2025
09f339f
Fix unresolved issue in SFTTrainer PR (#139)
nuzant Jul 1, 2025
0659174
Merge branch 'fw/refactor' of https://github.com/inclusionAI/AReaL in…
garrett4wade Jul 1, 2025
df5ee49
Merge branch 'fw/refactor' of https://github.com/inclusionAI/AReaL in…
garrett4wade Jul 1, 2025
eddfada
Merge branch 'fw/refactor2' of https://code.alipay.com/inclusionAI/AR…
garrett4wade Jul 1, 2025
d1f863c
Merge branch 'fw/refactor2' of https://code.alipay.com/inclusionAI/AR…
garrett4wade Jul 1, 2025
ae79dda
image_process0701
antoinegg1 Jul 1, 2025
3be8639
image_process0701
antoinegg1 Jul 1, 2025
3eba150
image_process0701_2
antoinegg1 Jul 1, 2025
e19a0dc
image_process0701_2
antoinegg1 Jul 1, 2025
3c36e4e
image_process0701_3
antoinegg1 Jul 1, 2025
bda8514
image_process0701_3
antoinegg1 Jul 1, 2025
a19b855
.
garrett4wade Jul 2, 2025
ab7503a
.
garrett4wade Jul 2, 2025
4b8f824
.
garrett4wade Jul 2, 2025
a5299b1
.
garrett4wade Jul 2, 2025
c4b4d90
.
garrett4wade Jul 2, 2025
3a8796b
.
garrett4wade Jul 2, 2025
c2fe048
imageprocess0702
antoinegg1 Jul 2, 2025
334a2b0
imageprocess0702
antoinegg1 Jul 2, 2025
14cad10
image_process0702_2
antoinegg1 Jul 2, 2025
e3929d1
image_process0702_2
antoinegg1 Jul 2, 2025
4139b22
image_process0702_3
antoinegg1 Jul 2, 2025
f9006ba
image_process0702_3
antoinegg1 Jul 2, 2025
40df511
image_process0702_4
antoinegg1 Jul 2, 2025
9583123
image_process0702_4
antoinegg1 Jul 2, 2025
ae463cc
image_process0702_5
antoinegg1 Jul 2, 2025
109be6a
image_process0702_5
antoinegg1 Jul 2, 2025
685045f
image_process0703_1
antoinegg1 Jul 3, 2025
618826e
image_process0703_1
antoinegg1 Jul 3, 2025
2bb0be3
0703_2
antoinegg1 Jul 3, 2025
ea0c65c
0703_2
antoinegg1 Jul 3, 2025
dd647e1
0703_3
antoinegg1 Jul 3, 2025
838774b
0703_3
antoinegg1 Jul 3, 2025
0e293e5
0703_4
antoinegg1 Jul 3, 2025
9669c85
0703_4
antoinegg1 Jul 3, 2025
53657b1
0703_4
antoinegg1 Jul 3, 2025
2900f8c
0703_4
antoinegg1 Jul 3, 2025
c12cc5e
0703_5
antoinegg1 Jul 3, 2025
c8d6d4c
0703_5
antoinegg1 Jul 3, 2025
a8e7a99
0703_6
antoinegg1 Jul 3, 2025
942a39b
0703_6
antoinegg1 Jul 3, 2025
dae6bec
0703_7
antoinegg1 Jul 3, 2025
a68f931
0703_7
antoinegg1 Jul 3, 2025
640a6be
0703_8
antoinegg1 Jul 3, 2025
20f7605
0703_8
antoinegg1 Jul 3, 2025
9b8a2d1
0703_9
antoinegg1 Jul 3, 2025
2c9f014
0703_9
antoinegg1 Jul 3, 2025
17af243
0703_11
antoinegg1 Jul 3, 2025
573cb7e
0703_11
antoinegg1 Jul 3, 2025
0766d19
0703_12
antoinegg1 Jul 3, 2025
a9a37f6
0703_12
antoinegg1 Jul 3, 2025
35505a2
0703_13
antoinegg1 Jul 3, 2025
f00c545
0703_13
antoinegg1 Jul 3, 2025
080f637
0703_14
antoinegg1 Jul 3, 2025
5db1b68
0703_14
antoinegg1 Jul 3, 2025
59dd80c
0703_15
antoinegg1 Jul 3, 2025
193d052
0703_15
antoinegg1 Jul 3, 2025
99633db
0703_16
antoinegg1 Jul 3, 2025
c75230e
0703_16
antoinegg1 Jul 3, 2025
5e2923b
0703-17
antoinegg1 Jul 3, 2025
5f0061d
0703-17
antoinegg1 Jul 3, 2025
f8d1211
0703_18
antoinegg1 Jul 3, 2025
03f54ab
0703_18
antoinegg1 Jul 3, 2025
19d7f94
0703_18
antoinegg1 Jul 3, 2025
1a6eb4a
0703_18
antoinegg1 Jul 3, 2025
86cbe43
0703_19
antoinegg1 Jul 3, 2025
6e13ea6
0703_19
antoinegg1 Jul 3, 2025
c04ee32
0704_1
antoinegg1 Jul 4, 2025
c96ceb9
0704_1
antoinegg1 Jul 4, 2025
aa1de1f
0704_2
antoinegg1 Jul 4, 2025
164c957
0704_2
antoinegg1 Jul 4, 2025
43ced6f
0704_3
antoinegg1 Jul 4, 2025
84258a2
0704_3
antoinegg1 Jul 4, 2025
8f62c9d
.
garrett4wade Jul 4, 2025
89a8d8c
.
garrett4wade Jul 4, 2025
6d3073b
0707_1
antoinegg1 Jul 7, 2025
32100c6
0707_1
antoinegg1 Jul 7, 2025
79af776
0707_2
antoinegg1 Jul 7, 2025
5f6bdcc
0707_2
antoinegg1 Jul 7, 2025
17ed423
Merge branch 'lcy/refactor' into fw/refactor
antoinegg1 Jul 7, 2025
9cdf903
Merge branch 'lcy/refactor' into fw/refactor
antoinegg1 Jul 7, 2025
7d9f41b
0703_3
antoinegg1 Jul 7, 2025
3132862
0703_3
antoinegg1 Jul 7, 2025
d15f131
r
antoinegg1 Jul 7, 2025
db590ea
p
antoinegg1 Jul 7, 2025
fe3c27f
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 7, 2025
5010107
Merge branch 'fw/refactor' of https://code.alipay.com/inclusionAI/ARe…
garrett4wade Jul 7, 2025
43b07b3
fix
antoinegg1 Jul 7, 2025
5409498
fix
antoinegg1 Jul 7, 2025
132b755
fix
antoinegg1 Jul 7, 2025
9c0c094
refactor
antoinegg1 Jul 7, 2025
34a64a9
0707_6
antoinegg1 Jul 7, 2025
9dd893c
0707_7
antoinegg1 Jul 7, 2025
645b58c
refactor1
antoinegg1 Jul 7, 2025
90f4cf0
0707_undone
antoinegg1 Jul 7, 2025
b006b31
f
antoinegg1 Jul 7, 2025
aced39b
0708_1
antoinegg1 Jul 8, 2025
6018376
Merge remote-tracking branch 'origin/lcy/refactor' into lcy/refactor
antoinegg1 Jul 8, 2025
74a2eba
0708_2
antoinegg1 Jul 8, 2025
fcfa067
0708_3
antoinegg1 Jul 8, 2025
b584cd2
0708_7
antoinegg1 Jul 8, 2025
3d3f682
0708_4
antoinegg1 Jul 8, 2025
184f9e8
0709_1
antoinegg1 Jul 9, 2025
2b6f962
0709_2
antoinegg1 Jul 9, 2025
e7991fc
0709_3
antoinegg1 Jul 9, 2025
223cafd
0709_4
antoinegg1 Jul 9, 2025
c01052a
0709_5
antoinegg1 Jul 9, 2025
605342d
0709_
antoinegg1 Jul 9, 2025
7379a9d
0709_6
antoinegg1 Jul 9, 2025
8a7d656
0709_7
antoinegg1 Jul 9, 2025
92f144e
0709_7
antoinegg1 Jul 9, 2025
3eaf620
0709_8
antoinegg1 Jul 9, 2025
2edcd2a
0709_9
antoinegg1 Jul 9, 2025
0cd58b5
0710_1
antoinegg1 Jul 10, 2025
496413f
0710_2
antoinegg1 Jul 10, 2025
e57cb20
0710_2
antoinegg1 Jul 10, 2025
3122d90
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 10, 2025
50cf951
0710_3
antoinegg1 Jul 10, 2025
27c06b9
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 10, 2025
622781d
0710_3
antoinegg1 Jul 10, 2025
0d3c579
0710_3
antoinegg1 Jul 10, 2025
0640d5a
0710_5
antoinegg1 Jul 10, 2025
e1f2853
0710_4
antoinegg1 Jul 10, 2025
8c7affe
merge_lite
antoinegg1 Jul 10, 2025
395affa
merge_2
antoinegg1 Jul 10, 2025
4c0cd02
merge_3
antoinegg1 Jul 10, 2025
16e087d
0711_1
antoinegg1 Jul 11, 2025
fc51cd5
0711_2
antoinegg1 Jul 11, 2025
2af8cd5
0711_3
antoinegg1 Jul 11, 2025
b3fed3c
0711_4
antoinegg1 Jul 11, 2025
cad4488
0711_6
antoinegg1 Jul 11, 2025
437f7a7
0711_7
antoinegg1 Jul 11, 2025
35ab78b
0711_8
antoinegg1 Jul 11, 2025
04e432e
0711_8
antoinegg1 Jul 11, 2025
d6ff9e7
0711_9
antoinegg1 Jul 11, 2025
036aa9a
0711_10
antoinegg1 Jul 11, 2025
10a3731
0711-11
antoinegg1 Jul 11, 2025
434d2f5
PullRequest: 353 [Lite] Add gradient checkpointing to FSDPEngine
nuzant Jul 14, 2025
d8038b2
PullRequest: 354 [lite] GRPO pre-commit: minor changes in FSDP engine
garrett4wade Jul 14, 2025
d79f0dc
0714_1
antoinegg1 Jul 14, 2025
eb524e8
0714_2
antoinegg1 Jul 14, 2025
84cd936
0714_3
antoinegg1 Jul 14, 2025
b74f240
0714_3
antoinegg1 Jul 14, 2025
101c4e9
0714_5
antoinegg1 Jul 14, 2025
724628e
PullRequest: 355 [Lite] GRPO pre-commit 2: Refactor RemoteSGLangEngin…
garrett4wade Jul 14, 2025
8a15551
PullRequest: 357 [lite] GRPO pre-commit 3: Fix typos and experiment u…
garrett4wade Jul 14, 2025
3f95968
PullRequest: 358 [lite] Support GRPO training locally with the GSM8k …
garrett4wade Jul 15, 2025
69f5450
merge1
antoinegg1 Jul 15, 2025
0435aa5
merge2
antoinegg1 Jul 15, 2025
e960c17
0715_1
antoinegg1 Jul 15, 2025
ec60071
0715_2
antoinegg1 Jul 15, 2025
325ef6e
0715_2
antoinegg1 Jul 15, 2025
c75dcaf
merge
garrett4wade Jul 16, 2025
712a4ab
0716_1
antoinegg1 Jul 16, 2025
5efd861
0716_2
antoinegg1 Jul 16, 2025
b2bd639
PullRequest: 368 [lite] Refactor train engine after merging contribut…
garrett4wade Jul 16, 2025
b56f599
PullRequest: 371 [lite] [fix] fix misc bugs in GRPO implementation
garrett4wade Jul 16, 2025
74fcc38
0716_3
antoinegg1 Jul 16, 2025
1cbb642
0716_4
antoinegg1 Jul 16, 2025
2419d44
0716_4
antoinegg1 Jul 16, 2025
8596ef4
0716_5
antoinegg1 Jul 16, 2025
3c7c739
0717_1
antoinegg1 Jul 17, 2025
871b25a
0717_3
antoinegg1 Jul 17, 2025
0a2b9db
0717_3
antoinegg1 Jul 17, 2025
510313b
0717_4
antoinegg1 Jul 17, 2025
ce796f2
0717_5
antoinegg1 Jul 17, 2025
e9dc112
0717_6
antoinegg1 Jul 17, 2025
587544b
0717_6
antoinegg1 Jul 17, 2025
a032333
0717_6
antoinegg1 Jul 17, 2025
c0176b5
0718_1
antoinegg1 Jul 18, 2025
0e27a10
0718_2
antoinegg1 Jul 18, 2025
a08043e
0718_4
antoinegg1 Jul 18, 2025
090850a
0718_5
antoinegg1 Jul 18, 2025
ddabd9c
PullRequest: 370 [lite] Add Slurm Launcher and Ray Launcher
nuzant Jul 21, 2025
ade6a1d
Merge remote-tracking branch 'origin/lite' into lcy/refactor
antoinegg1 Jul 21, 2025
c8952f0
merge_0721
antoinegg1 Jul 21, 2025
25b65a2
0721_1
antoinegg1 Jul 21, 2025
2f1b679
PullRequest: 392 [lite] Fix several bugs regarding RL learning and ad…
garrett4wade Jul 21, 2025
588ffd2
0721_2
antoinegg1 Jul 21, 2025
a157510
0721_3
antoinegg1 Jul 21, 2025
8f26371
merge_0721_2
antoinegg1 Jul 21, 2025
9c4da33
Merge branch 'lite' of https://github.com/inclusionAI/AReaL into lite
garrett4wade Jul 21, 2025
9fcc177
0721_4
antoinegg1 Jul 21, 2025
ab5db3f
.
garrett4wade Jul 21, 2025
4dd4a22
.
garrett4wade Jul 21, 2025
339e87a
0721_formal
antoinegg1 Jul 21, 2025
67760d3
0721_formal
antoinegg1 Jul 21, 2025
60ac722
0721_merge3
antoinegg1 Jul 21, 2025
a2d6d21
0721_merge4
antoinegg1 Jul 21, 2025
b4e8215
0721_merge5
antoinegg1 Jul 21, 2025
475c35c
0721_6
antoinegg1 Jul 21, 2025
aed6a90
Merge remote-tracking branch 'backup/lite' into lcy/refactor
antoinegg1 Jul 21, 2025
c295614
0721_merge6
antoinegg1 Jul 21, 2025
f451dbd
0721_merge7
antoinegg1 Jul 21, 2025
80862b7
0721_8
antoinegg1 Jul 21, 2025
79e2a81
0722_1
antoinegg1 Jul 22, 2025
3d2f7a9
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 22, 2025
7199ce2
0722_2
antoinegg1 Jul 22, 2025
eba0b5f
0722_3
antoinegg1 Jul 22, 2025
229f101
0722_4
antoinegg1 Jul 22, 2025
ea12141
0722_4
antoinegg1 Jul 22, 2025
c27a51b
0722_5
antoinegg1 Jul 22, 2025
5c0662f
0722_6
antoinegg1 Jul 22, 2025
af2f80c
0722_7
antoinegg1 Jul 22, 2025
8815be6
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 22, 2025
eff8f09
0723_1
antoinegg1 Jul 23, 2025
6bde86a
reformatted
antoinegg1 Jul 23, 2025
52c9447
clang-reformatted
antoinegg1 Jul 23, 2025
25884f5
clang-reformatted2
antoinegg1 Jul 23, 2025
391bd85
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 23, 2025
dd14838
0723_1
antoinegg1 Jul 23, 2025
9ec2c3f
0723_1
antoinegg1 Jul 23, 2025
4041afb
0723_1
antoinegg1 Jul 23, 2025
2a2e2fe
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 23, 2025
8e82c59
0723_merge3
antoinegg1 Jul 23, 2025
d12dec2
0723_4
antoinegg1 Jul 23, 2025
82442b8
Merge branch 'lcy/refactor' of https://code.alipay.com/inclusionAI/AR…
antoinegg1 Jul 23, 2025
00b5d87
0723_reformatted_5
antoinegg1 Jul 23, 2025
9a16605
0724_1
antoinegg1 Jul 24, 2025
6c28d52
0724_1
antoinegg1 Jul 24, 2025
c816a3c
0724_merge1
antoinegg1 Jul 24, 2025
e97e33f
0724_merge2
antoinegg1 Jul 24, 2025
176ec4b
0724_merge3
antoinegg1 Jul 24, 2025
5118cfa
0724_merge3
antoinegg1 Jul 24, 2025
5690b52
0724_merge4
antoinegg1 Jul 24, 2025
84e2d75
Merge remote-tracking branch 'backup/lite' into lcy/refactor
antoinegg1 Jul 24, 2025
1bc9310
0724_merge5
antoinegg1 Jul 24, 2025
13fc236
0724_merge6
antoinegg1 Jul 24, 2025
27fd51a
0724_merge7
antoinegg1 Jul 24, 2025
e705db1
0724_merge8
antoinegg1 Jul 24, 2025
6aeeabf
0724_4
antoinegg1 Jul 24, 2025
f5924b1
0724_merge7
antoinegg1 Jul 24, 2025
84be9c9
Merge remote-tracking branch 'backup/lite' into lcy/refactor
antoinegg1 Jul 24, 2025
6255ad5
0724-merge8
antoinegg1 Jul 24, 2025
b8549ac
0724_merge8
antoinegg1 Jul 24, 2025
4198cd6
0725_1
antoinegg1 Jul 25, 2025
3c272ff
0725_6
antoinegg1 Jul 25, 2025
8eaced4
0725_7
antoinegg1 Jul 25, 2025
4f8b17f
0725_4padded_image
antoinegg1 Jul 25, 2025
cc3c6bb
0725_9padded_image
antoinegg1 Jul 25, 2025
60ac19a
0725_10padded_image
antoinegg1 Jul 25, 2025
fb1796d
0725_11
antoinegg1 Jul 25, 2025
a4ad671
0725
antoinegg1 Jul 25, 2025
6b8bfcf
0725_12
antoinegg1 Jul 25, 2025
4ff813a
0725_format
antoinegg1 Jul 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions arealite/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,8 +611,15 @@ class ClusterSpecConfig:

@dataclass
class DatasetConfig:
path: str = field(
default=MISSING,
metadata={
"help": "Path to the dataset. Can be a local path or a HuggingFace dataset name."
},
)
type: Optional[str] = field(
default=None, metadata={"help": "Type of implemented dataset"}
default=None,
metadata={"help": "Type of training method.e.g., 'sft', 'rl', etc."},
)
batch_size: int = field(
default=1, metadata={"help": "Batch size of the dataloader"}
Expand Down Expand Up @@ -743,7 +750,7 @@ class BaseExperimentConfig:
tokenizer_path: str = field(default="")

train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: Optional[DatasetConfig] = field(default=None)

saver: SaverConfig = field(default_factory=SaverConfig)
checkpointer: SaverConfig = field(default_factory=SaverConfig)
Expand Down
18 changes: 16 additions & 2 deletions arealite/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple

from transformers import PreTrainedTokenizerFast
import torch
from gymnasium.core import ActType, ObsType
from PIL.Image import Image as ImageObject
from transformers import AutoProcessor, PreTrainedTokenizerFast

from arealite.api.cli_args import GenerationHyperparameters, SaverConfig
from arealite.utils.network import find_free_ports, gethostip
Expand Down Expand Up @@ -51,6 +54,16 @@ def output_len(self) -> int:
return len(self.output_tokens)


@dataclass
class VLMRequest(LLMRequest):
image_data: Optional[List[ImageObject | str]] = field(default_factory=list)


@dataclass
class VLMResponse(LLMResponse):
input_images: List[ImageObject | str] = field(default_factory=list)


@dataclass
class FinetuneSpec:
total_train_epochs: int
Expand Down Expand Up @@ -216,7 +229,8 @@ class SaveLoadMeta:
path: str
weight_format: str
with_optim: bool
tokenizer: Optional[PreTrainedTokenizerFast]
tokenizer: PreTrainedTokenizerFast | None
processor: AutoProcessor | None
base_model_path: str | None
naive_distributed: bool = False

Expand Down
47 changes: 47 additions & 0 deletions arealite/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional

import transformers

VALID_DATASETS = ["gsm8k", "clevr_count_70k"]


def get_custom_dataset(
path: str,
rank: int,
world_size: int,
type: str = "sft",
split: Optional[str] = None,
tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None,
processor: Optional[transformers.AutoProcessor] = None,
**kwargs,
):

if "gsm8k" in path and type == "sft":
from examples.arealite.dataset.gsm8k import get_gsm8k_sft_dataset

return get_gsm8k_sft_dataset(path, split, tokenizer, rank, world_size, **kwargs)
elif "gsm8k" in path and type == "rl":
from examples.arealite.dataset.gsm8k import get_gsm8k_rl_dataset

return get_gsm8k_rl_dataset(path, split, rank, world_size, **kwargs)
elif "clevr_count_70k" in path and type == "sft":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_sft_dataset,
)

return get_clevr_count_70k_sft_dataset(
path, split, processor, rank, world_size, **kwargs
)
elif "clevr_count_70k" in path and type == "rl":
from examples.arealite.dataset.clevr_count_70k import (
get_clevr_count_70k_rl_dataset,
)

return get_clevr_count_70k_rl_dataset(
path, split, processor, rank, world_size, **kwargs
)
else:
raise ValueError(
f"Dataset {path} with split {split} and training type {type} is not supported. "
f"Supported datasets are: {VALID_DATASETS}. "
)
83 changes: 62 additions & 21 deletions arealite/engine/base_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoProcessor,
PretrainedConfig,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
Expand All @@ -29,8 +31,8 @@
unsqueeze_mb_list,
)
from arealite.utils.fsdp import get_cosine_schedule_with_warmup
from arealite.utils.model import disable_dropout_in_model
from realhf.api.core.data_api import load_hf_tokenizer
from arealite.utils.model import VALID_VISION_MODELS, disable_dropout_in_model
from realhf.api.core.data_api import load_hf_processor_and_tokenizer, load_hf_tokenizer
from realhf.base import constants, logging

logger = logging.getLogger("Base HF Engine")
Expand All @@ -44,6 +46,7 @@ def __init__(self, config: TrainEngineConfig):
self.model: torch.nn.Module
self.optimizer: torch.optim.Optimizer
self.tokenizer: PreTrainedTokenizerFast
self.processor: AutoProcessor | None = None
# huggingface model config
self.model_config: PretrainedConfig
self._version: int = 0
Expand All @@ -54,6 +57,12 @@ def __init__(self, config: TrainEngineConfig):
self._parallelism_group: dist.ProcessGroup
self.weight_update_group_initialized = False

self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.is_vision_model = self.model_config.model_type in VALID_VISION_MODELS

self.world_size = int(os.environ["WORLD_SIZE"])

def set_version(self, version: int):
Expand Down Expand Up @@ -92,32 +101,54 @@ def create_device_model(self):
self.device = torch.device(int(os.environ["LOCAL_RANK"]))

dtype = getattr(torch, self.config.dtype)
self.model_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
)
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):

if self.is_vision_model:
if dtype == torch.float16:
raise ValueError(
"Vision models do not support float16 dtype. Please use bfloat16."
)
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
raise ValueError(
"Vision models do not support initialization from scratch. Please use a pretrained model."
)
else:
model = AutoModelForCausalLM.from_pretrained(
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
self.config.path
)

tik = time.perf_counter()
with torch.device("cuda"):
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)
if self.config.disable_dropout:
disable_dropout_in_model(model)
else:
self.tokenizer = load_hf_tokenizer(self.config.path)
tik = time.perf_counter()
with torch.device("cuda"):
if self.config.init_from_scratch:
# initialize scratch model from config
# NOTE: VLM cannot directly load state dict using this
# random initialized model, so otherwise we call
# from_pretrained rather than loading weights into this random model.
model = AutoModelForCausalLM.from_config(
self.model_config,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
torch_dtype=dtype,
attn_implementation=self.config.attn_impl,
)
if self.config.disable_dropout:
disable_dropout_in_model(model)

if self.config.gradient_checkpointing:
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
Expand Down Expand Up @@ -218,9 +249,15 @@ def step_lr_scheduler(self):

def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
assert "attention_mask" in input_ and "input_ids" in input_
if self.is_vision_model:
assert (
"pixel_values" in input_ and "image_grid_thw" in input_
), "For vision-language models, pixel_values and image_grid_thw must be present in input_"

if isinstance(input_, dict):
input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]])
input_ = amend_position_ids(input_)

mb_list = split_padded_tensor_dict_into_mb_list(input_, self.config.mb_spec)
logger.info(
f"Microbatch #tokens (rank {dist.get_rank()}): {mb_list.group_lens}"
Expand All @@ -230,13 +267,15 @@ def prepare_mb_list(self, input_: TensorDict) -> MicroBatchList:
# NOTE: We unsqueeze here because huggingface transformer models requires
# packed input to be of shape [1, total_seqlen].
mb_list = unsqueeze_mb_list(mb_list)

# FIXME: the resulting max_seqlen is a tensor rather than an integer
for mb in mb_list.mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False
for mb in mb_list.padded_mbs:
mb["max_seqlen"] = int(mb["max_seqlen"])
mb["use_cache"] = False

return mb_list

def train_batch(
Expand Down Expand Up @@ -264,11 +303,13 @@ def train_batch(
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):

outputs = self.model(**padded_mb_input)

logits = outputs.logits.squeeze(0)
logits = logits[:-pad_length] if pad_length > 0 else logits
loss = loss_fn(logits, mb_input)

loss_scale = loss_weight_fn(mb_input) / total_loss_weight

# Scale loss for accumulation
Expand Down
18 changes: 13 additions & 5 deletions arealite/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
StateDictOptions,
get_model_state_dict,
)
from transformers import PreTrainedTokenizerFast
from transformers import AutoProcessor, PreTrainedTokenizerFast

from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec
Expand All @@ -27,6 +27,7 @@
fsdp2_load_full_state_dict,
)
from arealite.utils.save_load import get_state_dict_from_repo_id_or_path
from realhf.api.core.data_api import load_hf_processor_and_tokenizer
from realhf.base import logging, name_resolve, names, pkg_version

logger = logging.getLogger("FSDPEngine")
Expand Down Expand Up @@ -77,7 +78,7 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):

def save(self, meta: SaveLoadMeta):
if meta.weight_format == "hf":
self._save_model_to_hf(meta.path, meta.tokenizer)
self._save_model_to_hf(meta.path, meta.tokenizer, meta.processor)
elif meta.weight_format == "dcp":
# TODO: implement DCP save/load for FSDP
raise NotImplementedError("DCP format saving is not implemented yet. ")
Expand All @@ -100,7 +101,10 @@ def load(self, meta: SaveLoadMeta):
self.load_optimizer_state(meta.path)

def _save_model_to_hf(
self, path: str, tokenizer: Optional[PreTrainedTokenizerFast]
self,
path: str,
tokenizer: Optional[PreTrainedTokenizerFast],
processor: Optional[AutoProcessor],
):
"""Save model in HuggingFace format."""
if self.model is None:
Expand All @@ -119,6 +123,8 @@ def _save_model_to_hf(
self.model_config.save_pretrained(path)
if tokenizer is not None:
tokenizer.save_pretrained(path)
if processor is not None:
processor.save_pretrained(path)

dist.barrier(device_ids=[self.device.index])

Expand All @@ -144,13 +150,13 @@ def upload_weights(self, meta: WeightUpdateMeta):
dist.barrier(device_ids=[self.device.index])
torch.cuda.synchronize()
elif meta.type == "disk":
self._save_model_to_hf(meta.path, self.tokenizer)
self._save_model_to_hf(meta.path, self.tokenizer, self.processor)
# dist.barrier() are called when _save_model_to_hf finished
if dist.get_rank() == 0:
update_name = names.update_weights_from_disk(
self.config.experiment_name,
self.config.trial_name,
self.model_version,
self.get_version(),
)
name_resolve.add(
update_name, str(datetime.now().timestamp()), keepalive_ttl=120
Expand Down Expand Up @@ -247,9 +253,11 @@ def train_batch(
loss.backward()

# NOTE: grad norm clip function is different

grad_norm = fsdp2_clip_grad_norm_(
self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping
)

if not torch.isfinite(grad_norm):
self.optimizer.zero_grad()
update_successful = False
Expand Down
1 change: 0 additions & 1 deletion arealite/engine/sft/lm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
logprobs = torch.where(loss_mask, logprobs, 0)

loss = -logprobs.sum() / loss_mask.count_nonzero()

with torch.no_grad():
seqlogp = torch.zeros(
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
Expand Down
Loading
Loading