diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py index 2f41b5c0a..4763d5134 100644 --- a/arealite/api/cli_args.py +++ b/arealite/api/cli_args.py @@ -104,6 +104,14 @@ class FSDPEngineConfig: ) +@dataclass +class HFEngineConfig: + autotp_size: Optional[int] = field( + default=1, + metadata={"help": "DeepSpeed AutoTP size"}, + ) + + @dataclass class TrainEngineConfig: experiment_name: str = MISSING @@ -136,6 +144,7 @@ class TrainEngineConfig: ) backend: str = "" fsdp: FSDPEngineConfig = field(default_factory=FSDPEngineConfig) + hf: HFEngineConfig = field(default_factory=HFEngineConfig) @dataclass diff --git a/arealite/api/io_struct.py b/arealite/api/io_struct.py index 3033af8c3..40ff81f67 100644 --- a/arealite/api/io_struct.py +++ b/arealite/api/io_struct.py @@ -175,6 +175,7 @@ class SaveLoadMeta: with_optim: bool tokenizer: PreTrainedTokenizerFast | None base_model_path: str | None + naive_distributed: bool = False @dataclass diff --git a/arealite/engine/hf_engine.py b/arealite/engine/hf_engine.py index cd685d0c4..b3c0b6532 100644 --- a/arealite/engine/hf_engine.py +++ b/arealite/engine/hf_engine.py @@ -1,123 +1,131 @@ -import asyncio -import functools -import math +import gc import os +import time from typing import Any, Callable, Dict, List, Optional import torch import torch.distributed as dist import transformers -from transformers import AutoConfig, AutoModelForCausalLM +from safetensors.torch import save_file +from tensordict import TensorDict +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + get_constant_schedule_with_warmup, + get_linear_schedule_with_warmup, +) -from arealite.api.cli_args import EngineConfig, ParallelismConfig, TrainingArgs -from arealite.api.engine_api import TrainEngine -from arealite.api.io_struct import FinetuneSpec -from arealite.api.llm_client_api import LLMClient -from arealite.utils import ( - get_state_dict_from_repo_id_or_path, - recorder_list, - split_dict_tensor_with_cu_seqlens, +from arealite.api.cli_args import MicroBatchSpec, TrainEngineConfig +from arealite.api.engine_api import ( + FinetuneSpec, + SaveLoadMeta, + TrainEngine, + WeightUpdateMeta, +) +from arealite.utils.data import ( + MicroBatchList, + amend_position_ids, + pack_tensor_dict, + pad_and_stack_tensors_along_first_dim, + pad_mb_list, + reorder_list, + split_packed_tensor_dict_into_mb_list, unpack_sequence, + unsqueeze_mb_list, ) -from realhf.base import constants - - -def get_cosine_schedule_with_warmup( - optimizer: torch.optim.Optimizer, - num_warmup_steps: int, - num_training_steps: int, - min_lr_ratio: float = 0.0, - num_cycles: float = 0.5, - last_epoch: int = -1, -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - Args: - optimizer (:class:`~torch.optim.Optimizer`): - The optimizer for which to schedule the learning rate. - num_warmup_steps (:obj:`int`): - The number of steps for the warmup phase. - num_training_steps (:obj:`int`): - The total number of training steps. - min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): - The minimum lr ratio w.r.t the maximum. - num_cycles (:obj:`float`, `optional`, defaults to 0.5): - The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 - following a half-cosine). - last_epoch (:obj:`int`, `optional`, defaults to -1): - The index of the last epoch when resuming training. - Return: - :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0 - coef = (1 - min_lr_ratio) * 0.5 - intercept = (1 + min_lr_ratio) * 0.5 - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float( - max(1, num_training_steps - num_warmup_steps) - ) - x = math.cos(math.pi * float(num_cycles) * 2.0 * progress) - return max(0.0, x * coef + intercept) +from arealite.utils.fsdp import get_cosine_schedule_with_warmup +from arealite.utils.save_load import ( + get_state_dict_from_repo_id_or_path, + is_existing_local_path, +) +from realhf.api.core.data_api import load_hf_tokenizer +from realhf.base import logging, name_resolve, names - return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) +logger = logging.getLogger("HFEngine") class HFEngine(TrainEngine): - """Simplified HF engine for transformer models.""" - - def __init__(self, args: TrainingArgs, engine_config: EngineConfig): - super().__init__(args, engine_config) + def __init__(self, config: TrainEngineConfig): + self.config = config + self.optimizer_config = config.optimizer self.model = None self.optimizer = None + self.tokenizer = None + # huggingface model config self.model_config = None - + # initialization + self.initialized = False self.weight_update_group_initialized = False - def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec): - """Initialize model in single node.""" - if not dist.is_initialized(): - dist.init_process_group(backend="nccl") - if dist.get_world_size() > 1: - raise RuntimeError( - "Distributed training is not supported in this engine. " - "Please use FSDP for distributed training." - ) - torch.cuda.set_device("cuda:0") - - dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16 + def train(self, mode: bool = True): + assert self.model is not None + self.model.train(mode=mode) + return self + + def initialize(self, addr: str | None, ft_spec: FinetuneSpec | None): + """Initialize distributed communication and model.""" + assert addr is None, "HFEngine does not support remote initialization." + + world_size = int(os.environ.get("WORLD_SIZE", 0)) + if not dist.is_initialized() and world_size > 1: + try: + import deepspeed + except ImportError: + print( + "Warning: deepspeed is not installed. Some functionality may be disabled." + ) + deepspeed.init_distributed(dist_backend="nccl", world_size=world_size) + + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + self.device = torch.device(f"cuda:{local_rank}") + + dtype = torch.bfloat16 if self.config.bf16 else torch.float16 self.model_config = AutoConfig.from_pretrained( - pretrained_model_name_or_path=self.engine_config.path, + pretrained_model_name_or_path=self.config.path, trust_remote_code=True, ) - with torch.device("cuda"): - # initialize scratch model from config - model = AutoModelForCausalLM.from_config( - self.model_config, - torch_dtype=dtype, - attn_implementation="flash_attention_2", + self.tokenizer = load_hf_tokenizer(self.config.path) + + self.model = AutoModelForCausalLM.from_config( + self.model_config, + torch_dtype=dtype, + attn_implementation=self.config.attn_impl, + ).to(f"cuda:{local_rank}") + + if not self.config.init_from_scratch: + # Load model from a initial checkpoint path, + # which should only be a huggingface checkpoint. + load_meta = SaveLoadMeta( + path=self.config.path, + weight_format="hf", + with_optim=False, + tokenizer=None, + base_model_path=self.config.path, + naive_distributed=False, ) - model = model.cuda() + self.load(load_meta) - self.model = model + if world_size > 1: + if self._check_autotp(): + self.model = deepspeed.tp_model_init( + self.model, tp_size=self.config.hf.autotp_size, dtype=dtype + ) + else: + raise RuntimeError("DeepSpeed AutoTP configuration error in HFEngine. ") # Set up optimizer - optimizer_config = self.engine_config.optimizer - if optimizer_config is not None: + if self.optimizer_config is not None: assert ( - optimizer_config.type == "adam" + self.optimizer_config.type == "adam" ), "Only AdamW optimizer is supported in this engine." - lr = optimizer_config.lr - weight_decay = optimizer_config.weight_decay - beta1 = optimizer_config.beta1 - beta2 = optimizer_config.beta2 - eps = optimizer_config.eps + lr = self.optimizer_config.lr + weight_decay = self.optimizer_config.weight_decay + beta1 = self.optimizer_config.beta1 + beta2 = self.optimizer_config.beta2 + eps = self.optimizer_config.eps self.optimizer = torch.optim.AdamW( self.model.parameters(), @@ -128,80 +136,293 @@ def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec): ) total_train_steps = ft_spec.total_train_steps num_warmup_steps = int( - optimizer_config.warmup_steps_proportion * total_train_steps + self.optimizer_config.warmup_steps_proportion * total_train_steps ) - self.lr_scheduler = get_cosine_schedule_with_warmup( - self.optimizer, - num_warmup_steps, - total_train_steps, - min_lr_ratio=optimizer_config.min_lr_ratio, + if self.optimizer_config.lr_scheduler_type == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + min_lr_ratio=self.optimizer_config.min_lr_ratio, + ) + elif self.optimizer_config.lr_scheduler_type == "linear": + self.lr_scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + ) + elif self.optimizer_config.lr_scheduler_type == "constant": + self.lr_scheduler = get_constant_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + ) + else: + raise ValueError( + f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" + ) + + self.initialized = True + + def _check_autotp(self): + tp_size = self.config.hf.autotp_size + config = self.model_config + num_attention_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + return ( + num_attention_heads % tp_size == 0 + and num_key_value_heads % tp_size == 0 + and hidden_size % tp_size == 0 + and intermediate_size % tp_size == 0 + ) + + def destroy(self): + """Destroy the engine and release GPU memory.""" + self.model = None + self.optimizer = None + gc.collect() + torch.cuda.empty_cache() + gc.collect() + self.initialized = False + + def save(self, meta: SaveLoadMeta): + if meta.weight_format == "hf": + self._save_model_to_hf(meta.path, meta.tokenizer, meta.naive_distributed) + elif meta.weight_format == "dcp": + # TODO: implement DCP save/load for HF + raise NotImplementedError("DCP format saving is not implemented yet. ") + else: + raise ValueError(f"Unknown weight format {meta.weight_format}. ") + + if meta.with_optim: + self._save_optimizer_state(meta.path) + + def load(self, meta: SaveLoadMeta): + if meta.weight_format == "hf": + self._load_model_from_hf(meta.path, meta.naive_distributed) + elif meta.weight_format == "dcp": + # TODO: implement DCP save/load for HF + raise NotImplementedError("DCP format loading is not implemented yet. ") + else: + raise ValueError(f"Unknown weight format {meta.weight_format}. ") + + if meta.with_optim: + self._load_optimizer_state(meta.path) + + def _save_optimizer_state(self, path: str): + assert self.optimizer is not None + os.makedirs(path, exist_ok=True) + torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) + + def _load_optimizer_state(self, path: str): + assert self.optimizer is not None + path = os.path.join(path, "optim.pt") + optimizer_state_dict = torch.load(path, weights_only=False) + self.optimizer.load_state_dict(optimizer_state_dict) + + def _save_model_to_hf( + self, + path: str, + tokenizer: Optional[transformers.PreTrainedTokenizerFast], + naive_distributed: bool, + ): + """Save model in HuggingFace format.""" + if self.model is None: + raise RuntimeError("Model not initialized") + + rank = dist.get_rank() + world_size = dist.get_world_size() + if rank == 0: + os.makedirs(path, exist_ok=True) + self.model_config.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + + if world_size > 1: + dist.barrier() + + state_dict = self.model.state_dict() + + if hasattr(self.model, "module"): + state_dict = { + k.replace("module.", "", 1) if k.startswith("module.") else k: v.cpu() + for k, v in state_dict.items() + } + else: + state_dict = {k: v.cpu() for k, v in state_dict.items()} + + if world_size > 1 and naive_distributed: + # Only support store parameters from model partitions respectively + gathered_state_dicts = None + if rank == 0: + gathered_state_dicts = [None for _ in range(world_size)] + + dist.gather_object( + obj=state_dict, object_gather_list=gathered_state_dicts, dst=0 ) - def train(self, mode: bool = True): - """Set the module in training mode.""" - return self.model.train(mode) + if rank == 0: + for i, state_dict in enumerate(gathered_state_dicts): + save_file(state_dict, f"{path}/rank_{i:02d}_model.safetensors") + else: + self.model.save_pretrained(path, state_dict=state_dict) + + if world_size > 1: + dist.barrier() + + def _load_model_from_hf(self, path: str, naive_distributed: bool): + """Load model from HuggingFace format.""" + + rank = dist.get_rank() + # Only support load full model parameters from huggingface + # and load model partition locally + if rank == 0 or is_existing_local_path(path): + if naive_distributed: + path = f"{path}/rank_{rank:02d}_model.safetensors" + full_state = get_state_dict_from_repo_id_or_path(path) + + if hasattr(self.model, "module") and not hasattr(full_state): + full_state = { + f"module.{k}" if not k.startswith("module.") else k: v + for k, v in full_state.items() + } + self.model.load_state_dict( + full_state, strict=not self.model_config.tie_word_embeddings + ) + + if self.model_config.tie_word_embeddings: + self.model.tie_weights() + + def upload_weights(self, meta: WeightUpdateMeta): + if meta.type == "nccl": + if not self.weight_update_group_initialized: + self._init_distributed_weight_update(meta) + self._update_weights_from_distributed() + elif meta.type == "disk": + self._save_model_to_hf(meta.path, self.tokenizer, meta.naive_distributed) + update_name = names.update_weights_from_disk( + self.config.experiment_name, + self.config.trial_name, + meta.model_version, + ) + name_resolve.add(update_name, str(time.time_ns()), keepalive_ttl=120) + else: + raise ValueError(f"Unknown weight update type {meta.type}") + + def _init_distributed_weight_update(self, meta: WeightUpdateMeta): + raise NotImplementedError( + "Distributed weight update is not implemented for HFEngine yet. " + ) + + def _update_weights_from_distributed(self): + raise NotImplementedError( + "Distributed weight update is not implemented for HFEngine yet. " + ) + + def step_lr_scheduler(self): + assert self.lr_scheduler is not None + return self.lr_scheduler.step() + + def _prepare_mb_list(self, input_: TensorDict) -> MicroBatchList: + assert "attention_mask" in input_ and "input_ids" in input_ + if isinstance(input_, dict): + input_ = TensorDict(input_, batch_size=[input_["input_ids"].shape[0]]) + input_ = amend_position_ids(input_) + packed_input = pack_tensor_dict(input_) + mb_list = split_packed_tensor_dict_into_mb_list( + packed_input, + self.config.mb_spec, + ) + mb_list = pad_mb_list(mb_list, pad_value=0.0) + # NOTE: We unsqueeze here because huggingface transformer models requires + # packed input to be of shape [1, total_seqlen]. + mb_list = unsqueeze_mb_list(mb_list) + return mb_list def train_batch( self, - input_: Dict, + input_: TensorDict, loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], loss_weight_fn: Callable[[Dict], float], - ) -> Dict: + ) -> Dict[str, float]: """Train on a batch using gradient accumulation.""" + input_ = input_.to(self.device) assert self.optimizer is not None + assert self.optimizer_config is not None assert self.lr_scheduler is not None self.optimizer.zero_grad() - mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + mb_list = self._prepare_mb_list(input_) + total_loss_weight = torch.tensor( - sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32 + sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 ) assert total_loss_weight != 0 - for mb_input in mb_splits.mbs: - outputs = self.model(**mb_input) - loss = loss_fn(outputs.logits, mb_input) + # Process microbatches with gradient accumulation + for pad_length, padded_mb_input, mb_input in zip( + mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs + ): + outputs = self.model(**padded_mb_input) + + logits = outputs.logits.squeeze(0) + logits = logits[:-pad_length] if pad_length > 0 else logits + loss = loss_fn(logits, mb_input) loss_scale = loss_weight_fn(mb_input) / total_loss_weight + loss *= loss_scale loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), - self.engine_config.optimizer.gradient_clipping, + self.optimizer_config.gradient_clipping, norm_type=2.0, error_if_nonfinite=False, foreach=None, ) + if not torch.isfinite(grad_norm): + self.optimizer.zero_grad() + update_successful = False + else: + self.optimizer.step() + update_successful = True + current_lr = self.lr_scheduler.get_last_lr()[0] # Optimizer step self.optimizer.step() - - return { - "grad_norm": grad_norm, - "lr": current_lr, - } + return dict( + update_successful=float(update_successful), + grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), + lr=current_lr, + ) @torch.no_grad() def eval_batch( self, - input_: Dict, + input_: TensorDict, loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], loss_weight_fn: Callable[[Dict], float], ) -> torch.Tensor | None: """Evaluate on a batch.""" - mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + mb_list = self._prepare_mb_list(input_) total_loss_weight = torch.tensor( - sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32 + sum([loss_weight_fn(mb) for mb in mb_list.mbs]), dtype=torch.float32 ) assert total_loss_weight != 0 total_loss = 0.0 total_weight = 0.0 - for mb_input in mb_splits.mbs: - outputs = self.model(**mb_input) - loss = loss_fn(outputs.logits, mb_input) + for pad_length, padded_mb_input, mb_input in zip( + mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs + ): + outputs = self.model(**padded_mb_input) + logits = outputs.logits.squeeze(0) + logits = logits[:-pad_length] if pad_length > 0 else logits + loss = loss_fn(logits, mb_input) # Simple weight calculation (could be improved) loss_scale = loss_weight_fn(mb_input) / total_loss_weight @@ -213,95 +434,34 @@ def eval_batch( @torch.no_grad() def forward( self, - input_: Dict, + input_: TensorDict, output_seqlens: List[int] | None = None, post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, - aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1), + aggregate_fn: Callable[[List[Any]], Any] = torch.cat, ) -> Any | None: """Forward pass with optional post-processing.""" - mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + cu_seqlens = pack_tensor_dict(input_)["cu_seqlens"] + mb_list = self._prepare_mb_list(input_) + if output_seqlens is None: - cu_seqlens = input_["cu_seqlens"] output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() results = [] - for mb_input in mb_splits.mbs: - outputs = self.model(**mb_input) + for pad_length, padded_mb_input, mb_input in zip( + mb_list.padding_lengths, mb_list.padded_mbs, mb_list.mbs + ): + outputs = self.model(**padded_mb_input) + logits = outputs.logits.squeeze(0) + logits = logits[:-pad_length] if pad_length > 0 else logits + if post_hook: - result = post_hook(outputs.logits, mb_input) + result = post_hook(logits, mb_input) results.append(result) else: - results.append(outputs.logits) + results.append(logits) res = aggregate_fn(results) - output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices] - unpacked = unpack_sequence(res, lens=output_seqlens, dim=1) - return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices)) - - def step_lr_scheduler(self): - """Step the learning rate scheduler.""" - return self.lr_scheduler.step() - - def save_model_to_hf( - self, - path: str, - tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None, - base_model_path: Optional[str] = None, - ): - """Save model in HuggingFace format.""" - if self.model is None: - raise RuntimeError("Model not initialized") - - os.makedirs(path, exist_ok=True) - - state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()} - self.model.save_pretrained(path, state_dict=state_dict) - self.model_config.save_pretrained(path) - if tokenizer is not None: - tokenizer.save_pretrained(path) - - def load_model_from_hf(self, path: str): - """Load model from HuggingFace format.""" - full_state = get_state_dict_from_repo_id_or_path(path) - self.model.load_state_dict( - full_state, strict=not self.model_config.tie_word_embeddings - ) - if self.model_config.tie_word_embeddings: - self.model.tie_weights() - - def save_optimizer_state(self, path: str): - """Save optimizer state.""" - if self.optimizer is None: - raise RuntimeError("Optimizer not initialized") - - os.makedirs(path, exist_ok=True) - torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt")) - - def load_optimizer_state(self, path: str): - """Load optimizer state.""" - if self.optimizer is None: - raise RuntimeError("Optimizer not initialized") - - optimizer_path = os.path.join(path, "optimizer.pt") - if os.path.exists(optimizer_path): - self.optimizer.load_state_dict( - torch.load(optimizer_path, map_location="cpu") - ) - else: - raise RuntimeError(f"Optimizer state file not found: {optimizer_path}") - - async def aupdate_weights_to(self, llm_client: LLMClient): - path = constants.get_param_realloc_path(self.args) - self.save_model_to_hf(path) - tasks = [ - llm_client.aupdate_weights_from_disk(server_info=server_info, path=path) - for server_info in llm_client.get_healthy_servers() - ] - await asyncio.gather(*tasks) - - def update_weights_to(self, llm_client: LLMClient): - loop = asyncio.new_event_loop() - try: - loop.run_until_complete(self.aupdate_weights_to(llm_client)) - finally: - loop.close() + output_seqlens = [output_seqlens[i] for i in mb_list.forward_indices] + unpacked = unpack_sequence(res, lens=output_seqlens, dim=0) + reordered = reorder_list(unpacked, mb_list.backward_indices) + return pad_and_stack_tensors_along_first_dim(reordered) diff --git a/arealite/tests/test_fsdp_engine.py b/arealite/tests/test_engine.py similarity index 84% rename from arealite/tests/test_fsdp_engine.py rename to arealite/tests/test_engine.py index d295408d3..6c5d07a3f 100644 --- a/arealite/tests/test_fsdp_engine.py +++ b/arealite/tests/test_engine.py @@ -1,7 +1,7 @@ # Copyright 2025 Ant Group Inc. # Licensed under the Apache License, Version 2.0 -"""Test script for HF Engine implementation.""" +"""Test script for Engine implementation.""" import os from typing import Dict @@ -52,29 +52,43 @@ def mock_input( ) -def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor: - """Mock loss function for testing.""" - return torch.mean(logits) +def get_engine(engine_type: str, model_path: str): + from arealite.engine.fsdp_engine import FSDPEngine + from arealite.engine.hf_engine import HFEngine - -@pytest.fixture(scope="module") -def engine(): - os.environ["WORLD_SIZE"] = "1" - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "7777" + engine_cls = {"hf": HFEngine, "fsdp": FSDPEngine}[engine_type] engine_config = TrainEngineConfig( - experiment_name="test-fsdp-engine", + experiment_name=f"test-{engine_type}-engine", trial_name="test0", - path=MODEL_PATH, + path=model_path, optimizer=OptimizerConfig(), ) - engine = FSDPEngine(engine_config) + engine = engine_cls(engine_config) ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) engine.initialize(None, ft_spec) - print("✓ Engine created successfully") + return engine + + +def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor: + """Mock loss function for testing.""" + return torch.mean(logits) + + +@pytest.fixture(scope="module", params=["fsdp", "hf"]) +def engine(request): + os.environ.update( + { + "WORLD_SIZE": "1", + "RANK": "0", + "LOCAL_RANK": "0", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "7777", + } + ) + + engine = get_engine(request.param, MODEL_PATH) + print(f"✓ {request.param.upper()} Engine created successfully") yield engine diff --git a/arealite/utils/save_load.py b/arealite/utils/save_load.py index 7df9a3f7b..19bc9850d 100644 --- a/arealite/utils/save_load.py +++ b/arealite/utils/save_load.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Dict import torch @@ -41,18 +42,21 @@ def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict: else: # Assume it's a local path local_path = repo_id_or_path - if not os.path.isdir(local_path): - raise ValueError( - f"Local path {local_path} does not exist or is not a directory, " - f"or {local_path} is a huggingface repo id but huggingface_hub is not installed." - ) # Step 3: Load all .safetensors and .bin files file_paths_to_load = [] - for filename in os.listdir(local_path): - filepath = os.path.join(local_path, filename) - if filename.endswith(".safetensors") or filename.endswith(".bin"): - file_paths_to_load.append(filepath) + if os.path.isdir(local_path): + for filename in os.listdir(local_path): + filepath = os.path.join(local_path, filename) + if filename.endswith(".safetensors") or filename.endswith(".bin"): + file_paths_to_load.append(filepath) + elif os.path.isfile(local_path): + file_paths_to_load.append(local_path) + else: + raise ValueError( + f"Local path {local_path} does not exist or is not a valid path, " + f"or {local_path} is a huggingface repo id but huggingface_hub is not installed." + ) def _load(filepath: str): if filepath.endswith(".safetensors"): @@ -82,3 +86,11 @@ def _load(filepath: str): except Exception as e: raise RuntimeError(f"Error loading checkpoint from {path}: {e}") return state_dict + + +def is_existing_local_path(path: str) -> bool: + try: + path_obj = Path(path) + return path_obj.exists() and (path_obj.is_file() or path_obj.is_dir()) + except (ValueError, OSError): + return False diff --git a/pyproject.toml b/pyproject.toml index 4f752ebc5..cd4e26dae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ dependencies = [ # Distributed computing "ray", "redis", + "deepspeed>=0.17.2", # Web frameworks "fastapi>=0.115.12", diff --git a/requirements.txt b/requirements.txt index 5318511cd..178e95835 100644 --- a/requirements.txt +++ b/requirements.txt @@ -74,3 +74,4 @@ swanlab[dashboard] torchdata autoflake tensordict +deepspeed>=0.17.2