Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 67 additions & 20 deletions arealite/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,20 +676,6 @@ class ClusterSpecConfig:
"help": "Root for logs and checkpoints. Should be available to all nodes."
},
)
gpu_type: str = field(
default="tesla", metadata={"help": "GPU type of the cluster. Used by slurm."}
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
gpu_image: str = field(default="", metadata={"help": "slurm image for trainers."})
cpu_image: str = field(default="", metadata={"help": "slurm image for CPU jobs."})
gpu_infer_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)
node_name_prefix: str = field(
default="slurmd-", metadata={"help": "Node prefix for a slurm cluster."}
)
n_nodes: int = field(
default=32,
metadata={
Expand Down Expand Up @@ -725,6 +711,72 @@ class DatasetConfig:
drop_last: bool = field(default=True)


@dataclass
class SlurmLauncherConfig:
"""Configuration for launching the SGLang server with Slurm."""

srun_additional_args: str = field(
default="--overlap --mpi=pmi2 -K --chdir $PWD",
metadata={"help": "Additional arguments to pass to the srun command."},
)
container_type: str = field(
default="apptainer",
metadata={
"help": "Type of containers used in slurm",
"choices": ["apptainer", "none"],
},
)
mount: str = field(
default="/storage:/storage", metadata={"help": "Mount path for slurm."}
)
trainer_image: str = field(
default="", metadata={"help": "slurm image for trainers."}
)
inference_server_image: str = field(
default="", metadata={"help": "slurm image for LLM inference."}
)


@dataclass
class LauncherConfig:
"""Configuration for launching the SGLang server."""

inference_server_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for inference server. "},
)
inference_server_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for inference server in MB. "},
)
trainer_cpus_per_gpu: int = field(
default=4,
metadata={"help": "Number of CPUs allocated per GPU for training. "},
)
trainer_mem_per_gpu: int = field(
default=32 * 1024,
metadata={"help": "Memory allocated per GPU for training in MB. "},
)
inference_server_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for inference server, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
trainer_env_vars: str = field(
default="",
metadata={
"help": "Environment variables for training, seperated by commas. "
"Example: 'ENV1=val1,ENV2=val2'. "
},
)
slurm: SlurmLauncherConfig = field(
default_factory=SlurmLauncherConfig,
metadata={"help": "Slurm launcher configuration."},
)


@dataclass
class BaseExperimentConfig:
# NOTE: we need this unified config class because different experiments
Expand All @@ -742,12 +794,6 @@ class BaseExperimentConfig:
default_factory=ClusterSpecConfig,
metadata={"help": "Cluster specification. Mainly used by slurm."},
)
n_nodes: int = field(
default=1, metadata={"help": "Number of nodes for experiment."}
)
n_gpus_per_node: int = field(
default=8, metadata={"help": "Number of GPUs per node for this experiment."}
)
allocation_mode: str = field(
default="",
metadata={
Expand Down Expand Up @@ -785,6 +831,7 @@ class BaseExperimentConfig:

server_only: bool = False
sglang: SGLangConfig = field(default_factory=SGLangConfig)
launcher: LauncherConfig = field(default_factory=LauncherConfig)


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions arealite/engine/base_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def train_batch(

# Scale loss for accumulation
# Revert gradient averaging across dp ranks
# FIXME: should be DP size
loss_scale *= self.world_size

loss *= loss_scale
Expand All @@ -286,8 +287,6 @@ def train_batch(
update_successful = True

current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
Expand Down
17 changes: 3 additions & 14 deletions arealite/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import dis
import gc
import os
import threading
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple

import torch
import torch.distributed as dist
Expand All @@ -14,14 +11,7 @@
StateDictOptions,
get_model_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedTokenizerFast,
get_constant_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from transformers import PreTrainedTokenizerFast

from arealite.api.cli_args import TrainEngineConfig
from arealite.api.engine_api import FinetuneSpec, SaveLoadMeta, WeightUpdateMeta
Expand Down Expand Up @@ -232,6 +222,7 @@ def train_batch(
for i, (pad_length, padded_mb_input, mb_input) in enumerate(
zip(mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs)
):
self.model.set_requires_gradient_sync(i == len(mb_list.mbs) - 1)
outputs = self.model(**padded_mb_input)

logits = outputs.logits.squeeze(0)
Expand All @@ -258,8 +249,6 @@ def train_batch(
update_successful = True

current_lr = self.lr_scheduler.get_last_lr()[0]
# Optimizer step
self.optimizer.step()
return dict(
update_successful=float(update_successful),
grad_norm=float(grad_norm) if grad_norm is not None else float("nan"),
Expand Down
12 changes: 3 additions & 9 deletions arealite/engine/sft/lm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,9 @@ def compute_packed_sft_loss(logits: torch.Tensor, input_: TensorDict) -> torch.T
cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64
)
for i in range(cu_seqlens.shape[0] - 1):
m = loss_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1]
assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], (
cu_seqlens,
logprobs.shape,
)
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (
m.numel() - m.count_nonzero()
)
m = loss_mask[cu_seqlens[i] : cu_seqlens[i + 1]]
logp = logprobs[cu_seqlens[i] : cu_seqlens[i + 1]]
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / (m.count_nonzero())

## Loggin stats
stats_tracker.denominator(
Expand Down
Loading
Loading