Skip to content
Merged
87 changes: 84 additions & 3 deletions arealite/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class MicroBatchSpec:
"""Specification for splitting micro-batches during training."""

n_mbs: int = field(
n_mbs: Optional[int] = field(
default=1,
metadata={
"help": "Number of micro-batches (or minimum number if max_tokens_per_mb is set). Used when max_tokens_per_mb is None or as minimum count",
Expand Down Expand Up @@ -161,7 +161,7 @@ class FSDPEngineConfig:


@dataclass
class HFEngineConfig:
class DeepSpeedAutoTPEngineConfig:
autotp_size: Optional[int] = field(
default=1,
metadata={"help": "DeepSpeed AutoTP size"},
Expand Down Expand Up @@ -201,7 +201,88 @@ class TrainEngineConfig:
)
backend: str = ""
fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
hf: HFEngineConfig = field(default_factory=HFEngineConfig)
ds_auto_tp: DeepSpeedAutoTPEngineConfig = field(
default_factory=DeepSpeedAutoTPEngineConfig
)


@dataclass
class PPOActorConfig(TrainEngineConfig):
# Core PPO/GRPO Parameters
group_size: int = field(
default=1, metadata={"help": "Number of sequences in each group"}
)
group_adv_norm: bool = field(
default=False,
metadata={
"help": "Normalize advantages within each prompt group rather than globally"
},
)
ppo_n_minibatches: int = field(
default=4, metadata={"help": "Number of minibatches for each PPO update"}
)
eps_clip: float = field(
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
)
c_clip: Optional[float] = field(
default=None,
metadata={
"help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping."
},
)
temperature: float = field(
default=1.0, metadata={"help": "Temperature during generation."}
)
# Reward
group_reward_norm: bool = field(
default=False,
metadata={
"help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias"
},
)
reward_scaling: float = field(
default=1.0, metadata={"help": "Reward scaling factor"}
)
reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"})
reward_clip: float = field(
default=20.0, metadata={"help": "Maximum absolute value for reward clipping"}
)
mask_no_eos_with_zero: bool = field(
default=False,
metadata={
"help": "Mask truncated generations (no EOS token) and exclude from training"
},
)

# Advantage Estimation
discount: float = field(
default=1.0, metadata={"help": "Discount factor for future rewards"}
)
gae_lambda: float = field(
default=1.0, metadata={"help": "Lambda parameter for GAE"}
)
adv_norm: bool = field(
default=True, metadata={"help": "Enable advantage normalization"}
)

# KL Control
kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"})

# Asynchronous RL
recompute_logprob: bool = field(
default=False,
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
behav_imp_weight_cap: Optional[float] = field(
default=None,
metadata={
"help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true"
},
)


@dataclass
Expand Down
10 changes: 5 additions & 5 deletions arealite/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class Scheduling:
cpu: int
gpu: int
mem: int
nodelist: str = None
exclude: str = None
partition: str = None
container_image: str = None
nodelist: Optional[str] = None
exclude: Optional[str] = None
partition: Optional[str] = None
container_image: Optional[str] = None
env_vars: Dict[str, str] = field(default_factory=dict)
# time utils from "https://slurm.schedmd.com/sbatch.html"
time_limit: Optional[str] = None # see "--time" option for format
Expand Down Expand Up @@ -105,7 +105,7 @@ def eval_batch(
def forward(
self,
input_: TensorDict,
output_seqlens: List[List[int]] | None = None,
output_seqlens: List[int] | None = None,
post_hook: Callable[[torch.Tensor, TensorDict], Any] | None = None,
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
) -> Any | None:
Expand Down
25 changes: 12 additions & 13 deletions arealite/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class AllocationType(enum.Enum):
@dataclass
class AllocationMode:
type_: AllocationType
parallel_strat: None | Dict[str, Dict[str, int]]
parallel_strat: Dict[str, Dict[str, int]]

@property
def gen_tp_size(self) -> int:
Expand Down Expand Up @@ -115,7 +115,7 @@ def from_str(cls, allocation_mode: str):
raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}")

@staticmethod
def extract_3d_alloc(allocation_mode: str) -> Dict | None:
def extract_parallelism_strategy(allocation_mode: str) -> Dict:
for x, y, z in itertools.permutations(["d", "t", "p"]):
pattern = rf"{x}(\d+){y}(\d+){z}(\d+)"
m = re.match(pattern, allocation_mode)
Expand All @@ -130,29 +130,28 @@ def extract_3d_alloc(allocation_mode: str) -> Dict | None:
z: c,
}
}
raise ValueError(
f"Unknown how to resolve parallelism strategy: {allocation_mode}"
)

@staticmethod
def extract_decoupled_alloc(allocation_mode: str) -> Dict | None:
def extract_decoupled_alloc(allocation_mode: str) -> Dict:
pattern = re.compile(
r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))"
)
m = pattern.match(allocation_mode)
if not m:
return
raise ValueError(
f"Unknown how to resolve decoupled allocation: {allocation_mode}"
)
if m.group(1):
gen_alloc = m.group(1)
other_alloc = m.group(2)
else:
gen_alloc = m.group(4)
other_alloc = m.group(3)
gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc)
if not gen_alloc:
return
other_alloc = AllocationMode.extract_3d_alloc(
other_alloc
) or AllocationMode.extract_key_value_alloc(other_alloc)
if not other_alloc:
return
gen_alloc = AllocationMode.extract_parallelism_strategy(gen_alloc)
other_alloc = AllocationMode.extract_parallelism_strategy(other_alloc)
other_alloc.update({"gen": gen_alloc["*"]})
return other_alloc

Expand All @@ -171,7 +170,7 @@ class SaveLoadMeta:
path: str
weight_format: str
with_optim: bool
tokenizer: PreTrainedTokenizerFast | None
tokenizer: Optional[PreTrainedTokenizerFast]
base_model_path: str | None
naive_distributed: bool = False

Expand Down
128 changes: 128 additions & 0 deletions arealite/engine/autotp_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os

import torch
import torch.distributed as dist
from safetensors.torch import save_file

from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
from arealite.engine.base_hf_engine import BaseHFEngine
from arealite.utils.save_load import (
get_state_dict_from_repo_id_or_path,
is_existing_local_path,
)
from realhf.base import constants, logging

logger = logging.getLogger("DeepSpeedAutoTPEngine")


class DeepSpeedAutoTPEngine(BaseHFEngine):
def __init__(self, config: TrainEngineConfig):
super().__init__(config)

def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None):
"""Initialize distributed communication and model."""
assert (
addr is None
), "DeepSpeedAutoTPEngine does not support remote initialization."
import deepspeed

self.create_process_group()

world_size = int(os.environ.get("WORLD_SIZE"))
deepspeed.init_distributed(
dist_backend="nccl",
world_size=world_size,
timeout=constants.NCCL_DEFAULT_TIMEOUT,
)
self.create_device_model()
# NOTE: the device context manager does not work here.
self.model = deepspeed.tp_model_init(
self.model,
tp_size=self.config.ds_auto_tp.autotp_size,
dtype=getattr(torch, self.config.dtype),
).to(self.device)
self.create_optimizer(ft_spec)
self.initialized = True

def _check_autotp(self):
tp_size = self.config.ds_auto_tp.autotp_size
config = self.model_config
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size

return (
num_attention_heads % tp_size == 0
and num_key_value_heads % tp_size == 0
and hidden_size % tp_size == 0
and intermediate_size % tp_size == 0
)

def save(self, meta: SaveLoadMeta):
if meta.weight_format != "naive_distributed":
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
if self.model is None:
raise RuntimeError("Model not initialized")

rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
os.makedirs(meta.path, exist_ok=True)
self.model_config.save_pretrained(
meta.path,
)
if meta.tokenizer is not None:
meta.tokenizer.save_pretrained(
meta.path,
)

state_dict = self.model.state_dict()
if hasattr(self.model, "module"):
state_dict = {
k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu()
for k, v in state_dict.items()
}
else:
state_dict = {k: v.cpu() for k, v in state_dict.items()}

# Only support store parameters from model partitions respectively
gathered_state_dicts = None
if rank == 0:
gathered_state_dicts = [None for _ in range(world_size)]
dist.gather_object(
obj=state_dict, object_gather_list=gathered_state_dicts, dst=0
)
if rank == 0:
for i, state_dict in enumerate(gathered_state_dicts):
save_file(state_dict, f"{meta.path}/rank_{i:02d}_model.safetensors")
if meta.with_optim:
self.save_optimizer_state(meta.path)

def load(self, meta: SaveLoadMeta):
if meta.weight_format != "naive_distributed":
raise ValueError(f"Unknown weight format {meta.weight_format}. ")
rank = dist.get_rank()
# Only support load full model parameters from huggingface
# and load model partition locally
if rank == 0 or is_existing_local_path(meta.path):
path = f"{meta.path}/rank_{rank:02d}_model.safetensors"
full_state = get_state_dict_from_repo_id_or_path(meta.path)

if hasattr(self.model, "module") and not hasattr(full_state):
full_state = {
f"module.{k}" if not k.startswith("module.") else k: v
for k, v in full_state.items()
}
self.model.load_state_dict(
full_state, strict=not self.model_config.tie_word_embeddings
)
if self.model_config.tie_word_embeddings:
self.model.tie_weights()

if meta.with_optim:
self.load_optimizer_state(meta.path)

def upload_weights(self, meta: WeightUpdateMeta):
raise ValueError(f"update weight not implemented {meta.type}")
Loading
Loading