diff --git a/3rdparty/Automodel-workspace/Automodel b/3rdparty/Automodel-workspace/Automodel index 910f4e0402..1d42deb981 160000 --- a/3rdparty/Automodel-workspace/Automodel +++ b/3rdparty/Automodel-workspace/Automodel @@ -1 +1 @@ -Subproject commit 910f4e0402ec3af0c3b8642639f0347732067630 +Subproject commit 1d42deb98169fd94b54c714c0fe4bf308fe7115a diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml new file mode 100644 index 0000000000..b82b13f540 --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml @@ -0,0 +1,22 @@ +defaults: ../../sft.yaml +cluster: + gpus_per_node: 8 +policy: + model_name: openai/gpt-oss-20b + train_global_batch_size: 128 + train_micro_batch_size: 8 + max_total_sequence_length: 512 + dequantize_base_checkpoint: true + dtensor_cfg: + expert_parallel_size: 8 + automodel_kwargs: + backend: + _target_: nemo_automodel.components.moe.utils.BackendConfig + attn: flex + linear: te + rms_norm: te + enable_deepep: true + fake_balanced_gate: false + enable_hf_state_dict_adapter: true +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 69323f4896..781c8cdf17 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -34,6 +34,34 @@ class LoRAConfig(TypedDict): use_triton: NotRequired[bool] +class AutomodelBackendConfig(TypedDict): + # Hydra target class path (e.g., "nemo_automodel.components.moe.utils.BackendConfig") + _target_: str + # Attention implementation: "te" (Transformer Engine), "flex" (FlexAttention), etc. + attn: NotRequired[str] + # Linear layer implementation: "te" (Transformer Engine), etc. + linear: NotRequired[str] + # RMSNorm implementation: "te" (Transformer Engine), etc. + rms_norm: NotRequired[str] + # Enable DeepEP (Deep Expert Parallelism) for MoE models + enable_deepep: NotRequired[bool] + # Use fake balanced gate for testing/debugging MoE + fake_balanced_gate: NotRequired[bool] + # Enable HuggingFace state dict adapter for checkpoint loading + enable_hf_state_dict_adapter: NotRequired[bool] + # Enable FSDP-specific optimizations + enable_fsdp_optimizations: NotRequired[bool] + # Precision for the MoE gate computation (e.g., "float64", "float32") + gate_precision: NotRequired[str] + + +class AutomodelKwargs(TypedDict): + # Whether to use Liger kernel optimizations (default: false) + use_liger_kernel: NotRequired[bool] + # Backend configuration for MoE models + backend: NotRequired[AutomodelBackendConfig] + + class DTensorConfigDisabled(TypedDict): enabled: Literal[False] @@ -50,6 +78,7 @@ class DTensorConfig(TypedDict): custom_parallel_plan: str | None clear_cache_every_n_steps: NotRequired[int | None] lora_cfg: NotRequired[LoRAConfig | LoRAConfigDisabled] + automodel_kwargs: NotRequired[AutomodelKwargs] class SequencePackingConfigDisabled(TypedDict): diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 21558768b4..144683c95c 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -111,6 +111,12 @@ def __init__( use_v2 = config.get("dtensor_cfg", {}).get("_v2", False) if use_v2: worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2" + + if "TORCH_CUDA_ARCH_LIST" not in os.environ: + warnings.warn( + "TORCH_CUDA_ARCH_LIST is not set. This is needed if using DeepEP in DTensorPolicyWorker V2. This variable is set in our container, but " + "if you are running a custom container or baremetal, you may need to set this variable manually. Example: export TORCH_CUDA_ARCH_LIST='9.0 10.0'" + ) else: assert ( config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False) diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 283b980e72..7ad33708a2 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -29,7 +29,7 @@ # Try to import nemo_automodel classes, fallback to None if not available try: - from nemo_automodel.components._transformers.auto_model import ( + from nemo_automodel._transformers.auto_model import ( NeMoAutoModelForCausalLM, NeMoAutoModelForImageTextToText, NeMoAutoModelForTextToWaveform, diff --git a/nemo_rl/models/policy/workers/__init__.py b/nemo_rl/models/policy/workers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 2232f87c6f..c90c161b75 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -14,14 +14,12 @@ import gc import itertools -import math import os import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Optional, cast -import nemo_automodel.components._peft.lora as _lora_mod import ray import torch from accelerate import init_empty_weights @@ -29,40 +27,38 @@ from nemo_automodel import ( NeMoAutoModelForSequenceClassification, ) +from nemo_automodel._transformers.registry import ModelRegistry from nemo_automodel.components._peft.lora import ( PeftConfig, apply_lora_to_linear_modules, ) +from nemo_automodel.components.config.loader import _resolve_target from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, get_train_context, ) -from nemo_automodel.components.distributed.grad_utils import ( - clip_grad_by_total_norm_, - get_grad_norm, -) -from nemo_automodel.components.distributed.parallelizer import ( - fsdp2_strategy_parallelize, +from nemo_automodel.components.distributed.fsdp2 import ( + FSDP2Manager, ) from nemo_automodel.components.distributed.tensor_utils import ( get_cpu_state_dict, to_local_if_dtensor, ) -from torch import nn -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, +from nemo_automodel.components.moe.parallelizer import ( + parallelize_model as moe_parallelize_model, ) +from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm +from torch import nn from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - OffloadPolicy, ) from torch.distributed.tensor import DTensor, Shard from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, + PreTrainedModel, ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM @@ -90,22 +86,17 @@ resolve_model_class, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker -from nemo_rl.utils.automodel_checkpoint import ( - load_checkpoint, - save_checkpoint, -) +from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch +from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer - -# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586) -def _patched_init_lora_weights(self, init_method: str): - if init_method == "xavier": - nn.init.xavier_normal_(self.lora_A.weight.data) - else: - nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5)) - self.lora_B.weight.data.zero_() +STRING_TO_DTYPE = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} @ray.remote( @@ -134,6 +125,9 @@ def __init__( **kwargs: Any, ): """Initialize the DTensorPolicyWorkerV2.""" + # Apply TE patch until TE is upgraded to 2.10.0 + apply_transformer_engine_patch() + self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -154,26 +148,25 @@ def __init__( configure_dynamo_cache() self.cfg = config + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call - torch.distributed.init_process_group(backend="nccl") + backend = "nccl" if not self.cpu_offload else "cuda:nccl,cpu:gloo" + torch.distributed.init_process_group(backend=backend) self.rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + self.checkpoint_manager: Optional[AutomodelCheckpointManager] = None + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] - if self.cfg["precision"] == "float32": - self.dtype = torch.float32 - elif self.cfg["precision"] == "bfloat16": - self.dtype = torch.bfloat16 - elif self.cfg["precision"] == "float16": - self.dtype = torch.float16 - else: + try: + self.dtype = STRING_TO_DTYPE[self.cfg["precision"]] + except KeyError: raise ValueError(f"Unknown precision: {self.cfg['precision']}") - print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"] if self.enable_seq_packing: assert not self.is_vlm, ( @@ -186,6 +179,16 @@ def __init__( hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {} + # Choose attention implementation on the following basis: + # - Packed sequence requires FA2 and CP must be 1 + # - CP > 1 requires SDPA + cp_size_cfg = self.cfg["dtensor_cfg"]["context_parallel_size"] + attn_impl = ( + "flash_attention_2" + if (self.enable_seq_packing and cp_size_cfg == 1) + else ("sdpa" if cp_size_cfg > 1 else None) + ) + model_config = AutoConfig.from_pretrained( model_name, # Always load the model in float32 to keep master weights in float32. @@ -235,15 +238,10 @@ def __init__( # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc. model_class = resolve_model_class(model_config.model_type) - full_state_dict = None - model_state_dict_keys = None - # lora config lora_cfg = self.cfg["dtensor_cfg"].get("lora_cfg", None) self.peft_config = None self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"] - # patch the init_lora_weights method to use the xavier initialization - _lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights if self.lora_enabled: if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1: assert not lora_cfg["use_triton"], ( @@ -254,30 +252,26 @@ def __init__( cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"} self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype) - if self.rank == 0: - print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") - model = model_class.from_pretrained( - model_name, - device_map="cpu", # load weights onto CPU initially - trust_remote_code=True, - config=model_config, - use_liger_kernel=False, - torch_dtype=str(model_config.torch_dtype), - ) - - if self.lora_enabled: - apply_lora_to_linear_modules(model, self.peft_config) - - full_state_dict = model.state_dict() - # Store the original model state dict keys before any parallelization - model_state_dict_keys = list(full_state_dict.keys()) - del model - print(f"[Rank {self.rank}] Initializing empty model for FSDP...") # All ranks initialize model on meta device, so FSDP can shard it. # The actual weights will be broadcast from rank 0. cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + automodel_kwargs = self.cfg["dtensor_cfg"].get("automodel_kwargs", {}) + if automodel_kwargs.get("backend", None) is not None: + backend_class = _resolve_target( + automodel_kwargs.get("backend", None)["_target_"] + ) + backend_kwargs = automodel_kwargs.get("backend") + backend_kwargs.pop("_target_") + backend = backend_class( + **backend_kwargs, + ) + automodel_kwargs["backend"] = backend + + if "use_liger_kernel" not in automodel_kwargs: + automodel_kwargs["use_liger_kernel"] = False + with init_empty_weights(): # NeMoAutoModelForCausalLM uses flash_attention_2 by default # so we need to set it to None if sequence packing is disabled @@ -299,32 +293,32 @@ def __init__( else: sdpa_method = None - self.model = model_class.from_config( - model_config, - attn_implementation="flash_attention_2" - if self.enable_seq_packing - else None, - use_liger_kernel=False, - trust_remote_code=True, + self.model = model_class.from_pretrained( + model_name, + attn_implementation=attn_impl, torch_dtype=str(model_config.torch_dtype), + trust_remote_code=True, + config=model_config, sdpa_method=sdpa_method, + **automodel_kwargs, ) if self.lora_enabled: apply_lora_to_linear_modules(self.model, self.peft_config) + # Hold a copy of model state_dict keys before any parallelization + self.model_state_dict_keys = list(self.model.state_dict().keys()) + if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] + ep_size = self.cfg["dtensor_cfg"].get("expert_parallel_size", 1) + dp_size = None # will be inferred if cp_size > 1 and self.enable_seq_packing: raise ValueError( "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." ) - dp_size = world_size // tp_size // cp_size sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"] - assert world_size == dp_size * tp_size * cp_size, ( - f"World size({world_size}) must equal to dp_size({dp_size}) * tp_size({tp_size}) * cp_size({cp_size}) to use DTensor" - ) if sequence_parallel_enabled and tp_size == 1: print( @@ -351,53 +345,18 @@ def __init__( "Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models." ) - # For FSDP2 compatibility, we need to support HSDP structure - # For now, we use dp_replicate_size = 1 (no hybrid sharding) - dp_replicate_size = 1 - dp_shard_size = dp_size - - # torch==2.8 uses LOCAL_RANK to set the device here (https://github.com/pytorch/pytorch/blob/ba56102387ef21a3b04b357e5b183d48f0afefc7/torch/distributed/device_mesh.py#L500), - # but CUDA_VISIBLE_DEVICES is set to only 1 gpu, so we need to temporarily set LOCAL_RANK to 0. - # TODO: consider changing the default LOCAL_RANK set in worker_groups.py - prev_local_rank = os.environ["LOCAL_RANK"] - os.environ["LOCAL_RANK"] = "0" - - # Create device mesh with HSDP structure for FSDP2 compatibility - device_mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - (dp_replicate_size, dp_shard_size, cp_size, tp_size), - mesh_dim_names=("dp_replicate", "dp_shard", "cp", "tp"), - ) - os.environ["LOCAL_RANK"] = prev_local_rank - - # Create flattened submeshes for different use cases - # Flatten dp_replicate + dp_shard for the "dp" dimension (backward compatibility) - device_mesh[("dp_replicate", "dp_shard")]._flatten(mesh_dim_name="dp") - - # Flatten dp_shard + cp for FSDP2 sharding - device_mesh[("dp_shard", "cp")]._flatten(mesh_dim_name="dp_shard_cp") - - # Flatten dp_replicate + dp_shard + cp for gradient operations - device_mesh[("dp_replicate", "dp_shard", "cp")]._flatten(mesh_dim_name="dp_cp") - - # Store mesh references for backward compatibility - self.dp_cp_mesh = device_mesh["dp_cp"] - self.dp_mesh = device_mesh["dp"] - self.tp_mesh = device_mesh["tp"] - self.cp_mesh = device_mesh["cp"] - - self.dp_size = dp_size - self.tp_size = tp_size - self.cp_size = cp_size - self.device_mesh = device_mesh - # ------------------------------------------------ - # 3) Move to GPU + Composable FSDP - # (Initialize device mesh, shard submodules, then shard entire model) + # Build device mesh and parallelize # ------------------------------------------------ - self.model = fsdp2_strategy_parallelize( - self.model, - device_mesh=self.device_mesh, + manager = FSDP2Manager( + dp_size=dp_size, + dp_replicate_size=1, + tp_size=tp_size, + cp_size=cp_size, + ep_size=ep_size, + pp_size=1, + sequence_parallel=sequence_parallel_enabled, + use_hf_tp_plan=self.cfg["dtensor_cfg"].get("use_hf_tp_plan", False), mp_policy=MixedPrecisionPolicy( param_dtype=self.dtype, reduce_dtype=torch.float32, @@ -405,34 +364,94 @@ def __init__( ), offload_policy=CPUOffloadPolicy(pin_memory=False) if self.cpu_offload - else OffloadPolicy(), - sequence_parallel=sequence_parallel_enabled, + else None, + backend="nccl", + world_size=world_size, activation_checkpointing=self.cfg["dtensor_cfg"][ "activation_checkpointing" ], - tp_shard_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], - dp_replicate_mesh_name="dp_replicate", - dp_shard_cp_mesh_name="dp_shard_cp", - tp_mesh_name="tp", + custom_tp_plan=self.cfg["dtensor_cfg"].get("custom_parallel_plan", None), ) - print(f"[Rank {self.rank}] Loading state dict from rank 0...") - # This will broadcast the state dict from rank 0 to all other ranks - # and load it into the FSDP model. - set_model_state_dict( + # Force setup distributed for world size 1 as FSDP2Manager skips it. + if world_size == 1: + manager._setup_distributed() + + # Store mesh references for downstream usage + self.device_mesh = manager.device_mesh + self.dp_cp_mesh = self.device_mesh["dp_cp"] + self.dp_mesh = self.device_mesh["dp"] + self.tp_mesh = self.device_mesh["tp"] + self.cp_mesh = self.device_mesh["cp"] + self.moe_mesh = getattr(manager, "moe_mesh", None) + + self.dp_size = manager.dp_size + self.tp_size = manager.tp_size + self.cp_size = manager.cp_size + + # Parallelize model + is_moe_model = any(["expert" in key for key in self.model_state_dict_keys]) + is_hf_model = ( + model_config.architectures[0] not in ModelRegistry.model_arch_name_to_cls + ) + if ( + not isinstance(self.model, PreTrainedModel) + and is_moe_model + and not is_hf_model + ): + assert self.tp_size == 1, ( + "Using custom implementation {self.model.__class__.__name__} for MoE model {model_name} which doesn't support tp_size > 1. Please use expert_parallel_size > 1 for custom implementation or set force_hf=True in your config at policy->dtensor_cfg->automodel_kwargs to use the HuggingFace implementation." + ) + assert self.cp_size == 1, ( + "Using custom implementation {self.model.__class__.__name__} for MoE model {model_name} which doesn't support cp_size > 1. Please set force_hf=True in your config at policy->dtensor_cfg->automodel_kwargs to use the HuggingFace implementation." + ) + moe_parallelize_model( + model=self.model, + world_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + pp_enabled=False, + dp_axis_names=( + ("dp_replicate", "dp_shard_cp") + if "dp_replicate" in self.device_mesh.mesh_dim_names + and "dp_shard_cp" in self.device_mesh.mesh_dim_names + else ("dp_shard_cp",) + ), + cp_axis_name="cp", + tp_axis_name="tp", + ep_axis_name="ep", + ep_shard_axis_names=("ep_shard",), + ) + else: + self.model = manager.parallelize(self.model) + + # Load base model weights across all ranks using Automodel Checkpointer + # This mirrors build_model_and_optimizer's is_meta_device + load_weights path + print(self.model) + self._init_checkpoint_manager( + config_updates={ + "model_repo_id": model_name, + "dequantize_base_checkpoint": self.cfg.get( + "dequantize_base_checkpoint", False + ), + "is_peft": self.lora_enabled, + }, + ) + self.checkpoint_manager.set_model_state_dict_keys(self.model_state_dict_keys) + + # Load base HF weights unless an explicit checkpoint is provided later + # This puts shards directly into the parallelized model + self.checkpoint_manager.load_base_model( self.model, - model_state_dict=full_state_dict, - options=StateDictOptions( - full_state_dict=True, - broadcast_from_rank0=True, + model_name=model_name, + hf_cache_dir=hf_config_overrides.get("cache_dir", None), + dequantize_base_checkpoint=self.cfg.get( + "dequantize_base_checkpoint", False ), + peft_init_method=self.peft_config.lora_A_init + if self.peft_config is not None + else None, ) - # Broadcast model state dict keys to all ranks and store as instance variable - keys_to_broadcast = [model_state_dict_keys] - torch.distributed.broadcast_object_list(keys_to_broadcast, src=0) - self.model_state_dict_keys = keys_to_broadcast[0] - # Handle tied word embeddings after loading the state dict # We need to actually tie the parameters at the model level is_tied_lm_head = hasattr(self.model, "lm_head") and getattr( @@ -448,10 +467,6 @@ def __init__( if embed_tokens_weight is not None: self.model.lm_head.weight = embed_tokens_weight - # Manually broadcast buffers - for _, buf in self.model.named_buffers(): - torch.distributed.broadcast(to_local_if_dtensor(buf), src=0) - if self.cpu_offload: self.model = self.move_to_device(self.model, "cpu") @@ -463,7 +478,8 @@ def __init__( if init_optimizer: optimizer_cls = get_class(self.cfg["optimizer"]["name"]) self.optimizer = optimizer_cls( - self.model.parameters(), **self.cfg["optimizer"]["kwargs"] + self.model.parameters(), + **self.cfg["optimizer"]["kwargs"], ) else: self.optimizer = None @@ -505,7 +521,7 @@ def __init__( self.load_checkpoint(weights_path, optimizer_path) else: print( - "No weights path provided. Starting from scratch (default policy init)" + "No weights path provided. Loaded base HF weights via Checkpointer (default policy init)" ) def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: @@ -513,18 +529,6 @@ def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: logits.div_(self.cfg["generation"]["temperature"]) return logits - def init_collective( - self, ip: str, port: int, world_size: int, *, train_world_size: int - ) -> None: - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - - pg = StatelessProcessGroup.create( - host=ip, port=port, rank=self.rank, world_size=world_size - ) - device = torch.cuda.current_device() - self.model_update_group = PyNcclCommunicator(pg, device=device) - def check_model_allow_flash_attn_args(self, model_config) -> bool: # Some models doesn't support flash_attn_kwargs # Check nemotron nas. @@ -591,7 +595,7 @@ def train( ## get the normalization factor for the loss local_valid_seqs = torch.sum(global_batch["sample_mask"]) - if not "token_mask" in global_batch: + if "token_mask" not in global_batch: local_valid_toks = ( local_valid_seqs * global_batch["input_ids"].shape[1] ) @@ -760,7 +764,10 @@ def train( outputs = self.model(**model_args) # Get logprobs - if not hasattr(outputs, "logits"): + if isinstance(outputs, (torch.Tensor, DTensor)): + # custom models (e.g., those coming from AutoModel) can output logits directly + logits = outputs + elif not hasattr(outputs, "logits"): logits = self.model.lm_head(outputs.last_hidden_state) else: logits = outputs.logits @@ -851,10 +858,6 @@ def train( ## NOTE: invalid samples should be multiplied ## by zero in the loss function to prevent them ## from affecting the gradient calculation - - # when FSDP reduces the gradients over the DP dim, they're automatically averaged - # but we want to sum them so we cancel out the average here - loss *= self.dp_size * self.cp_size loss.backward() if num_valid_samples > 0: @@ -863,20 +866,27 @@ def train( grad_norm: Optional[float | torch.Tensor] = None if not eval_mode: - with torch.no_grad(): - grad_norm = get_grad_norm( - self.model.parameters(), - dp_cp_group=self.dp_cp_mesh.get_group(), - tp_group=self.tp_mesh.get_group(), - dtype=torch.float32, - ) - if self.max_grad_norm is not None: - clip_grad_by_total_norm_( - self.model.parameters(), - max_grad_norm=self.max_grad_norm, - total_norm=grad_norm, - ) - grad_norm = torch.tensor([grad_norm]) + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + [self.model], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=1, + # when FSDP reduces the gradients over the DP dim, they're automatically averaged + # but we want to sum them so we rescale the gradients by self.dp_size * self.cp_size + dp_group_size=self.dp_size * self.cp_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) # Update parameters self.optimizer.step() @@ -1077,7 +1087,7 @@ def get_logprobs( outputs = self.model(**model_args) - logits = outputs.logits + logits = outputs.logits if hasattr(outputs, "logits") else outputs # Apply temperature scaling logits = self._apply_temperature_scaling(logits) @@ -1893,37 +1903,17 @@ def save_checkpoint( the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. """ - if checkpointing_cfg is None: - raise ValueError( - "checkpointing_cfg must be provided when saving checkpoint" - ) - - # Extract only the checkpointing configuration keys that exist - checkpoint_kwargs = { - key: value - for key, value in checkpointing_cfg.items() - if key - in { - "model_save_format", - "save_consolidated", - "is_peft", - "peft_config", - } - } - if self.lora_enabled: - checkpoint_kwargs["is_peft"] = True - checkpoint_kwargs["peft_config"] = self.peft_config - - save_checkpoint( + self.checkpoint_manager.save_checkpoint( model=self.model, weights_path=weights_path, - optimizer=self.optimizer if optimizer_path else None, - scheduler=self.scheduler if optimizer_path else None, + optimizer=self.optimizer, optimizer_path=optimizer_path, - tokenizer=self.tokenizer if tokenizer_path else None, + scheduler=self.scheduler, + tokenizer=self.tokenizer if tokenizer_path is None else None, tokenizer_path=tokenizer_path, - model_state_dict_keys=self.model_state_dict_keys, - **checkpoint_kwargs, + checkpointing_cfg=checkpointing_cfg, + lora_enabled=self.lora_enabled, + peft_config=self.peft_config, ) def load_checkpoint( @@ -1931,11 +1921,37 @@ def load_checkpoint( weights_path: str, optimizer_path: Optional[str] = None, ) -> None: - """Load a checkpoint into the model.""" - load_checkpoint( + """Load a checkpoint into the model using Automodel Checkpointer.""" + self.checkpoint_manager.load_checkpoint( model=self.model, weights_path=weights_path, - optimizer=self.optimizer if optimizer_path else None, - scheduler=self.scheduler if optimizer_path else None, + optimizer=self.optimizer, optimizer_path=optimizer_path, + scheduler=self.scheduler, ) + + def _init_checkpoint_manager( + self, + config_updates: Optional[dict[str, Any]] = None, + checkpoint_root: Optional[str] = None, + ) -> None: + """Initialize the AutomodelCheckpointManager for this worker. + + This creates the checkpoint manager bound to this worker's device meshes + and initializes its underlying checkpointer. + + Args: + config_updates: Dict of CheckpointingConfig fields to set during initialization. + checkpoint_root: Optional root directory for checkpoints. + """ + if self.checkpoint_manager is None: + self.checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=self.dp_mesh, + tp_mesh=self.tp_mesh, + model_state_dict_keys=getattr(self, "model_state_dict_keys", None), + moe_mesh=self.moe_mesh, + ) + self.checkpoint_manager.init_checkpointer( + config_updates=config_updates, + checkpoint_root=checkpoint_root, + ) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index e9bc31d954..66767822e6 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -19,103 +19,10 @@ from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from functools import partial -from importlib.util import find_spec from typing import Any, Iterator, Optional, TypeVar, cast import ray import torch - - -def _get_transformer_engine_file(relative_path: str) -> str: - """Return absolute path to a Transformer Engine file or raise if it cannot be found. - - The relative_path should be a POSIX-style path under the transformer_engine - package root, e.g. "pytorch/triton/permutation.py". - """ - spec = find_spec("transformer_engine") - if spec is None or not spec.submodule_search_locations: - raise RuntimeError( - "Transformer Engine package not found while attempting to patch " - f"'{relative_path}'. Ensure `transformer-engine` is installed and " - "available in this environment." - ) - - base_dir = next(iter(spec.submodule_search_locations)) - file_path = os.path.join(base_dir, *relative_path.split("/")) - - if not os.path.exists(file_path): - raise RuntimeError( - "Failed to locate expected Transformer Engine file to patch. " - f"Looked for '{relative_path}' at '{file_path}'. " - "This likely indicates an unexpected Transformer Engine installation " - "layout or version mismatch." - ) - - return file_path - - -def _apply_transformer_engine_patch(): - """Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. - - This locates the target file via importlib metadata instead of importing - `transformer_engine`, to avoid side effects during initialization. If the - permutation module has already been imported, it will be reloaded so that - the patched source takes effect. - """ - try: - perm_file = _get_transformer_engine_file("pytorch/triton/permutation.py") - - with open(perm_file, "r") as f: - content = f.read() - - if "get_int_dtype = triton.constexpr_function(get_int_dtype)" not in content: - print(f"Applying Triton fix to {perm_file}...") - - # 1. Replace the usage - old_usage = "idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" - new_usage = "idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" - - # 2. Insert the definition before the first @triton.jit - jit_anchor = "@triton.jit" - - new_definition = ( - "\n\n" - "get_int_dtype = core.get_int_dtype\n" - "get_int_dtype = triton.constexpr_function(get_int_dtype)\n" - ) - - new_content = None - if old_usage in content: - temp_content = content.replace(old_usage, new_usage) - - if jit_anchor in temp_content: - new_content = temp_content.replace( - jit_anchor, new_definition + jit_anchor, 1 - ) - - if new_content: - try: - with open(perm_file, "w") as f: - f.write(new_content) - print("Successfully patched transformer_engine permutation.py.") - except OSError as e: - print( - f"Could not write patch to transformer_engine (permission denied?): {e}" - ) - - # If the permutation module is already imported in this process, - # reload it so that the patched source takes effect for subsequent use. - import importlib - import sys - - perm_module_name = "transformer_engine.pytorch.triton.permutation" - if perm_module_name in sys.modules: - importlib.reload(sys.modules[perm_module_name]) - - except Exception as e: - print(f"Error checking/patching transformer_engine: {e}") - - from megatron.bridge import AutoBridge from megatron.bridge.models.model_provider import get_model from megatron.bridge.training import fault_tolerance @@ -222,6 +129,7 @@ def _apply_transformer_engine_patch(): get_runtime_env_for_policy_worker, ) from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker +from nemo_rl.models.policy.workers.patches import apply_transformer_engine_patch from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer @@ -540,7 +448,7 @@ def __init__( pre_init_communication_queue: Queue, **kwargs: Any, ): - _apply_transformer_engine_patch() + apply_transformer_engine_patch() self.is_generation_colocated = None if "generation" in config and config["generation"] is not None: diff --git a/nemo_rl/models/policy/workers/patches.py b/nemo_rl/models/policy/workers/patches.py new file mode 100644 index 0000000000..5a0d5b0ab8 --- /dev/null +++ b/nemo_rl/models/policy/workers/patches.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from importlib.util import find_spec + + +def _get_transformer_engine_file(relative_path: str) -> str: + """Return absolute path to a Transformer Engine file or raise if it cannot be found. + + The relative_path should be a POSIX-style path under the transformer_engine + package root, e.g. "pytorch/triton/permutation.py". + """ + spec = find_spec("transformer_engine") + if spec is None or not spec.submodule_search_locations: + raise RuntimeError( + "Transformer Engine package not found while attempting to patch " + f"'{relative_path}'. Ensure `transformer-engine` is installed and " + "available in this environment." + ) + + base_dir = next(iter(spec.submodule_search_locations)) + file_path = os.path.join(base_dir, *relative_path.split("/")) + + if not os.path.exists(file_path): + raise RuntimeError( + "Failed to locate expected Transformer Engine file to patch. " + f"Looked for '{relative_path}' at '{file_path}'. " + "This likely indicates an unexpected Transformer Engine installation " + "layout or version mismatch." + ) + + return file_path + + +def apply_transformer_engine_patch(): + """Apply patch from https://github.com/NVIDIA/TransformerEngine/pull/2286/files. + + This locates the target file via importlib metadata instead of importing + `transformer_engine`, to avoid side effects during initialization. If the + permutation module has already been imported, it will be reloaded so that + the patched source takes effect. + """ + try: + perm_file = _get_transformer_engine_file("pytorch/triton/permutation.py") + + with open(perm_file, "r") as f: + content = f.read() + + if "get_int_dtype = triton.constexpr_function(get_int_dtype)" not in content: + print(f"Applying Triton fix to {perm_file}...") + + # 1. Replace the usage + old_usage = "idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" + new_usage = "idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" + + # 2. Insert the definition before the first @triton.jit + jit_anchor = "@triton.jit" + + new_definition = ( + "\n\n" + "get_int_dtype = core.get_int_dtype\n" + "get_int_dtype = triton.constexpr_function(get_int_dtype)\n" + ) + + new_content = None + if old_usage in content: + temp_content = content.replace(old_usage, new_usage) + + if jit_anchor in temp_content: + new_content = temp_content.replace( + jit_anchor, new_definition + jit_anchor, 1 + ) + + if new_content: + try: + with open(perm_file, "w") as f: + f.write(new_content) + print("Successfully patched transformer_engine permutation.py.") + except OSError as e: + print( + f"Could not write patch to transformer_engine (permission denied?): {e}" + ) + + # If the permutation module is already imported in this process, + # reload it so that the patched source takes effect for subsequent use. + import importlib + import sys + + perm_module_name = "transformer_engine.pytorch.triton.permutation" + if perm_module_name in sys.modules: + importlib.reload(sys.modules[perm_module_name]) + + except Exception as e: + print(f"Error checking/patching transformer_engine: {e}") diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py index a9f0793851..bfba23fae5 100644 --- a/nemo_rl/utils/automodel_checkpoint.py +++ b/nemo_rl/utils/automodel_checkpoint.py @@ -11,49 +11,383 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Automodel checkpoint utilities for DTensor policy workers. -"""Checkpoint management utilities for HF models.""" +This module provides a wrapper class around the nemo_automodel Checkpointer +for saving and loading model checkpoints in DTensor-based policy workers. +""" import os from typing import Any, Optional import torch +from nemo_automodel.components._peft.lora import PeftConfig from nemo_automodel.components.checkpoint._backports.filesystem import ( SerializationFormat, ) - -# Apply torch backports for compatibility with torch==2.7.1 -from nemo_automodel.components.checkpoint._torch_backports import apply_patches - -# Import from nemo-automodel from nemo_automodel.components.checkpoint.checkpointing import ( - CheckpointingConfig, - load_model, - load_optimizer, - save_model, - save_optimizer, + Checkpointer, ) +from nemo_automodel.components.checkpoint.checkpointing import ( + CheckpointingConfig as AutomodelCheckpointingConfig, +) +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from transformers import AutoTokenizer +from transformers.utils import TRANSFORMERS_CACHE -# Apply torch backports for compatibility with torch==2.7.1 -apply_patches() - +from nemo_rl.utils.checkpoint import CheckpointingConfig -def _infer_checkpoint_root(weights_path: str) -> str: - """Infer checkpoint root directory from weights path. - When weights_path ends with "…/weights/model", we need the parent of - the weights directory (the checkpoint root), not the weights directory itself. +class AutomodelCheckpointManager: + """Manages checkpointing for DTensor-based models using nemo_automodel's Checkpointer. - Args: - weights_path: Path to model weights (e.g., "/path/to/policy/weights/model") + This class provides a clean interface for saving and loading model checkpoints, + wrapping the nemo_automodel Checkpointer with configuration management. - Returns: - str: Checkpoint root directory (e.g., "/path/to/policy") + Attributes: + checkpointer: The underlying nemo_automodel Checkpointer instance. + checkpoint_config: The current checkpoint configuration. + model_state_dict_keys: List of model state dict keys for checkpoint validation. """ - weights_dir = os.path.dirname(weights_path) - if weights_dir.endswith("weights"): - return os.path.dirname(weights_dir) - return weights_dir + + def __init__( + self, + dp_mesh: DeviceMesh, + tp_mesh: DeviceMesh, + model_state_dict_keys: Optional[list[str]] = None, + moe_mesh: Optional[DeviceMesh] = None, + ): + """Initialize the AutomodelCheckpointManager. + + Args: + dp_mesh: The data parallel device mesh. + tp_mesh: The tensor parallel device mesh. + model_state_dict_keys: Optional list of model state dict keys. + moe_mesh: Optional MoE device mesh. + """ + self.checkpointer: Optional[Checkpointer] = None + self.checkpoint_config: Optional[AutomodelCheckpointingConfig] = None + self.model_state_dict_keys = model_state_dict_keys + self.dp_mesh = dp_mesh + self.tp_mesh = tp_mesh + self.moe_mesh = moe_mesh + + def _get_dp_rank(self) -> int: + """Get the data parallel rank.""" + return torch.distributed.get_rank(self.dp_mesh.get_group()) + + def _get_tp_rank(self) -> int: + """Get the tensor parallel rank.""" + return torch.distributed.get_rank(self.tp_mesh.get_group()) + + def init_checkpointer( + self, + config_updates: Optional[dict[str, Any]] = None, + checkpoint_root: Optional[str] = None, + ) -> None: + """Initialize the Automodel Checkpointer if not already created. + + This method creates a new Checkpointer instance with the provided configuration. + If a checkpointer already exists, this method does nothing. + + Args: + config_updates: Dict of CheckpointingConfig fields to set during initialization. + checkpoint_root: Optional root directory for checkpoints. + """ + if self.checkpointer is not None: + return + + if config_updates is None: + config_updates = {} + + dp_rank = self._get_dp_rank() + tp_rank = self._get_tp_rank() + pp_rank = 0 + + # Initialize a base config with sensible defaults + base_cfg = AutomodelCheckpointingConfig( + enabled=True, + checkpoint_dir=checkpoint_root or "", + model_save_format=config_updates.get("model_save_format", "safetensors"), + model_cache_dir=config_updates.get("model_cache_dir", ""), + model_repo_id=config_updates.get("model_repo_id", ""), + save_consolidated=config_updates.get("save_consolidated", False), + is_peft=config_updates.get("is_peft", False), + model_state_dict_keys=self.model_state_dict_keys, + is_async=config_updates.get("is_async", False), + dequantize_base_checkpoint=config_updates.get( + "dequantize_base_checkpoint", False + ), + ) + self.checkpoint_config = base_cfg + self.checkpointer = Checkpointer( + config=base_cfg, + dp_rank=dp_rank, + tp_rank=tp_rank, + pp_rank=pp_rank, + moe_mesh=self.moe_mesh, + ) + + def update_checkpointer_config( + self, + config_updates: Optional[dict[str, Any]] = None, + checkpoint_root: Optional[str] = None, + ) -> None: + """Update the configuration of an existing Checkpointer. + + This method updates the mutable config fields on the existing Checkpointer instance. + If no checkpointer exists, this method does nothing. + + Note: Some config changes (like model_save_format) require rebuilding the + checkpointer's internal addons list. This method handles that automatically. + + Args: + config_updates: Dict of CheckpointingConfig fields to update. + checkpoint_root: Optional root directory for checkpoints. + """ + if self.checkpointer is None: + return + + if config_updates is None: + config_updates = {} + + cfg = self.checkpointer.config + if checkpoint_root is not None: + cfg.checkpoint_dir = checkpoint_root + for k, v in config_updates.items(): + if k == "model_save_format": + # Ensure enum type + v = SerializationFormat[v.upper()] if isinstance(v, str) else v + setattr(cfg, k, v) + # Ensure model_state_dict_keys is current + if self.model_state_dict_keys is not None: + cfg.model_state_dict_keys = self.model_state_dict_keys + + # Rebuild _addons list based on updated config + # This is necessary because _addons is populated during __init__ based on config + self._rebuild_checkpointer_addons() + + def _rebuild_checkpointer_addons(self) -> None: + """Rebuild the checkpointer's _addons list based on current config. + + The Checkpointer's _addons list is populated during __init__ based on config. + When config changes (e.g., model_save_format or is_peft), we need to rebuild + the addons list to match the new config. + """ + if self.checkpointer is None: + return + + from nemo_automodel.components.checkpoint.addons import ( + ConsolidatedHFAddon, + PeftAddon, + ) + + self.checkpointer._addons = [] + if self.checkpointer._should_write_hf_metadata(): + self.checkpointer._addons.append(ConsolidatedHFAddon()) + if self.checkpointer.config.is_peft: + self.checkpointer._addons.append(PeftAddon()) + + def set_model_state_dict_keys(self, keys: list[str]) -> None: + """Set the model state dict keys for checkpoint validation. + + Args: + keys: List of model state dict keys. + """ + self.model_state_dict_keys = keys + if self.checkpointer is not None: + self.checkpointer.config.model_state_dict_keys = keys + + def load_base_model( + self, + model: nn.Module, + model_name: str, + hf_cache_dir: Optional[str] = None, + dequantize_base_checkpoint: bool = False, + peft_init_method: Optional[str] = None, + ) -> None: + """Load base model weights using the Automodel Checkpointer. + + This method loads the initial HuggingFace model weights into the parallelized model. + + Args: + model: The model to load weights into. + model_name: Name or path of the model. + hf_cache_dir: Optional HuggingFace cache directory. + dequantize_base_checkpoint: Whether to dequantize the base checkpoint. + + Raises: + AssertionError: If checkpointer has not been initialized. + """ + assert self.checkpointer is not None, ( + "Checkpointer must be initialized before loading base model. " + "Call init_checkpointer() first." + ) + + self.update_checkpointer_config( + config_updates={ + "model_repo_id": model_name, + "dequantize_base_checkpoint": dequantize_base_checkpoint, + }, + checkpoint_root=None, + ) + self.checkpointer.config.model_state_dict_keys = self.model_state_dict_keys + + self.checkpointer.load_base_model( + model, + device=torch.cuda.current_device(), + root_dir=hf_cache_dir or TRANSFORMERS_CACHE, + model_name=model_name, + peft_init_method=peft_init_method, + load_base_model=True, + ) + + def save_checkpoint( + self, + model: nn.Module, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + optimizer_path: Optional[str] = None, + scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + tokenizer: Optional[AutoTokenizer] = None, + tokenizer_path: Optional[str] = None, + checkpointing_cfg: Optional[CheckpointingConfig] = None, + lora_enabled: bool = False, + peft_config: Optional[PeftConfig] = None, + ) -> None: + """Save a checkpoint of the model. + + The optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + + Args: + model: The model to save. + weights_path: Path to save model weights. + optimizer: Optional optimizer to save. + optimizer_path: Optional path to save optimizer state. + scheduler: Optional learning rate scheduler. + tokenizer: Optional tokenizer to save with the checkpoint. + tokenizer_path: Optional path to save tokenizer separately. + checkpointing_cfg: Checkpointing configuration. + lora_enabled: Whether LoRA is enabled. + peft_config: Optional PEFT configuration. + """ + print(f"Saving checkpoint to {weights_path}") + assert self.checkpointer is not None, ( + "Checkpointer must be initialized before saving checkpoint. " + "Call init_checkpointer() first." + ) + if checkpointing_cfg is None: + raise ValueError( + "checkpointing_cfg must be provided when saving checkpoint" + ) + + # Extract only the checkpointing configuration keys that exist + checkpoint_kwargs = { + key: value + for key, value in checkpointing_cfg.items() + if key + in { + "model_save_format", + "save_consolidated", + "is_peft", + "peft_config", + "model_cache_dir", + "model_repo_id", + "is_async", + "dequantize_base_checkpoint", + } + } + if lora_enabled: + checkpoint_kwargs["is_peft"] = True + checkpoint_kwargs["peft_config"] = peft_config + + checkpoint_root = _infer_checkpoint_root(weights_path) + + # Update checkpointer configuration + self.update_checkpointer_config( + config_updates=checkpoint_kwargs, checkpoint_root=checkpoint_root + ) + + self.checkpointer.save_model( + model=model, + weights_path=weights_path, + peft_config=checkpoint_kwargs.get("peft_config"), + tokenizer=tokenizer if tokenizer_path is None else None, + ) + + if optimizer_path and optimizer is not None: + self.checkpointer.save_optimizer( + optimizer=optimizer, + model=model, + weights_path=optimizer_path, + scheduler=scheduler, + ) + + if tokenizer_path and tokenizer is not None: + print(f"Saving tokenizer (or processor) to {tokenizer_path}") + tokenizer.save_pretrained(tokenizer_path) + + def load_checkpoint( + self, + model: nn.Module, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + optimizer_path: Optional[str] = None, + scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + ) -> None: + """Load a checkpoint into the model using Automodel Checkpointer. + + Args: + model: The model to load weights into. + weights_path: Path to the checkpoint weights. + optimizer: Optional optimizer to load state into. + optimizer_path: Optional path to optimizer checkpoint. + scheduler: Optional learning rate scheduler. + """ + print(f"Loading weights from {weights_path}") + assert self.checkpointer is not None, ( + "Checkpointer must be initialized before loading checkpoint. " + "Call init_checkpointer() first." + ) + + model_save_format, is_peft = detect_checkpoint_format(weights_path) + + weights_dir = os.path.dirname(weights_path) + checkpoint_root = ( + os.path.dirname(weights_dir) + if weights_dir.endswith("weights") + else weights_dir + ) + + # Update checkpointer configuration + self.update_checkpointer_config( + config_updates={ + "model_save_format": model_save_format, + "is_peft": is_peft, + "dequantize_base_checkpoint": False, # the saved checkpoint is already dequantized + }, + checkpoint_root=checkpoint_root, + ) + + model_dir = ( + weights_path + if weights_path.endswith("/model") + else os.path.join(weights_path, "model") + ) + + self.checkpointer.load_model( + model=model, + model_path=model_dir, + ) + + if optimizer_path and optimizer is not None: + self.checkpointer.load_optimizer( + optimizer=optimizer, + model=model, + weights_path=optimizer_path, + scheduler=scheduler, + ) def detect_checkpoint_format(weights_path: str) -> tuple[str, bool]: @@ -91,150 +425,19 @@ def detect_checkpoint_format(weights_path: str) -> tuple[str, bool]: return model_save_format, is_peft -def save_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[Any] = None, - optimizer_path: Optional[str] = None, - tokenizer: Optional[Any] = None, - tokenizer_path: Optional[str] = None, - model_save_format: str = "safetensors", - is_peft: bool = False, - peft_config: Optional[Any] = None, - save_consolidated: bool = False, - model_state_dict_keys: Optional[list[str]] = None, -) -> None: - """Save a checkpoint of the model and optionally optimizer state. - - Args: - model: The PyTorch model to save - weights_path: Path to save model weights - optimizer: Optional optimizer to save - scheduler: Optional scheduler to save - optimizer_path: Path to save optimizer state (required if optimizer provided) - tokenizer: Optional tokenizer to save - tokenizer_path: Path to save tokenizer state (required if tokenizer provided) - model_save_format: Format for saving model ("torch_save" or "safetensors") - is_peft: Whether the model uses PEFT - peft_config: PEFT configuration if is_peft is True - save_consolidated: Whether to save consolidated checkpoints (for HF compatibility) - model_state_dict_keys: Copy of the model state dict keys before any parallelization. - If None, will be extracted from the model's current state dict. - """ - # Create checkpoint config - - # Extract model state dict keys if not provided - if model_state_dict_keys is None: - model_state_dict_keys = list(model.state_dict().keys()) - - valid_formats = {"safetensors", "torch_save"} - if model_save_format not in valid_formats: - raise ValueError( - f"Unsupported model_save_format='{model_save_format}'. " - f"Expected one of {sorted(valid_formats)}." - ) - - # Ensure target directories exist - os.makedirs(weights_path, exist_ok=True) - if optimizer_path: - os.makedirs(optimizer_path, exist_ok=True) - if tokenizer_path: - os.makedirs(tokenizer_path, exist_ok=True) - - checkpoint_config = CheckpointingConfig( - enabled=True, - checkpoint_dir=_infer_checkpoint_root(weights_path), - model_save_format=model_save_format, - model_cache_dir="", - model_repo_id="", - save_consolidated=save_consolidated, - is_peft=is_peft, - model_state_dict_keys=model_state_dict_keys, - ) - - # Save model using nemo-automodel API - save_model( - model=model, - weights_path=weights_path, - checkpoint_config=checkpoint_config, - peft_config=peft_config, - tokenizer=tokenizer if tokenizer_path is None else None, - ) - - # Save optimizer if provided - if optimizer is not None: - if optimizer_path is None: - raise ValueError( - "optimizer_path must be provided when saving optimizer state" - ) - save_optimizer( - optimizer=optimizer, - model=model, - weights_path=optimizer_path, - scheduler=scheduler, - ) - - # Save tokenizer separately if tokenizer_path provided - if tokenizer is not None and tokenizer_path is not None: - print(f"Saving tokenizer (or processor) to {tokenizer_path}") - tokenizer.save_pretrained(tokenizer_path) - +def _infer_checkpoint_root(weights_path: str) -> str: + """Infer checkpoint root directory from weights path. -def load_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[Any] = None, - optimizer_path: Optional[str] = None, -) -> None: - """Load a model weights and optionally optimizer state. + When weights_path ends with "…/weights/model", we need the parent of + the weights directory (the checkpoint root), not the weights directory itself. Args: - model: The PyTorch model whose weights to update - weights_path: Path to load model weights from - optimizer: Optional optimizer to load state into - scheduler: Optional scheduler to load state into - optimizer_path: Path to load optimizer state from (required if optimizer provided) - """ - print(f"Loading weights from {weights_path}") - - model_save_format, is_peft = detect_checkpoint_format(weights_path) - - try: - format_enum = SerializationFormat[model_save_format.upper()] - - # append /model to the weights_path if it doesn't exist - # TODO: remove this once nemo-automodel is updated - if not weights_path.endswith("/model"): - weights_path = os.path.join(weights_path, "model") - - # Load model using nemo-automodel API - load_model( - model=model, - model_path=weights_path, - model_save_format=format_enum, - is_peft=is_peft, - ) - except FileNotFoundError as e: - msg = ( - f"Failed to load model from '{weights_path}': {e}\n" - "Note: DTensorPolicyWorkerV2 expects:\n" - " - Model shards under '/weights/model'\n" - " - Optimizer states under '/optimizer/optim'\n" - "Please verify your checkpoint layout." - ) - raise FileNotFoundError(msg) from e + weights_path: Path to model weights (e.g., "/path/to/policy/weights/model") - if optimizer is not None: - if optimizer_path is None: - raise ValueError( - "optimizer_path must be provided when loading optimizer state" - ) - print(f"Loading optimizer from {optimizer_path}") - load_optimizer( - optimizer=optimizer, - model=model, - weights_path=optimizer_path, - scheduler=scheduler, - ) + Returns: + str: Checkpoint root directory (e.g., "/path/to/policy") + """ + weights_dir = os.path.dirname(weights_path) + if weights_dir.endswith("weights"): + return os.path.dirname(weights_dir) + return weights_dir diff --git a/nemo_rl/utils/venvs.py b/nemo_rl/utils/venvs.py index c5511473ea..667a45a9f1 100644 --- a/nemo_rl/utils/venvs.py +++ b/nemo_rl/utils/venvs.py @@ -14,6 +14,7 @@ import logging import os import shlex +import shutil import subprocess import time from functools import lru_cache @@ -71,8 +72,6 @@ def create_local_venv( # Force rebuild if requested if force_rebuild and os.path.exists(venv_path): logger.info(f"Force rebuilding venv at {venv_path}") - import shutil - shutil.rmtree(venv_path) logger.info(f"Creating new venv at {venv_path}") diff --git a/pyproject.toml b/pyproject.toml index 177feccbd4..b35ffa49f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,10 +58,13 @@ automodel = [ # Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular) # https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108 # https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76 - "vllm==0.11.2", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved + "vllm==0.11.2", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved "flash-attn==2.8.1", "mamba-ssm", "causal-conv1d", + "nv-grouped-gemm", + "transformer-engine[pytorch]==2.8.0", + "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480", ] vllm = [ "cuda-python", @@ -69,7 +72,7 @@ vllm = [ # deep_ep also needs libibverbs-dev # sudo apt-get update # sudo apt-get install libibverbs-dev - "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@e3908bf5bd0cc6265bcb225d15cd8c996d4759ef", + "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480", "vllm==0.11.2", "num2words>=0.5.14", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved @@ -216,13 +219,11 @@ default-groups = ["dev", "build"] link-mode = "copy" # The TE override is needed because automodel/mbridge we are on is still on 2.5.0 # The opencv-python-headless override is needed because automodel pins it to 4.10.0.84, whereas vllm>=0.11.0 needs >= 4.11.0 -# The transformers override is needed since automodel is still on <=4.55.4 # The timm override is needed because current automodel pins to 1.0.16. This can be removed once we move ToT automodel # The nvidia-modelopt override is needed because mcore is still on 0.33 override-dependencies = [ "transformer-engine[pytorch]==2.8.0", "opencv-python-headless>=4.11.0", - "transformers==4.57.1", "timm<=1.0.22", "nvidia-modelopt[torch]>=0.39.0", ] @@ -265,7 +266,7 @@ requires-dist = ["torch", "packaging", "ninja", "causal-conv1d"] [[tool.uv.dependency-metadata]] name = "deep_ep" # This version has to match the version in the commit/rev/tag used -version = "v1.1.0+e3908bf" +version = "v1.2.1+bfded34" requires-dist = ["torch", "packaging", "ninja"] [[tool.uv.dependency-metadata]] diff --git a/pyrefly.toml b/pyrefly.toml index 95f8943e42..74f0f29ed9 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -109,11 +109,13 @@ project-includes = [ "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", + "nemo_rl/models/policy/workers/__init__.py", + "nemo_rl/models/policy/workers/patches.py", "nemo_rl/utils/__init__.py", + "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", "nemo_rl/utils/native_checkpoint.py", - "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/nsys.py", "nemo_rl/utils/nvml.py", "nemo_rl/utils/packed_tensor.py", diff --git a/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.sh b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.sh new file mode 100755 index 0000000000..24b5977f7e --- /dev/null +++ b/tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=30 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=ruit_personal_debug \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 7.0' \ + 'data["train/loss"]["50"] < 0.4' \ + 'data["train/grad_norm"]["50"] < 2.5' +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index c9dddde4f0..7d830fbfd5 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -87,6 +87,9 @@ tests/test_suites/llm/sft-qwen2.5-math7b-2n8g-megatron.sh # Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571 # tests/test_suites/llm/sft-nemotron-super-49b-8n8g-fsdp2tp4cp8-tulu-v3.sh +# gpt-oss 20b DeepEP test +tests/test_suites/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.sh + ####### # DPO # ####### diff --git a/tests/unit/models/dtensor/test_lora.py b/tests/unit/models/dtensor/test_lora.py index eac60eb6fe..b4c92f1120 100644 --- a/tests/unit/models/dtensor/test_lora.py +++ b/tests/unit/models/dtensor/test_lora.py @@ -32,10 +32,6 @@ apply_lora_to_linear_modules, ) -from nemo_rl.models.policy.workers.dtensor_policy_worker_v2 import ( - _patched_init_lora_weights, -) - class SimpleLoraMock(nn.Module): """Simple mock LoRA module for testing initialization.""" @@ -46,69 +42,6 @@ def __init__(self, in_features=128, out_features=256, lora_dim=8): self.lora_B = nn.Linear(lora_dim, out_features, bias=False) -@pytest.mark.parametrize("init_method", ["xavier"]) -def test_lora_init_differs_from_upstream_buggy_version(init_method): - """ - Test that our patched LoRA initialization differs from the buggy upstream version. - - Remove this test once Automodel is bumped to commit 2d20e33a19d5e53a271b1403b507475e68ad14dc or later. - - Issue: https://github.com/NVIDIA-NeMo/RL/issues/1586 - """ - torch.manual_seed(42) - - # Create two identical LoRA modules - lora_buggy = LinearLoRA(nn.Linear(16, 16)) - lora_patched = LinearLoRA(nn.Linear(16, 16)) - - # Copy initial weights to ensure identical starting point - lora_patched.lora_A.weight.data.copy_(lora_buggy.lora_A.weight.data) - lora_patched.lora_B.weight.data.copy_(lora_buggy.lora_B.weight.data) - - # Apply buggy upstream initialization - torch.manual_seed(42) - lora_buggy.init_lora_weights(init_method) - - # Apply our patched initialization - torch.manual_seed(42) - _patched_init_lora_weights(lora_patched, init_method) - - # For xavier method, they should differ (that's the bug) - - # Assert that weights differ due to the upstream bug - are_equal_A = torch.allclose( - lora_buggy.lora_A.weight.data, - lora_patched.lora_A.weight.data, - atol=1e-6, - rtol=1e-6, - ) - - assert not are_equal_A, ( - "LoRA A weights should differ for xavier initialization. " - "If this assertion fails, the upstream bug has been fixed in Automodel. " - "You can:\n" - "1. Remove the patch in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py\n" - "2. Remove the patching call\n" - "3. Close issue: https://github.com/NVIDIA-NeMo/RL/issues/1586\n" - "4. Delete this test" - ) - - # LoRA B should always be zero-initialized (both implementations do this correctly) - are_equal_B = torch.allclose( - lora_buggy.lora_B.weight.data, - lora_patched.lora_B.weight.data, - atol=0, - rtol=0, - ) - assert are_equal_B, "LoRA B weights should both be zero" - assert torch.all(lora_buggy.lora_B.weight.data == 0), ( - "LoRA B should be zero-initialized" - ) - assert torch.all(lora_patched.lora_B.weight.data == 0), ( - "LoRA B should be zero-initialized" - ) - - def test_lora_init_statistical_properties(): """ Additional test to verify the statistical properties of the patched initialization. @@ -119,7 +52,7 @@ def test_lora_init_statistical_properties(): lora = SimpleLoraMock(in_features=512, out_features=1024, lora_dim=32) # Test xavier initialization - _patched_init_lora_weights(lora, "xavier") + LinearLoRA.init_lora_weights(lora, "xavier") # Xavier normal should have mean ≈ 0 and specific std mean_A = lora.lora_A.weight.data.mean().item() @@ -137,7 +70,7 @@ def test_lora_init_statistical_properties(): # Test kaiming initialization lora2 = SimpleLoraMock(in_features=512, out_features=1024, lora_dim=32) - _patched_init_lora_weights(lora2, "kaiming") + LinearLoRA.init_lora_weights(lora2, "kaiming") mean_A2 = lora2.lora_A.weight.data.mean().item() assert abs(mean_A2) < 0.1, f"Kaiming should have mean ≈ 0, got {mean_A2}" diff --git a/tests/unit/models/policy/test_automodel_types.py b/tests/unit/models/policy/test_automodel_types.py new file mode 100644 index 0000000000..676305274a --- /dev/null +++ b/tests/unit/models/policy/test_automodel_types.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import get_type_hints + +import pytest + +# Check if nemo_automodel is available for tests that need it +try: + from nemo_automodel.components.moe.utils import BackendConfig # noqa: F401 + + NEMO_AUTOMODEL_AVAILABLE = True +except ImportError: + NEMO_AUTOMODEL_AVAILABLE = False + +from nemo_rl.models.policy import AutomodelBackendConfig + + +def get_typeddict_keys(typed_dict_class): + """Get all keys from a TypedDict class.""" + return set(get_type_hints(typed_dict_class).keys()) + + +@pytest.mark.automodel +class TestAutoModelBackendConfig: + def test_backend_config_keys_are_defined(self): + keys = get_typeddict_keys(AutomodelBackendConfig) + assert len(keys) > 0, "AutoModelBackendConfig should have keys defined" + assert "_target_" in keys, "_target_ should be a key in AutoModelBackendConfig" + + @pytest.mark.skipif( + not NEMO_AUTOMODEL_AVAILABLE, reason="nemo_automodel not available" + ) + def test_instantiate_backend_config_from_nemo_automodel(self): + keys = get_typeddict_keys(AutomodelBackendConfig) + backend_keys = {k for k in keys if k != "_target_"} + + config_dict: AutomodelBackendConfig = { + "_target_": "nemo_automodel.components.moe.utils.BackendConfig", + "attn": "te", + "linear": "te", + "rms_norm": "te", + "enable_deepep": True, + "fake_balanced_gate": False, + "enable_hf_state_dict_adapter": True, + "enable_fsdp_optimizations": False, + "gate_precision": "float64", + } + + # Remove _target_ as it's a Hydra convention, not a BackendConfig param + backend_kwargs = {k: v for k, v in config_dict.items() if k != "_target_"} + + # Instantiate the actual BackendConfig + backend = BackendConfig(**backend_kwargs) + + # Verify each key from TypedDict is accessible on the backend + for key in backend_keys: + assert hasattr(backend, key), ( + f"BackendConfig missing attribute '{key}' defined in TypedDict" + ) diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 99ec9bcfff..4ddd320bfe 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -231,24 +231,8 @@ def calculate_token_logprobs(model_name: str, data: BatchedDataDict): return token_logprobs -@pytest.fixture(scope="module") -def two_gpu_virtual_cluster(): - cluster_name = "test" - print(f"Creating virtual cluster '{cluster_name}'...") - cluster = RayVirtualCluster( - name=cluster_name, - bundle_ct_per_node_list=[2], # Use tp bundles, one per GPU - use_gpus=True, - num_gpus_per_node=2, # Using tp GPUs - max_colocated_worker_groups=1, # Only one worker group - ) - yield cluster - print("Shutting down virtual cluster...") - cluster.shutdown() - - -@pytest.fixture -def base_setup(request, two_gpu_virtual_cluster): +def _base_setup_impl(request, cluster): + """Implementation for base setup - can be used with any cluster.""" params = request.param if hasattr(request, "param") else None assert params is not None, "params is not set" @@ -273,7 +257,7 @@ def base_setup(request, two_gpu_virtual_cluster): tokenizer = get_tokenizer(config["tokenizer"]) print(f"Creating {mode} Policy with {specified_config}...") policy = Policy( - cluster=two_gpu_virtual_cluster, + cluster=cluster, config=config, tokenizer=tokenizer, init_reference_model=False, @@ -298,319 +282,6 @@ def base_setup(request, two_gpu_virtual_cluster): policy.shutdown() -@pytest.fixture -def policy_setup(request, two_gpu_virtual_cluster, tiny_llama_model_path): - """Setup and teardown for policy tests - creates a virtual cluster and policy.""" - params = request.param if hasattr(request, "param") else {} - use_v2 = params.get("dtensor_v2", False) - enable_loras = params.get("enable_loras", False) - - config = create_test_config( - tiny_llama_model_path, dtensor_v2=use_v2, enable_loras=enable_loras - ) - tokenizer = get_tokenizer(config["tokenizer"]) - config["generation"] = configure_generation_config(config["generation"], tokenizer) - - print("Creating Policy...") - policy = Policy(cluster=two_gpu_virtual_cluster, config=config, tokenizer=tokenizer) - - yield policy - - print("Shutting down policy...") - policy.shutdown() - - -@pytest.fixture( - params=[ - # model_fixture_name tp cp sp cpu act - ("tiny_llama_model_path", 1, 1, False, False, False), - ("tiny_llama_model_path", 1, 1, True, False, False), - ("tiny_llama_model_path", 1, 1, False, True, False), - ("tiny_llama_model_path", 1, 1, False, False, True), - ("tiny_llama_model_path", 1, 2, False, False, False), - ("tiny_qwen2_model_path", 1, 1, True, True, False), - ("tiny_qwen2_model_path", 1, 1, True, False, True), - ("tiny_qwen2_model_path", 1, 1, False, True, True), - ("tiny_qwen2_model_path", 1, 1, True, True, True), - ("tiny_qwen2_model_path", 1, 2, False, False, False), - ("tiny_qwen3_model_path", 1, 1, True, True, False), - ("tiny_qwen3_model_path", 1, 1, True, False, True), - ("tiny_qwen3_model_path", 1, 1, False, True, True), - ("tiny_qwen3_model_path", 1, 1, True, True, True), - ("tiny_qwen3_model_path", 1, 2, False, False, False), - ( - "tiny_gemma3_model_path", - 1, - 1, - True, - True, - False, - ), # gemma3 doesn't support spda - ("tiny_gemma3_model_path", 1, 1, True, False, True), - ("tiny_gemma3_model_path", 1, 1, False, True, True), - ("tiny_gemma3_model_path", 1, 1, True, True, True), - # CP doesn't support gemma3 due to spda input has attent_mask != None. - # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 - # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), - # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), - # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), - ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), - ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), - # nemotron5_h doesn't support cp - ] -) -def training_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for training tests.""" - request.param = { - "mode": "train", - "enable_loras": False, - "lora_config": None, - "model_fixture_name": request.param[0], - "specified_config": { - "tp": request.param[1], - "cp": request.param[2], - "sp": request.param[3], - "cpu_offload": request.param[4], - "activation_checkpointing": request.param[5], - }, - } - yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) - - -@pytest.fixture( - params=[ - # TP=2, CP=1 - ("tiny_qwen2_model_path", 2, 1, False, True, False), - ("tiny_qwen2_model_path", 2, 1, False, False, False), - ("tiny_llama_model_path", 2, 1, False, False, False), - ("tiny_llama_model_path", 2, 1, False, True, False), - ("tiny_llama_model_path", 2, 1, False, True, True), - ("tiny_qwen3_model_path", 2, 1, False, True, False), - ("tiny_qwen3_model_path", 2, 1, False, False, False), - ("tiny_gemma3_model_path", 2, 1, False, True, False), - ("tiny_gemma3_model_path", 2, 1, False, False, False), - # TP=1, CP=2 - ("tiny_qwen2_model_path", 1, 2, False, True, False), - ("tiny_qwen2_model_path", 1, 2, False, False, False), - ("tiny_llama_model_path", 1, 2, False, False, False), - ("tiny_llama_model_path", 1, 2, False, True, False), - ("tiny_llama_model_path", 1, 2, False, True, True), - ("tiny_qwen3_model_path", 1, 2, False, True, False), - ("tiny_qwen3_model_path", 1, 2, False, False, False), - ] -) -def logprob_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for logprob tests.""" - request.param = { - "mode": "logprob", - "enable_loras": False, - "lora_config": None, - "model_fixture_name": request.param[0], - "specified_config": { - "tp": request.param[1], - "cp": request.param[2], - "sp": request.param[3], - "cpu_offload": request.param[4], - "activation_checkpointing": request.param[5], - }, - } - yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) - - -@pytest.fixture( - params=[ - # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton - ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), - ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), - ( - "tiny_qwen2_model_path", - ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], - [], - False, - 32, - 16, - 0.0, - "post", - "uniform", - True, - ), - ( - "tiny_qwen2_model_path", - [], - ["q_proj", "k_proj"], - False, - 32, - 16, - 0.0, - "post", - "uniform", - True, - ), - ] -) -def training_with_lora_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for training with lora tests.""" - request.param = { - "mode": "train", - "enable_loras": True, - "model_fixture_name": request.param[0], - "specified_config": {}, - "lora_config": { - "target_modules": request.param[1], - "exclude_modules": request.param[2], - "match_all_linear": request.param[3], - "dim": request.param[4], - "alpha": request.param[5], - "dropout": request.param[6], - "dropout_position": request.param[7], - "lora_A_init": request.param[8], - "use_triton": request.param[9], - }, - } - yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) - - -@pytest.fixture( - params=[ - # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton - ("tiny_llama_model_path", [], [], True, 16, 32, 0.0, "post", "xavier", True), - ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), - ( - "tiny_qwen2_model_path", - ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], - [], - False, - 32, - 16, - 0.0, - "post", - "uniform", - True, - ), - ( - "tiny_qwen2_model_path", - [], - ["q_proj", "k_proj"], - False, - 32, - 16, - 0.0, - "post", - "uniform", - True, - ), - ] -) -def logprob_with_lora_setup(request, two_gpu_virtual_cluster): - """Setup and teardown specifically for logprob with lora tests.""" - request.param = { - "mode": "logprob", - "enable_loras": True, - "model_fixture_name": request.param[0], - "specified_config": {}, - "lora_config": { - "target_modules": request.param[1], - "exclude_modules": request.param[2], - "match_all_linear": request.param[3], - "dim": request.param[4], - "alpha": request.param[5], - "dropout": request.param[6], - "dropout_position": request.param[7], - "lora_A_init": request.param[8], - "use_triton": request.param[9], - }, - } - yield from base_setup.__wrapped__(request, two_gpu_virtual_cluster) - - -@pytest.mark.hf_gated -@pytest.mark.timeout(360) -# @pytest.mark.parametrize("policy_setup", [True, False], indirect=True) -@pytest.mark.parametrize( - "policy_setup", - [ - {"dtensor_v2": True, "enable_loras": False}, - {"dtensor_v2": True, "enable_loras": True}, - {"dtensor_v2": False, "enable_loras": False}, - ], - indirect=True, -) -def test_lm_policy_init(policy_setup): - policy = policy_setup - - # Verify we have two workers, one per GPU - assert len(policy.worker_group.workers) == 2, "Should have 2 workers, one per GPU" - - # Check workers are alive - worker_alive = ray.get([w.is_alive.remote() for w in policy.worker_group.workers]) - assert all(worker_alive), f"Not all workers are alive: {worker_alive}" - - # Get GPU info from both workers to verify GPU usage - print("\nGetting GPU information from workers...") - gpu_infos = ray.get([w.get_gpu_info.remote() for w in policy.worker_group.workers]) - print("\nGPU Information:") - for i, info in enumerate(gpu_infos): - print(f"\nWorker {i} GPU Info:") - pprint.pprint(info) - - # Check 1: Verify workers have different ranks - gpu_ranks = [info["rank"] for info in gpu_infos] - assert len(set(gpu_ranks)) == 2, f"Expected 2 different ranks, got {gpu_ranks}" - assert set(gpu_ranks) == {0, 1}, f"Expected ranks 0 and 1, got {gpu_ranks}" - - # Check 2: Verify workers have different local_ranks - local_ranks = [info["local_rank"] for info in gpu_infos] - assert len(set(local_ranks)) == 2, ( - f"Expected 2 different local_ranks, got {local_ranks}" - ) - assert set(local_ranks) == {0, 1}, ( - f"Expected local_ranks 0 and 1, got {local_ranks}" - ) - - # Check 3: Verify workers have different CUDA_VISIBLE_DEVICES - cuda_visible_devices = [ - info["env_vars"].get("CUDA_VISIBLE_DEVICES") for info in gpu_infos - ] - assert len(set(cuda_visible_devices)) == 2, ( - f"Expected different CUDA_VISIBLE_DEVICES, got {cuda_visible_devices}" - ) - - # Check 4: Verify all workers report correct world_size - for info in gpu_infos: - assert info["world_size"] == 2, ( - f"Expected world_size=2, got {info['world_size']}" - ) - assert info["env_vars"]["WORLD_SIZE"] == "2", ( - f"Expected WORLD_SIZE=2, got {info['env_vars']['WORLD_SIZE']}" - ) - - # Check 5: Verify GPU memory is allocated on both GPUs - for info in gpu_infos: - assert info["memory_allocated_mb"] > 10, ( - f"Not enough memory allocated on GPU for rank {info['rank']}: {info['memory_allocated_mb']:.2f} MB" - ) - - # Check 6: Verify model parameters are on CUDA devices for both workers - for info in gpu_infos: - param_sample = list(info["parameter_sample"].values())[0] - assert "cuda" in param_sample["device"], ( - f"Parameter not on CUDA device: {param_sample['device']}" - ) - - # Check 8: Verify same model parameters are being tracked across workers - param_names = [list(info["parameter_sample"].keys())[0] for info in gpu_infos] - assert len(set(param_names)) == 1, ( - f"Workers are not tracking the same parameter: {param_names}" - ) - - # Check 9: Both workers should see their device as cuda:0 (correct distributed behavior) - for info in gpu_infos: - param_device = list(info["parameter_sample"].values())[0]["device"] - assert param_device == "cuda:0", ( - f"Expected parameter device to be cuda:0, got {param_device}" - ) - - def _test_dtensor_worker_training(policy, data, loss_fn): def verify_loss_tensor(loss_tensor): assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN" @@ -666,21 +337,6 @@ def verify_loss_tensor(loss_tensor): ) -@pytest.mark.hf_gated -@pytest.mark.timeout(360) -@pytest.mark.parametrize("use_v2", [True, False]) -def test_dtensor_worker_training(use_v2, training_setup): - policy, data, loss_fn = training_setup - _test_dtensor_worker_training(policy, data, loss_fn) - - -@pytest.mark.hf_gated -@pytest.mark.timeout(360) -def test_dtensor_worker_training_with_lora(training_with_lora_setup): - policy, data, loss_fn = training_with_lora_setup - _test_dtensor_worker_training(policy, data, loss_fn) - - def _test_dtensor_worker_logprob(policy, data, logprobs): # Verify resources were created properly assert policy is not None, "Policy was not created properly" assert data is not None, "Test data was not created properly" @@ -697,280 +353,794 @@ def _test_dtensor_worker_logprob(policy, data, logprobs): @pytest.mark.hf_gated -@pytest.mark.timeout(360) -@pytest.mark.parametrize("use_v2", [True, False]) -def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_setup): - policy, data, logprobs = logprob_setup - _test_dtensor_worker_logprob(policy, data, logprobs) +class TestSingleGPUCluster: + """Tests that run on a single GPU cluster.""" + + @pytest.fixture(scope="class") + def single_gpu_cluster(self): + """Class-scoped single GPU virtual cluster fixture.""" + cluster_name = "test_single_gpu" + print(f"Creating single GPU virtual cluster '{cluster_name}'...") + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[1], # Single GPU bundle + use_gpus=True, + num_gpus_per_node=1, # Using 1 GPU + max_colocated_worker_groups=1, # Only one worker group + ) + yield cluster + print("Shutting down single GPU virtual cluster...") + cluster.shutdown() + + @pytest.mark.timeout(360) + @pytest.mark.parametrize("use_v2", [True, False]) + def test_dtensor_single_gpu_training( + self, use_v2, single_gpu_cluster, tiny_llama_model_path + ): + """Test DTensor training with a single GPU cluster (no parallelism).""" + config = create_test_config( + tiny_llama_model_path, + tp=1, + cp=1, + sp=False, + cpu_offload=False, + activation_checkpointing=False, + dtensor_v2=use_v2, + ) + tokenizer = get_tokenizer(config["tokenizer"]) + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) + print("Creating Policy with single GPU cluster...") + policy = Policy( + cluster=single_gpu_cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=False, + ) -@pytest.mark.hf_gated -@pytest.mark.timeout(360) -def test_dtensor_worker_logprob_with_lora(logprob_with_lora_setup): - policy, data, logprobs = logprob_with_lora_setup - _test_dtensor_worker_logprob(policy, data, logprobs) + try: + # Verify we have one worker + assert len(policy.worker_group.workers) == 1, ( + "Should have 1 worker for single GPU" + ) + # Check worker is alive + worker_alive = ray.get( + [w.is_alive.remote() for w in policy.worker_group.workers] + ) + assert all(worker_alive), f"Worker is not alive: {worker_alive}" -@pytest.mark.hf_gated -@pytest.mark.parametrize("use_v2", [True, False]) -def test_dtensor_tp_and_tied_model_with_custom_parallel_plan( - use_v2, two_gpu_virtual_cluster, tiny_llama_tied_model_path -): - """Test that DTensor with a tp > 1 and a tied model with a custom parallel plan works.""" - from torch.distributed.tensor.parallel import ColwiseParallel - from torch.distributed.tensor.placement_types import Replicate + # Get GPU info to verify setup + gpu_infos = ray.get( + [w.get_gpu_info.remote() for w in policy.worker_group.workers] + ) + assert len(gpu_infos) == 1, "Should have 1 GPU info" + assert gpu_infos[0]["world_size"] == 1, ( + "World size should be 1 for single GPU" + ) + assert gpu_infos[0]["rank"] == 0, "Rank should be 0 for single GPU" + + # Create test batch + data = create_test_batch(mode="train") + loss_fn = SimpleLoss() + + # Test training + policy.prepare_for_training() + + losses = [] + for step in range(2): + results = policy.train(data, loss_fn) + assert "loss" in results, "Training results should contain 'loss'" + loss_tensor = results["loss"] + assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN" + assert not torch.isinf(loss_tensor).any(), "Loss should not be Inf" + losses.append(loss_tensor[-1].item()) + print(f"Step {step} - Training loss: {results['loss']}") + + policy.finish_training() + + # Verify loss changed (model was updated) + assert losses[0] > losses[-1], ( + "Loss should decrease over training iterations" + ) - custom_parallel_plan = { - "lm_head": ColwiseParallel(output_layouts=Replicate()), - "model.embed_tokens": ColwiseParallel(output_layouts=Replicate()), - } - config = create_test_config( - model_name=tiny_llama_tied_model_path, - tp=2, - cp=1, - sp=False, - cpu_offload=False, - activation_checkpointing=False, - custom_parallel_plan=custom_parallel_plan, - dtensor_v2=use_v2, - ) - tokenizer = get_tokenizer(config["tokenizer"]) - - policy = Policy( - tokenizer=tokenizer, - config=config, - init_optimizer=False, - init_reference_model=False, - cluster=two_gpu_virtual_cluster, - ) + finally: + policy.shutdown() - # Verify that the model is parallelized as expected - state_dict = ray.get(policy.worker_group.workers[0].return_state_dict.remote()) - total_shape = state_dict["lm_head.weight"].shape - sharded_shape = state_dict["lm_head.weight"].to_local().shape - assert total_shape[0] == sharded_shape[0], ( - "lm_head.weight should have the same number of rows" - ) - assert total_shape[1] == sharded_shape[1] * 2, ( - "lm_head.weight should be sharded across 2 GPUs" - ) + @pytest.mark.timeout(360) + @pytest.mark.parametrize("use_v2", [True, False]) + def test_dtensor_single_gpu_logprob( + self, use_v2, single_gpu_cluster, tiny_llama_model_path + ): + """Test DTensor logprob computation with a single GPU cluster (no parallelism).""" + config = create_test_config( + tiny_llama_model_path, + tp=1, + cp=1, + sp=False, + cpu_offload=False, + activation_checkpointing=False, + dtensor_v2=use_v2, + ) + tokenizer = get_tokenizer(config["tokenizer"]) + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) + + print("Creating Policy with single GPU cluster for logprob...") + policy = Policy( + cluster=single_gpu_cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=False, + ) + + try: + # Verify we have one worker + assert len(policy.worker_group.workers) == 1, ( + "Should have 1 worker for single GPU" + ) + + # Create test batch and compute reference logprobs + data = create_test_batch(mode="logprob") + expected_logprobs = calculate_token_logprobs(tiny_llama_model_path, data) + + # Test logprob computation + policy.prepare_for_lp_inference() + policy_logprobs = policy.get_logprobs(data)["logprobs"] + + max_diff = torch.max(torch.abs(policy_logprobs - expected_logprobs)) + print(f"Max logprob diff: {max_diff}") + assert torch.allclose(policy_logprobs, expected_logprobs), ( + f"Logprobs should match reference. Max diff: {max_diff}" + ) - # Clean up - policy.shutdown() + finally: + policy.shutdown() @pytest.mark.hf_gated -@pytest.mark.timeout(180) -def test_dtensor_loss_independent_of_microbatch_size_two_gpus( - two_gpu_virtual_cluster, tiny_llama_model_path -): - """Tests that changing microbatch size while keeping global batch size constant does not affect loss values in DTensor.""" - # Create test batch with global batch size of 8 - global_batch_size = 8 - seq_len = 128 - vocab_size = 32000 +class TestTwoGPUCluster: + """Tests that run on a two GPU cluster.""" + + @pytest.fixture(scope="class") + def two_gpu_cluster(self): + """Class-scoped two GPU virtual cluster fixture.""" + cluster_name = "test_two_gpu" + print(f"Creating virtual cluster '{cluster_name}'...") + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[2], # Use tp bundles, one per GPU + use_gpus=True, + num_gpus_per_node=2, # Using tp GPUs + max_colocated_worker_groups=1, # Only one worker group + ) + yield cluster + print("Shutting down virtual cluster...") + cluster.shutdown() + + @pytest.fixture + def policy_setup(self, request, two_gpu_cluster, tiny_llama_model_path): + """Setup and teardown for policy tests - creates a virtual cluster and policy.""" + params = request.param if hasattr(request, "param") else {} + use_v2 = params.get("dtensor_v2", False) + enable_loras = params.get("enable_loras", False) + + config = create_test_config( + tiny_llama_model_path, dtensor_v2=use_v2, enable_loras=enable_loras + ) + tokenizer = get_tokenizer(config["tokenizer"]) + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) - # Create test input_ids and attention_mask - input_ids = torch.randint(0, vocab_size, (global_batch_size, seq_len)) - attention_mask = torch.ones(global_batch_size, seq_len) - input_lengths = attention_mask.sum(dim=1).to(torch.int32) + print("Creating Policy...") + policy = Policy(cluster=two_gpu_cluster, config=config, tokenizer=tokenizer) - # Create data dictionary - data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - "attention_mask": attention_mask, - "token_mask": torch.triu( - torch.ones(global_batch_size, seq_len), diagonal=1 - ), # give different examples different numbers of valid tokens - "sample_mask": torch.ones((global_batch_size,)), - "labels": torch.randint(0, vocab_size, (global_batch_size, seq_len)), - "num_valid_tokens_in_batch": torch.tensor( - [seq_len] * global_batch_size, dtype=torch.float32 + yield policy + + print("Shutting down policy...") + policy.shutdown() + + @pytest.fixture( + params=[ + # model_fixture_name tp cp sp cpu act + ("tiny_llama_model_path", 1, 1, False, False, False), + ("tiny_llama_model_path", 1, 1, True, False, False), + ("tiny_llama_model_path", 1, 1, False, True, False), + ("tiny_llama_model_path", 1, 1, False, False, True), + ("tiny_llama_model_path", 1, 2, False, False, False), + ("tiny_qwen2_model_path", 1, 1, True, True, False), + ("tiny_qwen2_model_path", 1, 1, True, False, True), + ("tiny_qwen2_model_path", 1, 1, False, True, True), + ("tiny_qwen2_model_path", 1, 1, True, True, True), + ("tiny_qwen2_model_path", 1, 2, False, False, False), + ("tiny_qwen3_model_path", 1, 1, True, True, False), + ("tiny_qwen3_model_path", 1, 1, True, False, True), + ("tiny_qwen3_model_path", 1, 1, False, True, True), + ("tiny_qwen3_model_path", 1, 1, True, True, True), + ("tiny_qwen3_model_path", 1, 2, False, False, False), + ( + "tiny_gemma3_model_path", + 1, + 1, + True, + True, + False, + ), # gemma3 doesn't support spda + ("tiny_gemma3_model_path", 1, 1, True, False, True), + ("tiny_gemma3_model_path", 1, 1, False, True, True), + ("tiny_gemma3_model_path", 1, 1, True, True, True), + # CP doesn't support gemma3 due to spda input has attent_mask != None. + # Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881 + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, False), + # ("tiny_nemotron5_h_model_path", 1, 1, True, False, True), + # ("tiny_nemotron5_h_model_path", 1, 1, True, True, True), + ("tiny_nemotron5_h_model_path", 1, 1, False, False, False), + ("tiny_nemotron5_h_model_path", 1, 1, False, True, True), + # nemotron5_h doesn't support cp + ] + ) + def training_setup(self, request, two_gpu_cluster): + """Setup and teardown specifically for training tests.""" + request.param = { + "mode": "train", + "enable_loras": False, + "lora_config": None, + "model_fixture_name": request.param[0], + "specified_config": { + "tp": request.param[1], + "cp": request.param[2], + "sp": request.param[3], + "cpu_offload": request.param[4], + "activation_checkpointing": request.param[5], + }, + } + yield from _base_setup_impl(request, two_gpu_cluster) + + @pytest.fixture( + params=[ + # TP=2, CP=1 + ("tiny_qwen2_model_path", 2, 1, False, True, False), + ("tiny_qwen2_model_path", 2, 1, False, False, False), + ("tiny_llama_model_path", 2, 1, False, False, False), + ("tiny_llama_model_path", 2, 1, False, True, False), + ("tiny_llama_model_path", 2, 1, False, True, True), + ("tiny_qwen3_model_path", 2, 1, False, True, False), + ("tiny_qwen3_model_path", 2, 1, False, False, False), + ("tiny_gemma3_model_path", 2, 1, False, True, False), + ("tiny_gemma3_model_path", 2, 1, False, False, False), + # TP=1, CP=2 + ("tiny_qwen2_model_path", 1, 2, False, True, False), + ("tiny_qwen2_model_path", 1, 2, False, False, False), + ("tiny_llama_model_path", 1, 2, False, False, False), + ("tiny_llama_model_path", 1, 2, False, True, False), + ("tiny_llama_model_path", 1, 2, False, True, True), + ("tiny_qwen3_model_path", 1, 2, False, True, False), + ("tiny_qwen3_model_path", 1, 2, False, False, False), + ] + ) + def logprob_setup(self, request, two_gpu_cluster): + """Setup and teardown specifically for logprob tests.""" + request.param = { + "mode": "logprob", + "enable_loras": False, + "lora_config": None, + "model_fixture_name": request.param[0], + "specified_config": { + "tp": request.param[1], + "cp": request.param[2], + "sp": request.param[3], + "cpu_offload": request.param[4], + "activation_checkpointing": request.param[5], + }, + } + yield from _base_setup_impl(request, two_gpu_cluster) + + @pytest.fixture( + params=[ + # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton + ( + "tiny_llama_model_path", + [], + [], + True, + 16, + 32, + 0.0, + "post", + "xavier", + True, + ), + ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), + ( + "tiny_qwen2_model_path", + ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], + [], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ( + "tiny_qwen2_model_path", + [], + ["q_proj", "k_proj"], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ] + ) + def training_with_lora_setup(self, request, two_gpu_cluster): + """Setup and teardown specifically for training with lora tests.""" + request.param = { + "mode": "train", + "enable_loras": True, + "model_fixture_name": request.param[0], + "specified_config": {}, + "lora_config": { + "target_modules": request.param[1], + "exclude_modules": request.param[2], + "match_all_linear": request.param[3], + "dim": request.param[4], + "alpha": request.param[5], + "dropout": request.param[6], + "dropout_position": request.param[7], + "lora_A_init": request.param[8], + "use_triton": request.param[9], + }, + } + yield from _base_setup_impl(request, two_gpu_cluster) + + @pytest.fixture( + params=[ + # model_name, target_modules, exclude_modules, match_all_linear, dim, alpha, dropout, dropout_position, lora_A_init, use_triton + ( + "tiny_llama_model_path", + [], + [], + True, + 16, + 32, + 0.0, + "post", + "xavier", + True, ), - "advantages": torch.randn(global_batch_size, seq_len), - "prev_logprobs": torch.randn(global_batch_size, seq_len), - "reference_policy_logprobs": torch.randn(global_batch_size, seq_len), - "generation_logprobs": torch.randn(global_batch_size, seq_len), + ("tiny_qwen2_model_path", [], [], True, 32, 32, 0.0, "pre", "xavier", True), + ( + "tiny_qwen2_model_path", + ["q_proj", "k_proj", "*gate_proj*", "*up_proj*", "*down_proj*"], + [], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ( + "tiny_qwen2_model_path", + [], + ["q_proj", "k_proj"], + False, + 32, + 16, + 0.0, + "post", + "uniform", + True, + ), + ] + ) + def logprob_with_lora_setup(self, request, two_gpu_cluster): + """Setup and teardown specifically for logprob with lora tests.""" + request.param = { + "mode": "logprob", + "enable_loras": True, + "model_fixture_name": request.param[0], + "specified_config": {}, + "lora_config": { + "target_modules": request.param[1], + "exclude_modules": request.param[2], + "match_all_linear": request.param[3], + "dim": request.param[4], + "alpha": request.param[5], + "dropout": request.param[6], + "dropout_position": request.param[7], + "lora_A_init": request.param[8], + "use_triton": request.param[9], + }, } + yield from _base_setup_impl(request, two_gpu_cluster) + + @pytest.mark.timeout(360) + @pytest.mark.parametrize( + "policy_setup", + [ + {"dtensor_v2": True, "enable_loras": False}, + {"dtensor_v2": True, "enable_loras": True}, + {"dtensor_v2": False, "enable_loras": False}, + ], + indirect=True, ) + def test_lm_policy_init(self, policy_setup): + policy = policy_setup - # Test with mbs=1, 2 microbatches per GPU - config = create_test_config(tiny_llama_model_path) - tokenizer = get_tokenizer(config["tokenizer"]) + # Verify we have two workers, one per GPU + assert len(policy.worker_group.workers) == 2, ( + "Should have 2 workers, one per GPU" + ) - print("Creating training Policy with mbs=1...") - policy_mbs1 = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - init_reference_model=False, - tokenizer=tokenizer, - ) + # Check workers are alive + worker_alive = ray.get( + [w.is_alive.remote() for w in policy.worker_group.workers] + ) + assert all(worker_alive), f"Not all workers are alive: {worker_alive}" - # Test NLLLoss and ClippedPGLossFn with mbs=1 - nll_loss_fn = NLLLoss() - pg_loss_fn = ClippedPGLossFn( - { - "ratio_clip_min": 0.2, - "ratio_clip_max": 0.2, - "ratio_clip_c": None, - "reference_policy_kl_penalty": 0.1, - "reference_policy_kl_type": "k3", - "kl_input_clamp_value": 20.0, - "kl_output_clamp_value": 10.0, - "disable_ppo_ratio": False, - "use_on_policy_kl_approximation": False, - "use_importance_sampling_correction": False, - "truncated_importance_sampling_ratio": None, - "sequence_level_importance_ratios": False, - "token_level_loss": True, - "force_on_policy_ratio": False, + # Get GPU info from both workers to verify GPU usage + print("\nGetting GPU information from workers...") + gpu_infos = ray.get( + [w.get_gpu_info.remote() for w in policy.worker_group.workers] + ) + print("\nGPU Information:") + for i, info in enumerate(gpu_infos): + print(f"\nWorker {i} GPU Info:") + pprint.pprint(info) + + # Check 1: Verify workers have different ranks + gpu_ranks = [info["rank"] for info in gpu_infos] + assert len(set(gpu_ranks)) == 2, f"Expected 2 different ranks, got {gpu_ranks}" + assert set(gpu_ranks) == {0, 1}, f"Expected ranks 0 and 1, got {gpu_ranks}" + + # Check 2: Verify workers have different local_ranks + local_ranks = [info["local_rank"] for info in gpu_infos] + assert len(set(local_ranks)) == 2, ( + f"Expected 2 different local_ranks, got {local_ranks}" + ) + assert set(local_ranks) == {0, 1}, ( + f"Expected local_ranks 0 and 1, got {local_ranks}" + ) + + # Check 3: Verify workers have different CUDA_VISIBLE_DEVICES + cuda_visible_devices = [ + info["env_vars"].get("CUDA_VISIBLE_DEVICES") for info in gpu_infos + ] + assert len(set(cuda_visible_devices)) == 2, ( + f"Expected different CUDA_VISIBLE_DEVICES, got {cuda_visible_devices}" + ) + + # Check 4: Verify all workers report correct world_size + for info in gpu_infos: + assert info["world_size"] == 2, ( + f"Expected world_size=2, got {info['world_size']}" + ) + assert info["env_vars"]["WORLD_SIZE"] == "2", ( + f"Expected WORLD_SIZE=2, got {info['env_vars']['WORLD_SIZE']}" + ) + + # Check 5: Verify GPU memory is allocated on both GPUs + for info in gpu_infos: + assert info["memory_allocated_mb"] > 10, ( + f"Not enough memory allocated on GPU for rank {info['rank']}: {info['memory_allocated_mb']:.2f} MB" + ) + + # Check 6: Verify model parameters are on CUDA devices for both workers + for info in gpu_infos: + param_sample = list(info["parameter_sample"].values())[0] + assert "cuda" in param_sample["device"], ( + f"Parameter not on CUDA device: {param_sample['device']}" + ) + + # Check 8: Verify same model parameters are being tracked across workers + param_names = [list(info["parameter_sample"].keys())[0] for info in gpu_infos] + assert len(set(param_names)) == 1, ( + f"Workers are not tracking the same parameter: {param_names}" + ) + + # Check 9: Both workers should see their device as cuda:0 (correct distributed behavior) + for info in gpu_infos: + param_device = list(info["parameter_sample"].values())[0]["device"] + assert param_device == "cuda:0", ( + f"Expected parameter device to be cuda:0, got {param_device}" + ) + + @pytest.mark.timeout(360) + @pytest.mark.parametrize("use_v2", [True, False]) + def test_dtensor_worker_training(self, use_v2, training_setup): + policy, data, loss_fn = training_setup + _test_dtensor_worker_training(policy, data, loss_fn) + + @pytest.mark.timeout(360) + def test_dtensor_worker_training_with_lora(self, training_with_lora_setup): + policy, data, loss_fn = training_with_lora_setup + _test_dtensor_worker_training(policy, data, loss_fn) + + @pytest.mark.timeout(360) + @pytest.mark.parametrize("use_v2", [True, False]) + def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded( + self, use_v2, logprob_setup + ): + policy, data, logprobs = logprob_setup + _test_dtensor_worker_logprob(policy, data, logprobs) + + @pytest.mark.timeout(360) + def test_dtensor_worker_logprob_with_lora(self, logprob_with_lora_setup): + policy, data, logprobs = logprob_with_lora_setup + _test_dtensor_worker_logprob(policy, data, logprobs) + + @pytest.mark.parametrize("use_v2", [True, False]) + def test_dtensor_tp_and_tied_model_with_custom_parallel_plan( + self, use_v2, two_gpu_cluster, tiny_llama_tied_model_path + ): + """Test that DTensor with a tp > 1 and a tied model with a custom parallel plan works.""" + from torch.distributed.tensor.parallel import ColwiseParallel + from torch.distributed.tensor.placement_types import Replicate + + custom_parallel_plan = { + "lm_head": ColwiseParallel(output_layouts=Replicate()), + "model.embed_tokens": ColwiseParallel(output_layouts=Replicate()), } - ) + config = create_test_config( + model_name=tiny_llama_tied_model_path, + tp=2, + cp=1, + sp=False, + cpu_offload=False, + activation_checkpointing=False, + custom_parallel_plan=custom_parallel_plan, + dtensor_v2=use_v2, + ) + tokenizer = get_tokenizer(config["tokenizer"]) - policy_mbs1.prepare_for_training() - mbs1_nll_results = policy_mbs1.train(data, nll_loss_fn) - mbs1_nll_loss = mbs1_nll_results["loss"] + policy = Policy( + tokenizer=tokenizer, + config=config, + init_optimizer=False, + init_reference_model=False, + cluster=two_gpu_cluster, + ) - mbs1_pg_results = policy_mbs1.train(data, pg_loss_fn) - mbs1_pg_loss = mbs1_pg_results["loss"] + # Verify that the model is parallelized as expected + state_dict = ray.get(policy.worker_group.workers[0].return_state_dict.remote()) + total_shape = state_dict["lm_head.weight"].shape + sharded_shape = state_dict["lm_head.weight"].to_local().shape + assert total_shape[0] == sharded_shape[0], ( + "lm_head.weight should have the same number of rows" + ) + assert total_shape[1] == sharded_shape[1] * 2, ( + "lm_head.weight should be sharded across 2 GPUs" + ) - policy_mbs1.worker_group.shutdown() + # Clean up + policy.shutdown() - # Test with mbs=2, 1 microbatch per GPU - config = create_test_config(tiny_llama_model_path) - config["train_micro_batch_size"] = 2 - config["generation"] = configure_generation_config(config["generation"], tokenizer) + @pytest.mark.timeout(180) + def test_dtensor_loss_independent_of_microbatch_size_two_gpus( + self, two_gpu_cluster, tiny_llama_model_path + ): + """Tests that changing microbatch size while keeping global batch size constant does not affect loss values in DTensor.""" + # Create test batch with global batch size of 8 + global_batch_size = 8 + seq_len = 128 + vocab_size = 32000 + + # Create test input_ids and attention_mask + input_ids = torch.randint(0, vocab_size, (global_batch_size, seq_len)) + attention_mask = torch.ones(global_batch_size, seq_len) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + + # Create data dictionary + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "token_mask": torch.triu( + torch.ones(global_batch_size, seq_len), diagonal=1 + ), # give different examples different numbers of valid tokens + "sample_mask": torch.ones((global_batch_size,)), + "labels": torch.randint(0, vocab_size, (global_batch_size, seq_len)), + "num_valid_tokens_in_batch": torch.tensor( + [seq_len] * global_batch_size, dtype=torch.float32 + ), + "advantages": torch.randn(global_batch_size, seq_len), + "prev_logprobs": torch.randn(global_batch_size, seq_len), + "reference_policy_logprobs": torch.randn(global_batch_size, seq_len), + "generation_logprobs": torch.randn(global_batch_size, seq_len), + } + ) - print("Creating training Policy with mbs=2...") - policy_mbs2 = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - init_reference_model=False, - tokenizer=tokenizer, - ) + # Test with mbs=1, 2 microbatches per GPU + config = create_test_config(tiny_llama_model_path) + tokenizer = get_tokenizer(config["tokenizer"]) - # Test NLLLoss and ClippedPGLossFn with mbs=2 - policy_mbs2.prepare_for_training() - mbs2_nll_results = policy_mbs2.train(data, nll_loss_fn) - mbs2_nll_loss = mbs2_nll_results["loss"] + print("Creating training Policy with mbs=1...") + policy_mbs1 = Policy( + cluster=two_gpu_cluster, + config=config, + init_reference_model=False, + tokenizer=tokenizer, + ) - mbs2_pg_results = policy_mbs2.train(data, pg_loss_fn) - mbs2_pg_loss = mbs2_pg_results["loss"] + # Test NLLLoss and ClippedPGLossFn with mbs=1 + nll_loss_fn = NLLLoss() + pg_loss_fn = ClippedPGLossFn( + { + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, + "ratio_clip_c": None, + "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, + "truncated_importance_sampling_ratio": None, + "sequence_level_importance_ratios": False, + "token_level_loss": True, + "force_on_policy_ratio": False, + } + ) - # Verify both loss functions are independent of microbatch size - torch.testing.assert_close(mbs1_nll_loss, mbs2_nll_loss, rtol=1e-5, atol=1e-5) - torch.testing.assert_close(mbs1_pg_loss, mbs2_pg_loss, rtol=1e-5, atol=1e-5) + policy_mbs1.prepare_for_training() + mbs1_nll_results = policy_mbs1.train(data, nll_loss_fn) + mbs1_nll_loss = mbs1_nll_results["loss"] - policy_mbs2.worker_group.shutdown() + mbs1_pg_results = policy_mbs1.train(data, pg_loss_fn) + mbs1_pg_loss = mbs1_pg_results["loss"] + policy_mbs1.worker_group.shutdown() -@pytest.mark.hf_gated -@pytest.mark.timeout(300) -@pytest.mark.parametrize("use_v2", [True, False]) -def test_dtensor_v1_policy_flops_range_check( - tiny_llama_model_path, two_gpu_virtual_cluster, use_v2 -): - """Test that the returned FLOPS is within a reasonable range using dtensor backend. + # Test with mbs=2, 1 microbatch per GPU + config = create_test_config(tiny_llama_model_path) + config["train_micro_batch_size"] = 2 + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) - Performs 2 warmup iterations and measures FLOPS for the next 3 iterations. - """ - batch_size = 8 - seq_len = 128 - vocab_size = 32000 + print("Creating training Policy with mbs=2...") + policy_mbs2 = Policy( + cluster=two_gpu_cluster, + config=config, + init_reference_model=False, + tokenizer=tokenizer, + ) - # Create dtensor v1 config with default settings - config = create_test_config(tiny_llama_model_path, dtensor_v2=use_v2) + # Test NLLLoss and ClippedPGLossFn with mbs=2 + policy_mbs2.prepare_for_training() + mbs2_nll_results = policy_mbs2.train(data, nll_loss_fn) + mbs2_nll_loss = mbs2_nll_results["loss"] - # Update config for FLOPS testing with larger batch and sequence length - config["train_global_batch_size"] = batch_size - config["train_micro_batch_size"] = ( - batch_size # Use full batch size for single microbatch - ) + mbs2_pg_results = policy_mbs2.train(data, pg_loss_fn) + mbs2_pg_loss = mbs2_pg_results["loss"] - tokenizer = get_tokenizer(config["tokenizer"]) - config["generation"] = configure_generation_config(config["generation"], tokenizer) + # Verify both loss functions are independent of microbatch size + torch.testing.assert_close(mbs1_nll_loss, mbs2_nll_loss, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(mbs1_pg_loss, mbs2_pg_loss, rtol=1e-5, atol=1e-5) - policy = Policy( - cluster=two_gpu_virtual_cluster, - config=config, - tokenizer=tokenizer, - init_reference_model=False, - ) + policy_mbs2.worker_group.shutdown() - # Create test data - torch.manual_seed(42) - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) - attention_mask = torch.ones(batch_size, seq_len) - input_lengths = attention_mask.sum(dim=1).to(torch.int32) + @pytest.mark.timeout(300) + @pytest.mark.parametrize("use_v2", [True, False]) + def test_dtensor_v1_policy_flops_range_check( + self, tiny_llama_model_path, two_gpu_cluster, use_v2 + ): + """Test that the returned FLOPS is within a reasonable range using dtensor backend. - data = BatchedDataDict( - { - "input_ids": input_ids, - "input_lengths": input_lengths, - "attention_mask": attention_mask, - "labels": torch.randint(0, vocab_size, (batch_size, seq_len)), - "sample_mask": torch.ones(batch_size), - } - ) + Performs 2 warmup iterations and measures FLOPS for the next 3 iterations. + """ + batch_size = 8 + seq_len = 128 + vocab_size = 32000 - # Create loss function - loss_fn = SimpleLoss() + # Create dtensor v1 config with default settings + config = create_test_config(tiny_llama_model_path, dtensor_v2=use_v2) - try: - # Prepare for training - policy.prepare_for_training() - - # Perform 2 warmup iterations - print("Performing warmup iterations...") - for warmup_step in range(2): - results = policy.train(data, loss_fn) - - # Measure FLOPS on the third iteration - print("Measuring FLOPS on 3 iterations...") - time_begin = time.time() - for train_step in range(3): - results = policy.train(data, loss_fn) - runtime_sec = time.time() - time_begin - - # Check if FLOPS tracking is available - if policy.flops_tracker is not None: - assert "total_flops" in results, ( - "Training results should contain 'total_flops'" - ) - total_flops = results["total_flops"] + # Update config for FLOPS testing with larger batch and sequence length + config["train_global_batch_size"] = batch_size + config["train_micro_batch_size"] = ( + batch_size # Use full batch size for single microbatch + ) - assert isinstance(total_flops, (int, float)), ( - "total_flops should be numeric" - ) - assert total_flops > 0, "total_flops should be positive" + tokenizer = get_tokenizer(config["tokenizer"]) + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) - total_tflops = total_flops / 1e12 / 3 - print(f"Total FLOPS: {total_flops:.2e} ({total_tflops:.4f} TFLOPS)") + policy = Policy( + cluster=two_gpu_cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=False, + ) - flop_count_total = total_flops * runtime_sec - assert 1e9 < flop_count_total < 5e10, ( - "Total FLOPS should be within 1e9 and 5e10" - ) + # Create test data + torch.manual_seed(42) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "labels": torch.randint(0, vocab_size, (batch_size, seq_len)), + "sample_mask": torch.ones(batch_size), + } + ) - if "theoretical_tflops" in results: - theoretical_tflops = results["theoretical_tflops"] - assert isinstance(theoretical_tflops, (int, float)), ( - "theoretical_tflops should be numeric" + # Create loss function + loss_fn = SimpleLoss() + + try: + # Prepare for training + policy.prepare_for_training() + + # Perform 2 warmup iterations + print("Performing warmup iterations...") + for warmup_step in range(2): + results = policy.train(data, loss_fn) + + # Measure FLOPS on the third iteration + print("Measuring FLOPS on 3 iterations...") + time_begin = time.time() + for train_step in range(3): + results = policy.train(data, loss_fn) + runtime_sec = time.time() - time_begin + + # Check if FLOPS tracking is available + if policy.flops_tracker is not None: + assert "total_flops" in results, ( + "Training results should contain 'total_flops'" ) - assert theoretical_tflops > 0, "theoretical_tflops should be positive" + total_flops = results["total_flops"] - utilization = total_tflops / theoretical_tflops - print(f"Theoretical TFLOPS: {theoretical_tflops:.2f}") - print(f"Model utilization: {utilization * 100:.2f}%") + assert isinstance(total_flops, (int, float)), ( + "total_flops should be numeric" + ) + assert total_flops > 0, "total_flops should be positive" + + total_tflops = total_flops / 1e12 / 3 + print(f"Total FLOPS: {total_flops:.2e} ({total_tflops:.4f} TFLOPS)") - assert utilization <= 1.0, ( - f"Model utilization {utilization * 100:.2f}% should not exceed 100%" + flop_count_total = total_flops * runtime_sec + assert 1e9 < flop_count_total < 5e10, ( + "Total FLOPS should be within 1e9 and 5e10" ) - else: - print("FLOPS tracker not available, skipping FLOPS range check") - pytest.skip("FLOPS tracker not supported for this model configuration") - finally: - policy.shutdown() + if "theoretical_tflops" in results: + theoretical_tflops = results["theoretical_tflops"] + assert isinstance(theoretical_tflops, (int, float)), ( + "theoretical_tflops should be numeric" + ) + assert theoretical_tflops > 0, ( + "theoretical_tflops should be positive" + ) + + utilization = total_tflops / theoretical_tflops + print(f"Theoretical TFLOPS: {theoretical_tflops:.2f}") + print(f"Model utilization: {utilization * 100:.2f}%") + + assert utilization <= 1.0, ( + f"Model utilization {utilization * 100:.2f}% should not exceed 100%" + ) + else: + print("FLOPS tracker not available, skipping FLOPS range check") + pytest.skip("FLOPS tracker not supported for this model configuration") + + finally: + policy.shutdown() diff --git a/tests/unit/models/policy/test_dtensor_worker_v2.py b/tests/unit/models/policy/test_dtensor_worker_v2.py index daaf2ea5d2..4e9f33f99e 100644 --- a/tests/unit/models/policy/test_dtensor_worker_v2.py +++ b/tests/unit/models/policy/test_dtensor_worker_v2.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile + import pytest import ray +import torch +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster -from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy import AutomodelKwargs, PolicyConfig +from nemo_rl.models.policy.lm_policy import Policy +from tests.unit.test_utils import SimpleLoss def create_test_config( @@ -28,8 +36,14 @@ def create_test_config( activation_checkpointing: bool = False, custom_parallel_plan: str | None = None, dtensor_v2: bool = False, + precision: str = "float32", + expert_parallel_size: int = 1, + use_hf_tp_plan: bool = False, + sequence_packing_enabled: bool = False, + automodel_kwargs: AutomodelKwargs | None = None, + checkpointing: dict | None = None, ) -> PolicyConfig: - return { + config = { "model_name": model_name, "tokenizer": {"name": model_name}, "generation_batch_size": 1, # Small batch size for testing @@ -37,7 +51,7 @@ def create_test_config( "train_micro_batch_size": 1, "learning_rate": 5e-6, "logprob_batch_size": 1, - "precision": "float32", + "precision": precision, "offload_optimizer_for_logprob": False, "generation": { "backend": "hf", @@ -64,6 +78,8 @@ def create_test_config( "tensor_parallel_size": tp, "context_parallel_size": cp, "custom_parallel_plan": custom_parallel_plan, + "expert_parallel_size": expert_parallel_size, + "use_hf_tp_plan": use_hf_tp_plan, }, "dynamic_batching": { "enabled": True, @@ -72,7 +88,8 @@ def create_test_config( "sequence_length_round": 4, }, "sequence_packing": { - "enabled": False, + "enabled": sequence_packing_enabled, + "train_mb_tokens": 128, }, "optimizer": { "name": "torch.optim.AdamW", @@ -93,6 +110,41 @@ def create_test_config( }, "max_grad_norm": 1.0, } + if automodel_kwargs is not None: + config["dtensor_cfg"]["automodel_kwargs"] = automodel_kwargs + if checkpointing is not None: + config["checkpointing"] = checkpointing + return config + + +def create_test_batch( + batch_size: int = 8, + seq_len: int = 128, + vocab_size: int = 32000, + mode: str = "train", +) -> BatchedDataDict: + """Create a test batch for training or logprob computation.""" + torch.manual_seed(66) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + **( + { + "labels": torch.randint(0, vocab_size, (batch_size, seq_len)), + "sample_mask": torch.ones(batch_size).cuda(), + } + if mode == "train" + else {} + ), + } + ) + data = data.to("cpu") + return data @pytest.fixture(scope="module") @@ -111,10 +163,6 @@ def two_gpu_virtual_cluster(): cluster.shutdown() -from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.models.policy.lm_policy import Policy - - def compare_model_configs(config_v1: dict, config_v2: dict) -> list[str]: """ Compare two model configurations and return a list of discrepancies. @@ -178,7 +226,11 @@ def test_dtensor_worker_v1_v2_model_config_equivalence( cpu_offload, activation_checkpointing, ): - """Test that dtensor worker v1 and v2 produce equivalent model configurations.""" + """Test that dtensor worker v1 and v2 produce equivalent model configurations. + + This test verifies that DTensorPolicyWorkerV2 produces the same model config + as the v1 worker, ensuring backward compatibility. + """ # Get the actual model path from the fixture name model_name = request.getfixturevalue(model_fixture_name) # Create v1 configuration @@ -242,3 +294,165 @@ def test_dtensor_worker_v1_v2_model_config_equivalence( assert not discrepancies, ( f"Model configurations differ between v1 and v2 approaches for {model_name}" ) + + +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +def test_dtensor_v2_checkpoint_save_and_load( + two_gpu_virtual_cluster, + tiny_llama_model_path, +): + with tempfile.TemporaryDirectory() as tmpdir: + checkpointing_config = { + "enabled": True, + "checkpoint_dir": tmpdir, + "metric_name": None, # Save most recent checkpoints + "higher_is_better": False, + "keep_top_k": 2, + "save_period": 30, + "checkpoint_must_save_by": None, + } + + config = create_test_config( + model_name=tiny_llama_model_path, + tp=2, + cp=1, + dtensor_v2=True, + checkpointing=checkpointing_config, + ) + + policy = Policy( + tokenizer=get_tokenizer(config["tokenizer"]), + config=config, + init_optimizer=True, + init_reference_model=False, + cluster=two_gpu_virtual_cluster, + name_prefix="lm_policy_checkpoint", + ) + + try: + weights_path = os.path.join(tmpdir, "weights") + optimizer_path = os.path.join(tmpdir, "optimizer") + + # Save checkpoint + policy.save_checkpoint( + weights_path=weights_path, + optimizer_path=optimizer_path, + checkpointing_cfg=checkpointing_config, + ) + + # Verify checkpoint files were created + assert os.path.exists(weights_path), "Weights path should exist after save" + + # Load checkpoint into a new policy + config2 = create_test_config( + model_name=tiny_llama_model_path, + tp=2, + cp=1, + dtensor_v2=True, + checkpointing=checkpointing_config, + ) + + # Shutdown original policy first to free GPU memory + policy.shutdown() + policy = None + + policy2 = Policy( + tokenizer=get_tokenizer(config2["tokenizer"]), + config=config2, + init_optimizer=True, + init_reference_model=False, + cluster=two_gpu_virtual_cluster, + name_prefix="lm_policy_checkpoint_loaded", + weights_path=weights_path, + optimizer_path=optimizer_path, + ) + + # Verify policy was loaded successfully + assert len(policy2.worker_group.workers) == 2 + worker_alive = ray.get( + [w.is_alive.remote() for w in policy2.worker_group.workers] + ) + assert all(worker_alive) + + policy2.shutdown() + finally: + if policy is not None: + policy.shutdown() + + +@pytest.mark.hf_gated +@pytest.mark.timeout(360) +@pytest.mark.parametrize("precision", ["bfloat16", "float16"]) +def test_dtensor_v2_mixed_precision_training_and_logprobs( + two_gpu_virtual_cluster, + tiny_llama_model_path, + precision, +): + config = create_test_config( + model_name=tiny_llama_model_path, + tp=2, + cp=1, + dtensor_v2=True, + precision=precision, + ) + + policy = Policy( + tokenizer=get_tokenizer(config["tokenizer"]), + config=config, + init_optimizer=True, + init_reference_model=False, + cluster=two_gpu_virtual_cluster, + name_prefix=f"lm_policy_{precision}_mixed", + ) + + try: + # --- Test Training --- + train_data = create_test_batch(mode="train") + loss_fn = SimpleLoss() + + policy.prepare_for_training() + results = policy.train(train_data, loss_fn) + + # Verify training completed successfully + assert "loss" in results + loss_tensor = results["loss"] + assert not torch.isnan(loss_tensor).any(), ( + f"Loss should not be NaN with precision={precision}" + ) + assert not torch.isinf(loss_tensor).any(), ( + f"Loss should not be Inf with precision={precision}" + ) + # Loss is returned in float32 (reduced in float32 for numerical stability) + assert loss_tensor.dtype == torch.float32, ( + f"Loss should be float32, got {loss_tensor.dtype}" + ) + + policy.finish_training() + + # --- Test Logprobs --- + logprob_data = create_test_batch(mode="logprob") + + policy.prepare_for_lp_inference() + logprobs = policy.get_logprobs(logprob_data) + + # Verify logprobs were computed successfully + assert "logprobs" in logprobs + logprobs_tensor = logprobs["logprobs"] + assert logprobs_tensor.shape[0] == logprob_data.size + assert not torch.isnan(logprobs_tensor).any(), ( + f"Logprobs should not be NaN with precision={precision}" + ) + assert not torch.isinf(logprobs_tensor).any(), ( + f"Logprobs should not be Inf with precision={precision}" + ) + # Logprobs are returned in float32 for numerical stability + assert logprobs_tensor.dtype == torch.float32, ( + f"Logprobs should be float32 for numerical stability, got {logprobs_tensor.dtype}" + ) + + # Verify the configured precision by checking worker info + worker_info = ray.get(policy.worker_group.workers[0].get_gpu_info.remote()) + assert worker_info is not None, "Should get worker info" + finally: + policy.shutdown() diff --git a/tests/unit/models/policy/test_patches.py b/tests/unit/models/policy/test_patches.py new file mode 100644 index 0000000000..e8cacbcd4a --- /dev/null +++ b/tests/unit/models/policy/test_patches.py @@ -0,0 +1,447 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from nemo_rl.models.policy.workers.patches import ( + _get_transformer_engine_file, + apply_transformer_engine_patch, +) + + +class TestGetTransformerEngineFile: + """Tests for _get_transformer_engine_file function.""" + + def test_package_not_found(self): + """Test that RuntimeError is raised when transformer_engine is not installed.""" + with patch( + "nemo_rl.models.policy.workers.patches.find_spec", return_value=None + ): + with pytest.raises(RuntimeError) as exc_info: + _get_transformer_engine_file("pytorch/triton/permutation.py") + + assert "Transformer Engine package not found" in str(exc_info.value) + assert "pytorch/triton/permutation.py" in str(exc_info.value) + + def test_package_no_submodule_locations(self): + """Test that RuntimeError is raised when package has no submodule_search_locations.""" + mock_spec = MagicMock() + mock_spec.submodule_search_locations = None + + with patch( + "nemo_rl.models.policy.workers.patches.find_spec", return_value=mock_spec + ): + with pytest.raises(RuntimeError) as exc_info: + _get_transformer_engine_file("pytorch/triton/permutation.py") + + assert "Transformer Engine package not found" in str(exc_info.value) + + def test_package_empty_submodule_locations(self): + """Test that RuntimeError is raised when submodule_search_locations is empty.""" + mock_spec = MagicMock() + mock_spec.submodule_search_locations = [] + + with patch( + "nemo_rl.models.policy.workers.patches.find_spec", return_value=mock_spec + ): + with pytest.raises(RuntimeError) as exc_info: + _get_transformer_engine_file("pytorch/triton/permutation.py") + + assert "Transformer Engine package not found" in str(exc_info.value) + + def test_file_not_found(self): + """Test that RuntimeError is raised when the target file doesn't exist.""" + mock_spec = MagicMock() + mock_spec.submodule_search_locations = ["/fake/path/to/transformer_engine"] + + with ( + patch( + "nemo_rl.models.policy.workers.patches.find_spec", + return_value=mock_spec, + ), + patch("os.path.exists", return_value=False), + ): + with pytest.raises(RuntimeError) as exc_info: + _get_transformer_engine_file("pytorch/triton/permutation.py") + + assert "Failed to locate expected Transformer Engine file" in str( + exc_info.value + ) + assert "pytorch/triton/permutation.py" in str(exc_info.value) + + def test_successful_file_lookup(self): + """Test successful file path resolution.""" + mock_spec = MagicMock() + mock_spec.submodule_search_locations = ["/fake/path/to/transformer_engine"] + expected_path = os.path.join( + "/fake/path/to/transformer_engine", "pytorch", "triton", "permutation.py" + ) + + with ( + patch( + "nemo_rl.models.policy.workers.patches.find_spec", + return_value=mock_spec, + ), + patch("os.path.exists", return_value=True), + ): + result = _get_transformer_engine_file("pytorch/triton/permutation.py") + + assert result == expected_path + + def test_path_construction_with_multiple_segments(self): + """Test that paths with multiple segments are correctly constructed.""" + mock_spec = MagicMock() + mock_spec.submodule_search_locations = ["/base/dir"] + + with ( + patch( + "nemo_rl.models.policy.workers.patches.find_spec", + return_value=mock_spec, + ), + patch("os.path.exists", return_value=True), + ): + result = _get_transformer_engine_file("a/b/c/d.py") + + expected = os.path.join("/base/dir", "a", "b", "c", "d.py") + assert result == expected + + +class TestApplyTransformerEnginePatch: + """Tests for apply_transformer_engine_patch function.""" + + UNPATCHED_CONTENT = """ +import triton +from triton import language as core + +@triton.jit +def some_kernel(x): + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + return x +""" + + ALREADY_PATCHED_CONTENT = """ +import triton +from triton import language as core + +get_int_dtype = core.get_int_dtype +get_int_dtype = triton.constexpr_function(get_int_dtype) + +@triton.jit +def some_kernel(x): + idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + return x +""" + + def test_patch_not_applied_when_already_patched(self, capsys): + """Test that patch is not applied when file is already patched.""" + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch( + "builtins.open", + mock_open(read_data=self.ALREADY_PATCHED_CONTENT), + ) as mock_file, + ): + apply_transformer_engine_patch() + + # Verify file was only opened for reading (not writing) + mock_file.assert_called_once_with("/fake/path/permutation.py", "r") + # No print about applying fix since already patched + captured = capsys.readouterr() + assert "Applying Triton fix" not in captured.out + + def test_patch_applied_when_needed(self, capsys): + """Test that patch is correctly applied when file needs patching.""" + mock_file_handle = MagicMock() + mock_file_handle.read.return_value = self.UNPATCHED_CONTENT + mock_file_handle.__enter__ = MagicMock(return_value=mock_file_handle) + mock_file_handle.__exit__ = MagicMock(return_value=False) + + written_content = [] + + def mock_write(content): + written_content.append(content) + + mock_file_handle.write = mock_write + + call_count = [0] + + def mock_open_func(path, mode="r"): + call_count[0] += 1 + if mode == "r": + mock_file_handle.read.return_value = self.UNPATCHED_CONTENT + return mock_file_handle + + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch("builtins.open", mock_open_func), + ): + apply_transformer_engine_patch() + + captured = capsys.readouterr() + assert "Applying Triton fix to /fake/path/permutation.py" in captured.out + assert "Successfully patched" in captured.out + + # Verify the content was modified + assert len(written_content) > 0 + new_content = written_content[0] + assert ( + "get_int_dtype = triton.constexpr_function(get_int_dtype)" + in new_content + ) + assert ( + "idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" + in new_content + ) + + def test_patch_handles_permission_error(self, capsys): + """Test that permission errors when writing are handled gracefully.""" + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch("builtins.open") as mock_file, + ): + # First call (read) succeeds + read_mock = MagicMock() + read_mock.__enter__ = MagicMock(return_value=read_mock) + read_mock.__exit__ = MagicMock(return_value=False) + read_mock.read.return_value = self.UNPATCHED_CONTENT + + # Second call (write) fails with permission error + write_mock = MagicMock() + write_mock.__enter__ = MagicMock( + side_effect=PermissionError("Permission denied") + ) + + mock_file.side_effect = [read_mock, write_mock] + + apply_transformer_engine_patch() + + captured = capsys.readouterr() + # Should not crash, but print error message + assert "Applying Triton fix" in captured.out + + def test_patch_handles_file_lookup_error(self, capsys): + """Test that errors from _get_transformer_engine_file are handled.""" + with patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + side_effect=RuntimeError("Transformer Engine package not found"), + ): + # Should not raise, just print error + apply_transformer_engine_patch() + + captured = capsys.readouterr() + assert "Error checking/patching transformer_engine" in captured.out + + def test_module_reload_when_already_imported(self): + """Test that the module is reloaded if already imported.""" + module_name = "transformer_engine.pytorch.triton.permutation" + + # Create a fake module to put in sys.modules + fake_module = MagicMock() + + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch("builtins.open", mock_open(read_data=self.ALREADY_PATCHED_CONTENT)), + patch.dict(sys.modules, {module_name: fake_module}), + patch("importlib.reload") as mock_reload, + ): + apply_transformer_engine_patch() + + mock_reload.assert_called_once_with(fake_module) + + def test_no_reload_when_module_not_imported(self): + """Test that no reload happens if module isn't imported.""" + module_name = "transformer_engine.pytorch.triton.permutation" + + # Ensure module is NOT in sys.modules + modules_without_te = {k: v for k, v in sys.modules.items() if k != module_name} + + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch("builtins.open", mock_open(read_data=self.ALREADY_PATCHED_CONTENT)), + patch.dict(sys.modules, modules_without_te, clear=True), + patch("importlib.reload") as mock_reload, + ): + apply_transformer_engine_patch() + + mock_reload.assert_not_called() + + def test_patch_does_nothing_when_old_usage_not_found(self, capsys): + """Test that patch does nothing when old_usage pattern is not in file.""" + content_without_old_usage = """ +import triton +from triton import language as core + +@triton.jit +def some_kernel(x): + # Different usage pattern + idtype = some_other_function() + return x +""" + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch( + "builtins.open", mock_open(read_data=content_without_old_usage) + ) as mock_file, + ): + apply_transformer_engine_patch() + + # Verify file was only opened for reading (not writing) + mock_file.assert_called_once_with("/fake/path/permutation.py", "r") + captured = capsys.readouterr() + # Should print applying message but not success message since pattern not found + assert "Applying Triton fix" in captured.out + assert "Successfully patched" not in captured.out + + def test_patch_does_nothing_when_jit_anchor_not_found(self, capsys): + """Test that patch does nothing when @triton.jit anchor is not found.""" + content_without_jit = """ +import triton +from triton import language as core + +def some_kernel(x): + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + return x +""" + with ( + patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value="/fake/path/permutation.py", + ), + patch( + "builtins.open", mock_open(read_data=content_without_jit) + ) as mock_file, + ): + apply_transformer_engine_patch() + + # Verify file was only opened for reading (not writing) + mock_file.assert_called_once_with("/fake/path/permutation.py", "r") + captured = capsys.readouterr() + assert "Applying Triton fix" in captured.out + assert "Successfully patched" not in captured.out + + +class TestPatchIntegration: + """Integration-style tests for the patch module.""" + + def test_patch_with_real_temp_file(self, capsys): + """Test patching with a real temporary file to verify file operations.""" + unpatched_content = """import triton +from triton import language as core + +@triton.jit +def permutation_kernel(x): + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + return x +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as tmp_file: + tmp_file.write(unpatched_content) + tmp_path = tmp_file.name + + try: + with patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value=tmp_path, + ): + apply_transformer_engine_patch() + + # Read the patched file + with open(tmp_path, "r") as f: + patched_content = f.read() + + # Verify the patch was applied + assert ( + "get_int_dtype = triton.constexpr_function(get_int_dtype)" + in patched_content + ) + assert ( + "idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" + in patched_content + ) + # Verify old pattern is gone + assert ( + "idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)" + not in patched_content + ) + + captured = capsys.readouterr() + assert "Successfully patched" in captured.out + finally: + os.unlink(tmp_path) + + def test_patch_idempotent(self, capsys): + """Test that applying patch twice doesn't change already patched content.""" + unpatched_content = """import triton +from triton import language as core + +@triton.jit +def permutation_kernel(x): + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + return x +""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as tmp_file: + tmp_file.write(unpatched_content) + tmp_path = tmp_file.name + + try: + with patch( + "nemo_rl.models.policy.workers.patches._get_transformer_engine_file", + return_value=tmp_path, + ): + # Apply patch first time + apply_transformer_engine_patch() + + with open(tmp_path, "r") as f: + first_patched = f.read() + + # Apply patch second time + apply_transformer_engine_patch() + + with open(tmp_path, "r") as f: + second_patched = f.read() + + # Content should be identical + assert first_patched == second_patched + + captured = capsys.readouterr() + # First application should succeed, second should skip + assert captured.out.count("Successfully patched") == 1 + finally: + os.unlink(tmp_path) diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py index 64f859c494..9d4c3afe86 100644 --- a/tests/unit/utils/test_automodel_checkpoint.py +++ b/tests/unit/utils/test_automodel_checkpoint.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Unit tests for automodel checkpoint utilities.""" import os from tempfile import TemporaryDirectory @@ -20,7 +21,6 @@ import torch # Skip entire module if nemo_automodel is not available -pytest_plugins = [] try: import nemo_automodel # noqa: F401 except ImportError: @@ -32,9 +32,9 @@ ) from nemo_rl.utils.automodel_checkpoint import ( + AutomodelCheckpointManager, + _infer_checkpoint_root, detect_checkpoint_format, - load_checkpoint, - save_checkpoint, ) @@ -89,14 +89,66 @@ def mock_lora_config(): ) +def _cleanup_dcp_planner_cache(): + """Clean up DCP SavePlanner class-level caches. + + The SavePlanner class maintains class-level caches for plan caching optimization. + When tests run with different checkpoint formats (safetensors vs torch_save), + these caches contain stale data that causes errors. This function clears all + planner caches to ensure test isolation. + """ + try: + from torch.distributed.checkpoint.planner import SavePlanner + + # Clear all class-level planner caches + if hasattr(SavePlanner, "_cached_save_plan"): + SavePlanner._cached_save_plan.clear() + if hasattr(SavePlanner, "_cached_all_plans"): + SavePlanner._cached_all_plans.clear() + if hasattr(SavePlanner, "_cached_global_plan"): + SavePlanner._cached_global_plan.clear() + if hasattr(SavePlanner, "_cached_metadata"): + SavePlanner._cached_metadata.clear() + if hasattr(SavePlanner, "_cached_final_save_plan"): + SavePlanner._cached_final_save_plan.clear() + except Exception: + pass + + +def _cleanup_device_mesh_cache(): + """Clean up device mesh cache.""" + import gc + + # Clear device mesh cache + if hasattr(torch.distributed, "device_mesh") and hasattr( + torch.distributed.device_mesh, "_mesh_resources" + ): + try: + torch.distributed.device_mesh._mesh_resources.mesh_stack.clear() + torch.distributed.device_mesh._mesh_resources.child_to_root_mapping.clear() + except Exception: + pass + + gc.collect() + + @pytest.fixture def init_distributed(): - """Initialize a single-process distributed environment for testing.""" + """Initialize a single-process distributed environment for testing. + + Each test gets proper cleanup of DCP planner caches to ensure test isolation. + The planner caches are class-level and shared across all DefaultSavePlanner + instances, so they must be cleared between tests that use different + checkpoint formats. + """ + # Clean up any stale planner caches from previous tests + _cleanup_dcp_planner_cache() + _cleanup_device_mesh_cache() # Only initialize if not already initialized if not torch.distributed.is_initialized(): os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" # Free port + os.environ["MASTER_PORT"] = "29500" os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" @@ -105,251 +157,635 @@ def init_distributed(): yield - # Cleanup - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() + # Cleanup planner caches after test (critical for test isolation) + _cleanup_dcp_planner_cache() + _cleanup_device_mesh_cache() + + +@pytest.fixture +def mock_experiment(): + """Create a real model, optimizer, and scheduler for integration testing.""" + model = TestModel() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + return model, optimizer, scheduler + + +def check_dict_equality(dict1, dict2): + """Recursively check equality of two dictionaries""" + for k in dict1.keys(): + if isinstance(dict1[k], dict): + check_dict_equality(dict1[k], dict2[k]) + elif isinstance(dict1[k], torch.Tensor): + assert torch.allclose(dict1[k], dict2[k]) + else: + assert dict1[k] == dict2[k] @pytest.mark.automodel class TestDetectCheckpointFormat: - """Test the detect_checkpoint_format function.""" + """Tests for detect_checkpoint_format function.""" - def test_directory_with_safetensors(self): - """Test detection for directories containing safetensors files.""" - with TemporaryDirectory() as tmp_dir: - # Create directory with safetensors files - os.makedirs(os.path.join(tmp_dir, "weights", "model")) - weights_path = os.path.join(tmp_dir, "weights", "model") - - # Create safetensors shard files - with open( - os.path.join( - weights_path, "shard-00001-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - with open( - os.path.join( - weights_path, "shard-00002-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(weights_path) - assert format_type == "safetensors" - assert is_peft == False - - def test_directory_with_dcp_format(self): - """Test detection for directories with DCP (Distributed Checkpoint) format.""" - with TemporaryDirectory() as tmp_dir: - # Create directory structure like: step_3/policy/optimizer/optim - optim_path = os.path.join(tmp_dir, "step_3", "policy", "optimizer", "optim") - os.makedirs(optim_path) - - # Create DCP files (.distcp + .metadata) - with open(os.path.join(optim_path, "__0_0.distcp"), "w") as f: - f.write("dummy dcp content") - with open(os.path.join(optim_path, "__1_0.distcp"), "w") as f: - f.write("dummy dcp content") - with open(os.path.join(optim_path, ".metadata"), "w") as f: - f.write("dummy metadata") - - format_type, is_peft = detect_checkpoint_format(optim_path) - assert format_type == "torch_save" # DCP uses torch_save format - assert is_peft == False - - def test_directory_with_torch_files(self): - """Test detection for directories containing torch save files.""" - with TemporaryDirectory() as tmp_dir: - model_path = os.path.join(tmp_dir, "model") - os.makedirs(model_path) + def test_detect_safetensors_format(self, tmp_path): + """Test detection of safetensors format.""" + # Create a checkpoint directory with safetensors files + model_dir = tmp_path / "weights" / "model" + model_dir.mkdir(parents=True) + (model_dir / "model.safetensors").touch() - # Create torch save files - with open(os.path.join(model_path, "pytorch_model.bin"), "w") as f: - f.write("dummy content") + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) - format_type, is_peft = detect_checkpoint_format(model_path) - assert format_type == "torch_save" - assert is_peft == False + assert model_save_format == "safetensors" + assert is_peft is False - def test_peft_detection_in_filenames(self): - """Test PEFT detection from filenames within directories.""" - with TemporaryDirectory() as tmp_dir: - model_path = os.path.join(tmp_dir, "regular_model") - os.makedirs(model_path) - - # Create file with adapter pattern in name - with open(os.path.join(model_path, "adapter_model.safetensors"), "w") as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(model_path) - assert format_type == "safetensors" - assert is_peft == True # Should detect adapter in filename - - def test_default_fallback(self): - """Test default behavior for non-existent directories.""" - # Non-existent directory should default to safetensors, no PEFT - format_type, is_peft = detect_checkpoint_format("/non/existent/directory") - assert format_type == "safetensors" - assert is_peft == False - - def test_expected_structure(self): - """Test with the expected folder structure from the user.""" - with TemporaryDirectory() as tmp_dir: - # Create the expected structure: step_3/policy/weights/model - weights_path = os.path.join(tmp_dir, "step_3", "policy", "weights", "model") - os.makedirs(weights_path) - - # Create safetensors shard files as in the example - with open( - os.path.join( - weights_path, "shard-00001-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - with open( - os.path.join( - weights_path, "shard-00002-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(weights_path) - assert format_type == "safetensors" - assert is_peft == False - - """Test the save_checkpoint function.""" - - @pytest.mark.automodel - @patch("nemo_rl.utils.automodel_checkpoint.save_model") - @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") - def test_save_model_only(self, mock_save_optimizer, mock_save_model, mock_model): + def test_detect_torch_save_format_distcp(self, tmp_path): + """Test detection of torch_save format with .distcp files.""" + # Create a checkpoint directory with .distcp files + model_dir = tmp_path / "weights" / "model" + model_dir.mkdir(parents=True) + (model_dir / "checkpoint.distcp").touch() + + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) + + assert model_save_format == "torch_save" + assert is_peft is False + + def test_detect_torch_save_format_bin(self, tmp_path): + """Test detection of torch_save format with .bin files.""" + # Create a checkpoint directory with .bin files + model_dir = tmp_path / "weights" / "model" + model_dir.mkdir(parents=True) + (model_dir / "pytorch_model.bin").touch() + + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) + + assert model_save_format == "torch_save" + assert is_peft is False + + def test_detect_torch_save_format_pt(self, tmp_path): + """Test detection of torch_save format with .pt files.""" + # Create a checkpoint directory with .pt files + model_dir = tmp_path / "weights" / "model" + model_dir.mkdir(parents=True) + (model_dir / "model.pt").touch() + + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) + + assert model_save_format == "torch_save" + assert is_peft is False + + def test_detect_peft_adapter(self, tmp_path): + """Test detection of PEFT adapter files.""" + # Create a checkpoint directory with adapter files + model_dir = tmp_path / "weights" / "model" + model_dir.mkdir(parents=True) + (model_dir / "adapter_model.safetensors").touch() + + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) + + assert model_save_format == "safetensors" + assert is_peft is True + + def test_detect_peft_adapter_config(self, tmp_path): + """Test detection of PEFT adapter config files.""" + # Create a checkpoint directory with adapter config + model_dir = tmp_path / "weights" / "model" + model_dir.mkdir(parents=True) + (model_dir / "adapter_config.json").touch() + (model_dir / "model.safetensors").touch() + + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) + + assert model_save_format == "safetensors" + assert is_peft is True + + def test_detect_empty_directory(self, tmp_path): + """Test detection with empty directory.""" + # Create an empty checkpoint directory + model_dir = tmp_path / "weights" + model_dir.mkdir(parents=True) + + model_save_format, is_peft = detect_checkpoint_format(str(model_dir)) + + # Default to safetensors when no files found + assert model_save_format == "safetensors" + assert is_peft is False + + def test_detect_nonexistent_directory(self, tmp_path): + """Test detection with non-existent directory.""" + nonexistent_path = str(tmp_path / "nonexistent") + + model_save_format, is_peft = detect_checkpoint_format(nonexistent_path) + + # Default to safetensors when directory doesn't exist + assert model_save_format == "safetensors" + assert is_peft is False + + def test_detect_nested_safetensors(self, tmp_path): + """Test detection of safetensors in nested directories.""" + # Create nested structure + nested_dir = tmp_path / "weights" / "model" / "rank_0" + nested_dir.mkdir(parents=True) + (nested_dir / "model-00001-of-00002.safetensors").touch() + + model_save_format, is_peft = detect_checkpoint_format(str(tmp_path / "weights")) + + assert model_save_format == "safetensors" + assert is_peft is False + + +@pytest.mark.automodel +class TestInferCheckpointRoot: + """Tests for _infer_checkpoint_root function.""" + + def test_infer_root_from_weights_model_path(self): + """Test inferring root from path ending with weights/model.""" + weights_path = "/path/to/checkpoint/weights/model" + + result = _infer_checkpoint_root(weights_path) + + assert result == "/path/to/checkpoint" + + def test_infer_root_from_weights_path(self): + """Test inferring root from path ending with weights.""" + weights_path = "/path/to/checkpoint/weights" + + result = _infer_checkpoint_root(weights_path) + + assert result == "/path/to/checkpoint" + + def test_infer_root_from_other_path(self): + """Test inferring root from path not ending with weights.""" + weights_path = "/path/to/checkpoint/custom_dir" + + result = _infer_checkpoint_root(weights_path) + + # Should return parent directory + assert result == "/path/to/checkpoint" + + def test_infer_root_with_trailing_slash(self): + """Test inferring root with trailing slash in path.""" + weights_path = "/path/to/checkpoint/weights/" + + result = _infer_checkpoint_root(weights_path) + + # dirname of "/path/to/checkpoint/weights/" is "/path/to/checkpoint/weights" + # which ends with "weights", so parent is returned + assert result == "/path/to/checkpoint" + + def test_infer_root_relative_path(self): + """Test inferring root from relative path.""" + weights_path = "checkpoint/weights/model" + + result = _infer_checkpoint_root(weights_path) + + assert result == "checkpoint" + + def test_infer_root_single_level(self): + """Test inferring root from single level path.""" + weights_path = "weights/model" + + result = _infer_checkpoint_root(weights_path) + + assert result == "" + + +@pytest.mark.automodel +class TestAutomodelCheckpointManager: + """Tests for AutomodelCheckpointManager class. + + Note: Full integration tests require distributed environment setup. + These tests focus on the helper methods and configuration. + """ + + @pytest.fixture + def mock_meshes(self): + """Create mock device meshes for testing.""" + mock_dp_mesh = MagicMock() + mock_dp_mesh.get_group.return_value = MagicMock() + + mock_tp_mesh = MagicMock() + mock_tp_mesh.get_group.return_value = MagicMock() + + return mock_dp_mesh, mock_tp_mesh + + @patch("torch.distributed.get_rank") + def test_manager_initialization(self, mock_get_rank, mock_meshes): + """Test AutomodelCheckpointManager initialization.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + model_state_dict_keys=["layer1.weight", "layer2.weight"], + ) + + assert manager.checkpointer is None + assert manager.checkpoint_config is None + assert manager.model_state_dict_keys == ["layer1.weight", "layer2.weight"] + assert manager.dp_mesh is mock_dp_mesh + assert manager.tp_mesh is mock_tp_mesh + + @patch("torch.distributed.get_rank") + def test_set_model_state_dict_keys(self, mock_get_rank, mock_meshes): + """Test setting model state dict keys.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + new_keys = ["new_layer.weight", "new_layer.bias"] + manager.set_model_state_dict_keys(new_keys) + + assert manager.model_state_dict_keys == new_keys + + @patch("torch.distributed.get_rank") + def test_save_checkpoint_without_checkpointer_raises( + self, mock_get_rank, mock_meshes + ): + """Test that save_checkpoint raises error without initialized checkpointer.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + mock_model = MagicMock() + + with pytest.raises(AssertionError, match="Checkpointer must be initialized"): + manager.save_checkpoint( + model=mock_model, + weights_path="/path/to/weights", + checkpointing_cfg={"enabled": True}, + ) + + @patch("torch.distributed.get_rank") + def test_save_checkpoint_without_config_raises(self, mock_get_rank, mock_meshes): + """Test that save_checkpoint raises error without checkpointing config.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + # Set up a mock checkpointer so we get past the first assertion + manager.checkpointer = MagicMock() + + mock_model = MagicMock() + + with pytest.raises(ValueError, match="checkpointing_cfg must be provided"): + manager.save_checkpoint( + model=mock_model, + weights_path="/path/to/weights", + checkpointing_cfg=None, + ) + + @patch("torch.distributed.get_rank") + def test_load_checkpoint_without_checkpointer_raises( + self, mock_get_rank, mock_meshes + ): + """Test that load_checkpoint raises error without initialized checkpointer.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + mock_model = MagicMock() + + with pytest.raises(AssertionError, match="Checkpointer must be initialized"): + manager.load_checkpoint( + model=mock_model, + weights_path="/path/to/weights", + ) + + @patch("torch.distributed.get_rank") + def test_load_base_model_without_checkpointer_raises( + self, mock_get_rank, mock_meshes + ): + """Test that load_base_model raises error without initialized checkpointer.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + mock_model = MagicMock() + + with pytest.raises(AssertionError, match="Checkpointer must be initialized"): + manager.load_base_model( + model=mock_model, + model_name="test-model", + ) + + @patch("torch.distributed.get_rank") + def test_init_checkpointer_creates_checkpointer(self, mock_get_rank, mock_meshes): + """Test that init_checkpointer creates a new checkpointer.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + assert manager.checkpointer is None + + # Mock the Checkpointer class + with patch( + "nemo_rl.utils.automodel_checkpoint.Checkpointer" + ) as mock_checkpointer_cls: + mock_checkpointer = MagicMock() + mock_checkpointer_cls.return_value = mock_checkpointer + + manager.init_checkpointer( + config_updates={"model_repo_id": "test-model"}, + checkpoint_root="/path/to/checkpoints", + ) + + assert manager.checkpointer is mock_checkpointer + mock_checkpointer_cls.assert_called_once() + + @patch("torch.distributed.get_rank") + def test_init_checkpointer_does_nothing_if_exists(self, mock_get_rank, mock_meshes): + """Test that init_checkpointer does nothing if checkpointer already exists.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + # Create a mock existing checkpointer + existing_checkpointer = MagicMock() + manager.checkpointer = existing_checkpointer + + # Try to init again + with patch( + "nemo_rl.utils.automodel_checkpoint.Checkpointer" + ) as mock_checkpointer_cls: + manager.init_checkpointer( + config_updates={"model_repo_id": "test-model"}, + ) + + # Should not have created a new checkpointer + mock_checkpointer_cls.assert_not_called() + assert manager.checkpointer is existing_checkpointer + + @patch("torch.distributed.get_rank") + def test_update_checkpointer_config_updates_config( + self, mock_get_rank, mock_meshes + ): + """Test that update_checkpointer_config updates the config.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + # Create a mock checkpointer with config + mock_config = MagicMock() + mock_checkpointer = MagicMock() + mock_checkpointer.config = mock_config + manager.checkpointer = mock_checkpointer + + # Update the config + manager.update_checkpointer_config( + config_updates={"is_peft": True}, + checkpoint_root="/new/path", + ) + + # Verify config was updated + assert mock_config.checkpoint_dir == "/new/path" + assert mock_config.is_peft is True + + @patch("torch.distributed.get_rank") + def test_update_checkpointer_config_does_nothing_if_no_checkpointer( + self, mock_get_rank, mock_meshes + ): + """Test that update_checkpointer_config does nothing without checkpointer.""" + from nemo_rl.utils.automodel_checkpoint import AutomodelCheckpointManager + + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + + # Should not raise any error + manager.update_checkpointer_config( + config_updates={"is_peft": True}, + checkpoint_root="/new/path", + ) + + assert manager.checkpointer is None + + +@pytest.mark.automodel +class TestSaveCheckpointFunctional: + """Functional tests for save_checkpoint method with mocked internals.""" + + @pytest.fixture + def mock_meshes(self): + """Create mock device meshes for testing.""" + mock_dp_mesh = MagicMock() + mock_dp_mesh.get_group.return_value = MagicMock() + + mock_tp_mesh = MagicMock() + mock_tp_mesh.get_group.return_value = MagicMock() + + return mock_dp_mesh, mock_tp_mesh + + @patch("torch.distributed.get_rank") + @patch("nemo_rl.utils.automodel_checkpoint.Checkpointer") + def test_save_model_only( + self, mock_checkpointer_cls, mock_get_rank, mock_meshes, mock_model + ): """Test saving model weights only.""" + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + mock_checkpointer = MagicMock() + mock_checkpointer_cls.return_value = mock_checkpointer + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + manager.init_checkpointer() + with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "weights") - os.makedirs(os.path.dirname(weights_path), exist_ok=True) # Save checkpoint - save_checkpoint( + manager.save_checkpoint( model=mock_model, weights_path=weights_path, - model_save_format="safetensors", - is_peft=False, + checkpointing_cfg={ + "enabled": True, + "model_save_format": "safetensors", + "is_peft": False, + }, ) - # Verify save_model was called correctly - mock_save_model.assert_called_once() - call_args = mock_save_model.call_args - assert call_args[1]["model"] is mock_model - assert call_args[1]["weights_path"] == weights_path - assert ( - call_args[1]["checkpoint_config"].model_save_format.value - == "safetensors" - ) - assert call_args[1]["checkpoint_config"].is_peft == False + # Verify save_model was called + mock_checkpointer.save_model.assert_called_once() - # Verify optimizer saving was not called - mock_save_optimizer.assert_not_called() + # Verify save_optimizer was not called + mock_checkpointer.save_optimizer.assert_not_called() - @pytest.mark.automodel - @patch("nemo_rl.utils.automodel_checkpoint.save_model") - @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") + @patch("torch.distributed.get_rank") + @patch("nemo_rl.utils.automodel_checkpoint.Checkpointer") def test_save_with_optimizer( - self, mock_save_optimizer, mock_save_model, mock_model, mock_optimizer + self, + mock_checkpointer_cls, + mock_get_rank, + mock_meshes, + mock_model, + mock_optimizer, ): """Test saving model and optimizer weights.""" + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + mock_checkpointer = MagicMock() + mock_checkpointer_cls.return_value = mock_checkpointer + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + manager.init_checkpointer() + with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "model", "weights") optimizer_path = os.path.join(tmp_dir, "optimizer", "optim") - os.makedirs(os.path.dirname(weights_path)) - os.makedirs(os.path.dirname(optimizer_path)) # Save checkpoint with optimizer - save_checkpoint( + manager.save_checkpoint( model=mock_model, weights_path=weights_path, optimizer=mock_optimizer, optimizer_path=optimizer_path, - model_save_format="torch_save", - is_peft=True, + checkpointing_cfg={ + "enabled": True, + "model_save_format": "torch_save", + "is_peft": True, + }, ) # Verify both model and optimizer saving were called - mock_save_model.assert_called_once() - mock_save_optimizer.assert_called_once() - - # Check optimizer call args - opt_call_args = mock_save_optimizer.call_args - assert opt_call_args[1]["optimizer"] is mock_optimizer - assert opt_call_args[1]["model"] is mock_model - assert opt_call_args[1]["weights_path"] == optimizer_path - - @pytest.mark.automodel - @patch("nemo_rl.utils.automodel_checkpoint.save_model") - def test_save_with_tokenizer(self, mock_save_model, mock_model): + mock_checkpointer.save_model.assert_called_once() + mock_checkpointer.save_optimizer.assert_called_once() + + @patch("torch.distributed.get_rank") + @patch("nemo_rl.utils.automodel_checkpoint.Checkpointer") + def test_save_with_tokenizer( + self, mock_checkpointer_cls, mock_get_rank, mock_meshes, mock_model + ): """Test saving with tokenizer.""" + mock_get_rank.return_value = 0 + mock_dp_mesh, mock_tp_mesh = mock_meshes + + mock_checkpointer = MagicMock() + mock_checkpointer_cls.return_value = mock_checkpointer + + manager = AutomodelCheckpointManager( + dp_mesh=mock_dp_mesh, + tp_mesh=mock_tp_mesh, + ) + manager.init_checkpointer() + with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "model", "weights") tokenizer_path = os.path.join(tmp_dir, "tokenizer") - os.makedirs(os.path.dirname(weights_path)) os.makedirs(tokenizer_path) # Create mock tokenizer mock_tokenizer = MagicMock() # Save checkpoint with tokenizer - save_checkpoint( + manager.save_checkpoint( model=mock_model, weights_path=weights_path, tokenizer=mock_tokenizer, tokenizer_path=tokenizer_path, + checkpointing_cfg={"enabled": True}, ) # Verify tokenizer.save_pretrained was called mock_tokenizer.save_pretrained.assert_called_once_with(tokenizer_path) -@pytest.fixture -def mock_experiment(): - """Create a real model, optimizer, and scheduler for integration testing.""" - model = TestModel() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) - return model, optimizer, scheduler - - -def check_dict_equality(dict1, dict2): - """Recursively check equality of two dictionaries""" - for k in dict1.keys(): - if isinstance(dict1[k], dict): - check_dict_equality(dict1[k], dict2[k]) - elif isinstance(dict1[k], torch.Tensor): - assert torch.allclose(dict1[k], dict2[k]) - else: - assert dict1[k] == dict2[k] - - @pytest.mark.automodel class TestSaveLoadIntegration: """Integration tests that actually save and load checkpoints.""" - def test_save_and_load_model_only_safetensors(self, mock_experiment): + def test_save_and_load_model_only_safetensors( + self, init_distributed, mock_experiment + ): """Test saving and loading model weights only with safetensors format.""" test_model, _, _ = mock_experiment - original_state_dict = test_model.state_dict() + original_state_dict = {k: v.clone() for k, v in test_model.state_dict().items()} with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "test_model") + # Create device mesh and checkpoint manager + mesh = torch.distributed.device_mesh.init_device_mesh( + "cpu", (1,), mesh_dim_names=("dp",) + ) + checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=mesh, + tp_mesh=mesh, + ) + checkpoint_manager.init_checkpointer( + config_updates={"model_save_format": "safetensors"} + ) + # Save checkpoint - save_checkpoint( + checkpoint_manager.save_checkpoint( model=test_model, weights_path=weights_path, - model_save_format="safetensors", + checkpointing_cfg={ + "enabled": True, + "model_save_format": "safetensors", + }, ) # Verify files are created @@ -364,24 +800,43 @@ def test_save_and_load_model_only_safetensors(self, mock_experiment): param.data.fill_(999.0) # Load the checkpoint - load_checkpoint(model=new_model, weights_path=weights_path) + checkpoint_manager.load_checkpoint( + model=new_model, weights_path=weights_path + ) # Verify the weights match the original check_dict_equality(new_model.state_dict(), original_state_dict) - def test_save_and_load_model_only_torch_save(self, mock_experiment): + def test_save_and_load_model_only_torch_save( + self, init_distributed, mock_experiment + ): """Test saving and loading model weights only with torch_save format.""" test_model, _, _ = mock_experiment - original_state_dict = test_model.state_dict() + original_state_dict = {k: v.clone() for k, v in test_model.state_dict().items()} with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "test_model") + # Create device mesh and checkpoint manager + mesh = torch.distributed.device_mesh.init_device_mesh( + "cpu", (1,), mesh_dim_names=("dp",) + ) + checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=mesh, + tp_mesh=mesh, + ) + checkpoint_manager.init_checkpointer( + config_updates={"model_save_format": "torch_save"} + ) + # Save checkpoint - save_checkpoint( + checkpoint_manager.save_checkpoint( model=test_model, weights_path=weights_path, - model_save_format="torch_save", + checkpointing_cfg={ + "enabled": True, + "model_save_format": "torch_save", + }, ) # Verify files are created @@ -396,12 +851,14 @@ def test_save_and_load_model_only_torch_save(self, mock_experiment): param.data.fill_(999.0) # Load the checkpoint - load_checkpoint(model=new_model, weights_path=weights_path) + checkpoint_manager.load_checkpoint( + model=new_model, weights_path=weights_path + ) # Verify the weights match the original check_dict_equality(new_model.state_dict(), original_state_dict) - def test_save_and_load_model_and_optimizer(self, mock_experiment): + def test_save_and_load_model_and_optimizer(self, init_distributed, mock_experiment): """Test saving and loading both model and optimizer.""" test_model, optimizer, scheduler = mock_experiment @@ -415,7 +872,9 @@ def test_save_and_load_model_and_optimizer(self, mock_experiment): optimizer.step() scheduler.step() - original_model_state = test_model.state_dict() + original_model_state = { + k: v.clone() for k, v in test_model.state_dict().items() + } original_optimizer_state = optimizer.state_dict() original_scheduler_state = scheduler.state_dict() @@ -425,13 +884,26 @@ def test_save_and_load_model_and_optimizer(self, mock_experiment): os.makedirs(os.path.dirname(model_path), exist_ok=True) os.makedirs(os.path.dirname(optimizer_path), exist_ok=True) + # Create device mesh and checkpoint manager + mesh = torch.distributed.device_mesh.init_device_mesh( + "cpu", (1,), mesh_dim_names=("dp",) + ) + checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=mesh, + tp_mesh=mesh, + ) + checkpoint_manager.init_checkpointer( + config_updates={"model_save_format": "safetensors"} + ) + # Save checkpoint - save_checkpoint( + checkpoint_manager.save_checkpoint( model=test_model, weights_path=model_path, optimizer=optimizer, scheduler=scheduler, optimizer_path=optimizer_path, + checkpointing_cfg={"enabled": True}, ) # Verify files are created @@ -450,7 +922,7 @@ def test_save_and_load_model_and_optimizer(self, mock_experiment): param.data.fill_(999.0) # Load the checkpoint - load_checkpoint( + checkpoint_manager.load_checkpoint( model=new_model, weights_path=model_path, optimizer=new_optimizer, @@ -464,7 +936,7 @@ def test_save_and_load_model_and_optimizer(self, mock_experiment): assert new_scheduler.state_dict() == original_scheduler_state def test_save_and_load_model_with_lora( - self, mock_experiment, mock_lora_config, init_distributed + self, init_distributed, mock_experiment, mock_lora_config ): """Test saving and loading both model and optimizer with LORA.""" test_model, _, _ = mock_experiment @@ -508,11 +980,28 @@ def test_save_and_load_model_with_lora( with TemporaryDirectory() as tmp_dir: weights_path = os.path.join(tmp_dir, "test_model") - save_checkpoint( + + # Create device mesh and checkpoint manager with PEFT enabled + mesh = torch.distributed.device_mesh.init_device_mesh( + "cpu", (1,), mesh_dim_names=("dp",) + ) + checkpoint_manager = AutomodelCheckpointManager( + dp_mesh=mesh, + tp_mesh=mesh, + ) + checkpoint_manager.init_checkpointer( + config_updates={"model_save_format": "safetensors", "is_peft": True} + ) + + checkpoint_manager.save_checkpoint( model=test_model, weights_path=weights_path, - model_save_format="safetensors", - is_peft=True, + checkpointing_cfg={ + "enabled": True, + "model_save_format": "safetensors", + "is_peft": True, + }, + lora_enabled=True, peft_config=lora_config, ) @@ -528,8 +1017,10 @@ def test_save_and_load_model_with_lora( for param in new_model.parameters(): param.data.fill_(999.0) - # Load the checkpoint for peft need distributed(refer to nemo_automodel/components/checkpoint/stateful_wrappers.py:load_state_dict) - load_checkpoint(model=new_model, weights_path=weights_path) + # Load the checkpoint + checkpoint_manager.load_checkpoint( + model=new_model, weights_path=weights_path + ) # peft only save lora weights, so we need to filter out the non-lora weights lora_params_original = { k: v for k, v in lora_state_dict.items() if "lora" in k diff --git a/uv.lock b/uv.lock index 5675ee364d..58c9ce5cc7 100644 --- a/uv.lock +++ b/uv.lock @@ -35,7 +35,6 @@ overrides = [ { name = "opencv-python-headless", specifier = ">=4.11.0" }, { name = "timm", specifier = "<=1.0.22" }, { name = "transformer-engine", extras = ["pytorch"], specifier = "==2.8.0" }, - { name = "transformers", specifier = "==4.57.1" }, ] [[manifest.dependency-metadata]] @@ -45,7 +44,7 @@ requires-dist = ["torch", "packaging", "ninja"] [[manifest.dependency-metadata]] name = "deep-ep" -version = "1.1.0+e3908bf" +version = "1.2.1+bfded34" requires-dist = ["torch", "packaging", "ninja"] [[manifest.dependency-metadata]] @@ -432,6 +431,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/22/91616fe707a5c5510de2cac9b046a30defe7007ba8a0c04f9c08f27df312/audioop_lts-0.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:b492c3b040153e68b9fdaff5913305aaaba5bb433d8a7f73d5cf6a64ed3cc1dd", size = 25206, upload-time = "2025-08-05T16:43:16.444Z" }, ] +[[package]] +name = "audioread" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "standard-aifc", marker = "python_full_version >= '3.13'" }, + { name = "standard-sunau", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/4a/874ecf9b472f998130c2b5e145dcdb9f6131e84786111489103b66772143/audioread-3.1.0.tar.gz", hash = "sha256:1c4ab2f2972764c896a8ac61ac53e261c8d29f0c6ccd652f84e18f08a4cab190", size = 20082, upload-time = "2025-10-26T19:44:13.484Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/16/fbe8e1e185a45042f7cd3a282def5bb8d95bb69ab9e9ef6a5368aa17e426/audioread-3.1.0-py3-none-any.whl", hash = "sha256:b30d1df6c5d3de5dcef0fb0e256f6ea17bdcf5f979408df0297d8a408e2971b4", size = 23143, upload-time = "2025-10-26T19:44:12.016Z" }, +] + [[package]] name = "av" version = "15.0.0" @@ -549,19 +561,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/40/be3858ffed004e47e48a2cefecdbf9b950d41098b780f9dc3aa609a88351/bitarray-3.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:2a3d1b05ffdd3e95687942ae7b13c63689f85d3f15c39b33329e3cb9ce6c015f", size = 147015, upload-time = "2025-11-02T21:40:35.064Z" }, ] -[[package]] -name = "bitsandbytes" -version = "0.45.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, - { name = "torch", version = "2.9.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/b7/cb5ce4d1a382cf53c19ef06c5fc29e85f5e129b4da6527dd207d90a5b8ad/bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a5453f30cc6aab6ccaac364e6bf51a7808d3da5f71763dffeb6d9694c59136e4", size = 76059261, upload-time = "2025-04-07T13:32:52.573Z" }, - { url = "https://files.pythonhosted.org/packages/a6/4c/77b535e025ce780d2ada8271c1e481fb7337c1df2588a52fe1c9bd87d2e8/bitsandbytes-0.45.5-py3-none-win_amd64.whl", hash = "sha256:ed1c61b91d989d6a33fd05737d6edbf5086d8ebc89235ee632c7a19144085da2", size = 75430204, upload-time = "2025-04-07T13:32:57.553Z" }, -] - [[package]] name = "bitstring" version = "4.3.1" @@ -1279,12 +1278,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/57/ecc9ae29fa5b2d90107cd1d9bf8ed19aacb74b2264d986ae9d44fe9bdf87/debugpy-1.8.16-py2.py3-none-any.whl", hash = "sha256:19c9521962475b87da6f673514f7fd610328757ec993bf7ec0d8c96f9a325f9e", size = 5287700, upload-time = "2025-08-06T18:00:42.333Z" }, ] +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + [[package]] name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -1293,8 +1301,8 @@ wheels = [ [[package]] name = "deep-ep" -version = "1.1.0+e3908bf" -source = { git = "https://github.com/deepseek-ai/DeepEP.git?rev=e3908bf5bd0cc6265bcb225d15cd8c996d4759ef#e3908bf5bd0cc6265bcb225d15cd8c996d4759ef" } +version = "1.2.1+bfded34" +source = { git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480#bfded34800dfec415b71503f8205181de90b2480" } dependencies = [ { name = "ninja" }, { name = "packaging" }, @@ -1361,6 +1369,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/ae/afb1487556e2dc827a17097aac8158a25b433a345386f0e249f6d2694ccb/devtools-0.12.2-py3-none-any.whl", hash = "sha256:c366e3de1df4cdd635f1ad8cbcd3af01a384d7abda71900e68d43b04eb6aaca7", size = 19411, upload-time = "2023-09-03T16:56:59.049Z" }, ] +[[package]] +name = "diffusers" +version = "0.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "httpx" }, + { name = "huggingface-hub" }, + { name = "importlib-metadata" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "regex" }, + { name = "requests" }, + { name = "safetensors" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/45/ccb2e2180ddf475a0f931dac6a50346310e4c464ce3cccb8a65d1fc1e16d/diffusers-0.36.0.tar.gz", hash = "sha256:a9cde8721b415bde6a678f2d02abb85396487e1b0e0d2b4abb462d14a9825ab0", size = 3795088, upload-time = "2025-12-08T10:14:34.255Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/50/281f92cb1f83854dbd79b6e958b3bc5018607e2542971d41604ba7a14b2f/diffusers-0.36.0-py3-none-any.whl", hash = "sha256:525d42abc74bfc3b2db594999961295c054b48ef40a11724dacf50e6abd1af98", size = 4597884, upload-time = "2025-12-08T10:14:31.979Z" }, +] + [[package]] name = "dill" version = "0.3.8" @@ -2144,6 +2172,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "h2" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, +] + [[package]] name = "hatchling" version = "1.27.0" @@ -2174,6 +2215,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/d3/0aaf279f4f3dea58e99401b92c31c0f752924ba0e6c7d7bb07b1dbd7f35e/hf_xet-1.1.8-cp37-abi3-win_amd64.whl", hash = "sha256:4171f31d87b13da4af1ed86c98cf763292e4720c088b4957cf9d564f92904ca9", size = 2801689, upload-time = "2025-08-18T22:01:04.81Z" }, ] +[[package]] +name = "hpack" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -2224,6 +2274,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[package.optional-dependencies] +http2 = [ + { name = "h2" }, +] + [[package]] name = "huggingface-hub" version = "0.34.4" @@ -2257,6 +2312,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547, upload-time = "2023-02-23T18:33:40.801Z" }, ] +[[package]] +name = "hyperframe" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, +] + [[package]] name = "identify" version = "2.6.13" @@ -2275,6 +2339,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "imageio-ffmpeg" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" }, + { url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" }, + { url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" }, + { url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" }, + { url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" }, +] + [[package]] name = "imagesize" version = "1.4.1" @@ -2536,6 +2614,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/60/dfbbf40e3a371388c0e03ff65b01319b7d4023e883df6d7261125772ffdc/latex2sympy2_extended-1.10.2-py3-none-any.whl", hash = "sha256:f910442c5b02a466c1046f47d05cc5285181068b882399281f30102715337fb7", size = 207855, upload-time = "2025-07-02T15:26:04.88Z" }, ] +[[package]] +name = "lazy-loader" +version = "0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/6b/c875b30a1ba490860c93da4cabf479e03f584eba06fe5963f6f6644653d8/lazy_loader-0.4.tar.gz", hash = "sha256:47c75182589b91a4e1a85a136c074285a5ad4d9f39c63e0d7fb76391c4574cd1", size = 15431, upload-time = "2024-04-05T13:03:12.261Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/60/d497a310bde3f01cb805196ac61b7ad6dc5dcf8dce66634dc34364b20b4f/lazy_loader-0.4-py3-none-any.whl", hash = "sha256:342aa8e14d543a154047afb4ba8ef17f5563baad3fc610d7b15b213b0f119efc", size = 12097, upload-time = "2024-04-05T13:03:10.514Z" }, +] + +[[package]] +name = "librosa" +version = "0.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "audioread" }, + { name = "decorator" }, + { name = "joblib" }, + { name = "lazy-loader" }, + { name = "msgpack" }, + { name = "numba" }, + { name = "numpy" }, + { name = "pooch" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "soundfile" }, + { name = "soxr" }, + { name = "standard-aifc", marker = "python_full_version >= '3.13'" }, + { name = "standard-sunau", marker = "python_full_version >= '3.13'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/36/360b5aafa0238e29758729e9486c6ed92a6f37fa403b7875e06c115cdf4a/librosa-0.11.0.tar.gz", hash = "sha256:f5ed951ca189b375bbe2e33b2abd7e040ceeee302b9bbaeeffdfddb8d0ace908", size = 327001, upload-time = "2025-03-11T15:09:54.884Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/ba/c63c5786dfee4c3417094c4b00966e61e4a63efecee22cb7b4c0387dda83/librosa-0.11.0-py3-none-any.whl", hash = "sha256:0b6415c4fd68bff4c29288abe67c6d80b587e0e1e2cfb0aad23e4559504a7fa1", size = 260749, upload-time = "2025-03-11T15:09:52.982Z" }, +] + [[package]] name = "liger-kernel" version = "0.6.2" @@ -3400,10 +3516,14 @@ wheels = [ name = "nemo-automodel" source = { editable = "3rdparty/Automodel-workspace/Automodel" } dependencies = [ - { name = "bitsandbytes", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "datasets" }, + { name = "diffusers" }, + { name = "ftfy" }, + { name = "imageio-ffmpeg" }, { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "megatron-fsdp" }, + { name = "mlflow" }, + { name = "opencv-python-headless" }, { name = "pybind11" }, { name = "pyyaml" }, { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, @@ -3415,22 +3535,39 @@ dependencies = [ ] [package.optional-dependencies] -fa = [ - { name = "flash-attn" }, +all = [ + { name = "backoff" }, + { name = "flash-linear-attention" }, + { name = "mistral-common", extra = ["opencv"] }, + { name = "numba" }, + { name = "numpy" }, + { name = "perceptron" }, + { name = "pillow" }, + { name = "qwen-omni-utils" }, + { name = "qwen-vl-utils", extra = ["decord"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "sentencepiece" }, + { name = "timm" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, ] -moe = [ +cuda = [ + { name = "flash-attn" }, { name = "transformer-engine", extra = ["pytorch"] }, ] +extra = [ + { name = "flash-linear-attention" }, + { name = "perceptron" }, + { name = "sentencepiece" }, +] vlm = [ { name = "backoff" }, { name = "mistral-common", extra = ["opencv"] }, { name = "numba" }, { name = "numpy" }, { name = "pillow" }, - { name = "qwen-vl-utils", extra = ["decord"] }, + { name = "qwen-omni-utils" }, + { name = "qwen-vl-utils", extra = ["decord"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "timm" }, - { name = "torchcodec" }, - { name = "transformers" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, ] [package.dev-dependencies] @@ -3464,30 +3601,39 @@ test = [ [package.metadata] requires-dist = [ { name = "backoff", marker = "extra == 'vlm'" }, - { name = "bitsandbytes", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = "==0.45.5" }, { name = "datasets", specifier = ">=4.0.0" }, - { name = "flash-attn", marker = "extra == 'fa'", specifier = "<=2.8.3" }, + { name = "diffusers" }, + { name = "flash-attn", marker = "extra == 'cuda'", specifier = "<=2.8.3" }, + { name = "flash-linear-attention", marker = "extra == 'extra'" }, + { name = "ftfy" }, + { name = "imageio-ffmpeg" }, { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = ">=0.5.9" }, { name = "megatron-fsdp" }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'vlm'" }, + { name = "mlflow" }, + { name = "nemo-automodel", extras = ["extra"], marker = "extra == 'all'", editable = "3rdparty/Automodel-workspace/Automodel" }, + { name = "nemo-automodel", extras = ["vlm"], marker = "extra == 'all'", editable = "3rdparty/Automodel-workspace/Automodel" }, { name = "numba", marker = "extra == 'vlm'" }, { name = "numpy", marker = "extra == 'vlm'" }, + { name = "opencv-python-headless", specifier = "==4.10.0.84" }, + { name = "perceptron", marker = "extra == 'extra'" }, { name = "pillow", marker = "extra == 'vlm'" }, { name = "pybind11" }, { name = "pyyaml" }, - { name = "qwen-vl-utils", extras = ["decord"], marker = "extra == 'vlm'" }, - { name = "timm", marker = "extra == 'vlm'", specifier = "==1.0.16" }, + { name = "qwen-omni-utils", marker = "extra == 'vlm'" }, + { name = "qwen-vl-utils", extras = ["decord"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" }, + { name = "sentencepiece", marker = "extra == 'extra'" }, + { name = "timm", marker = "extra == 'vlm'", specifier = "<=1.0.22" }, { name = "torch", marker = "sys_platform != 'darwin'", specifier = "<=2.9.0", index = "https://download.pytorch.org/whl/cu129" }, { name = "torch", marker = "sys_platform == 'darwin'", specifier = "<=2.9.0", index = "https://pypi.org/simple" }, - { name = "torchao", specifier = "==0.14.1" }, - { name = "torchcodec", marker = "extra == 'vlm'" }, + { name = "torchao" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" }, { name = "torchdata" }, - { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'moe'", specifier = "==2.8.0" }, - { name = "transformers", specifier = "<=4.57.1" }, - { name = "transformers", marker = "extra == 'vlm'", specifier = "<=4.57.1" }, + { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'cuda'", specifier = "==2.8.0" }, + { name = "transformers", specifier = "<=4.57.3" }, { name = "wandb" }, ] -provides-extras = ["vlm", "fa", "moe"] +provides-extras = ["cuda", "extra", "vlm", "all"] [package.metadata.requires-dev] build = [ @@ -3606,9 +3752,12 @@ dependencies = [ [package.optional-dependencies] automodel = [ { name = "causal-conv1d" }, + { name = "deep-ep" }, { name = "flash-attn" }, { name = "mamba-ssm" }, { name = "nemo-automodel" }, + { name = "nv-grouped-gemm" }, + { name = "transformer-engine", extra = ["pytorch"] }, { name = "vllm" }, ] mcore = [ @@ -3681,7 +3830,8 @@ requires-dist = [ { name = "cuda-python", marker = "extra == 'vllm'" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, - { name = "deep-ep", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=e3908bf5bd0cc6265bcb225d15cd8c996d4759ef" }, + { name = "deep-ep", marker = "extra == 'automodel'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, + { name = "deep-ep", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-gemm", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepGEMM.git?rev=7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" }, { name = "flash-attn", marker = "extra == 'automodel'", specifier = "==2.8.1" }, { name = "flash-attn", marker = "extra == 'mcore'", specifier = "==2.8.1" }, @@ -3700,6 +3850,7 @@ requires-dist = [ { name = "num2words", specifier = ">=0.5.14" }, { name = "num2words", marker = "extra == 'vllm'", specifier = ">=0.5.14" }, { name = "numpy" }, + { name = "nv-grouped-gemm", marker = "extra == 'automodel'", git = "https://github.com/fanshiqing/grouped_gemm?tag=v1.1.4.post7" }, { name = "nvidia-ml-py" }, { name = "nvidia-nvshmem-cu12", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "nvtx" }, @@ -3720,6 +3871,7 @@ requires-dist = [ { name = "torchdata" }, { name = "torchvision", marker = "sys_platform != 'darwin'", specifier = ">=0.22.0", index = "https://download.pytorch.org/whl/cu129" }, { name = "torchvision", marker = "sys_platform == 'darwin'", specifier = ">=0.22.0", index = "https://pypi.org/simple" }, + { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'automodel'", specifier = "==2.8.0" }, { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'mcore'", specifier = "==2.8.0" }, { name = "transformers", specifier = "==4.57.1" }, { name = "triton", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')", index = "https://download.pytorch.org/whl/cu129" }, @@ -4532,6 +4684,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/fe/a2da1627aa9cb6310b6034598363bd26ac301c4a99d21f415b1b2855891e/peft-0.17.1-py3-none-any.whl", hash = "sha256:3d129d64def3d74779c32a080d2567e5f7b674e77d546e3585138216d903f99e", size = 504896, upload-time = "2025-08-21T09:25:18.974Z" }, ] +[[package]] +name = "perceptron" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, + { name = "httpx", extra = ["http2"] }, + { name = "numpy" }, + { name = "pillow" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/60/85db2243d8b550823603d8f9c5845b0dd0f01074e9aabf0b2af0c4f52565/perceptron-0.1.4.tar.gz", hash = "sha256:62fd190efb74925e2cc33c0cd38761e19959be3bdb7b24fbf9e3386d6961f690", size = 78116, upload-time = "2025-11-12T20:00:28.024Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/17/b7cb1a10ebb0a9a4c9fbcd96a28b43d44e08a90f620bab07e644a658d2f1/perceptron-0.1.4-py3-none-any.whl", hash = "sha256:f490a6df6c15167e91e1a528601cae98ce99a30991cf792f9ef83ebc15d335c4", size = 57421, upload-time = "2025-11-12T20:00:26.395Z" }, +] + [[package]] name = "pillow" version = "11.3.0" @@ -4638,6 +4808,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "pooch" +version = "1.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "platformdirs" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/77/b3d3e00c696c16cf99af81ef7b1f5fe73bd2a307abca41bd7605429fe6e5/pooch-1.8.2.tar.gz", hash = "sha256:76561f0de68a01da4df6af38e9955c4c9d1a5c90da73f7e40276a5728ec83d10", size = 59353, upload-time = "2024-06-06T16:53:46.224Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574, upload-time = "2024-06-06T16:53:44.343Z" }, +] + [[package]] name = "pre-commit" version = "4.3.0" @@ -5439,6 +5623,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/32/3836ed85947b06f1d67c07ce16c00b0cf8c053ab0b249d234f9f81ff95ff/pyzmq-27.0.1-cp314-cp314t-win_arm64.whl", hash = "sha256:0fc24bf45e4a454e55ef99d7f5c8b8712539200ce98533af25a5bfa954b6b390", size = 575098, upload-time = "2025-08-03T05:04:27.974Z" }, ] +[[package]] +name = "qwen-omni-utils" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "av" }, + { name = "librosa" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b8/b1/cc58b03b5eadddc0812cef884d013ed6cc66b09f9b0f5b45123f89dcd056/qwen_omni_utils-0.0.8.tar.gz", hash = "sha256:b5808673e1455f4115cb784a62cdc8e8616576221a01fc738610b0f9268cb33c", size = 8145, upload-time = "2025-06-12T11:02:05.411Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/b1/dcdd69246a3c3c3bd6f6ced58e2307b3afbd894c4412c29fd49dd897e562/qwen_omni_utils-0.0.8-py3-none-any.whl", hash = "sha256:c42bcc633fbfd84d565ff0de9d45fae68a6b57a9b7b97a4b77eda71a0d3ee73a", size = 9218, upload-time = "2025-06-12T11:02:03.981Z" }, +] + [[package]] name = "qwen-vl-utils" version = "0.0.11" @@ -5456,7 +5656,7 @@ wheels = [ [package.optional-dependencies] decord = [ - { name = "decord" }, + { name = "decord", marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, ] [[package]] @@ -6196,6 +6396,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/9c/0e6afc12c269578be5c0c1c9f4b49a8d32770a080260c333ac04cc1c832d/soupsieve-2.7-py3-none-any.whl", hash = "sha256:6e60cc5c1ffaf1cebcc12e8188320b72071e922c2e897f737cadce79ad5d30c4", size = 36677, upload-time = "2025-04-20T18:50:07.196Z" }, ] +[[package]] +name = "soxr" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/7e/f4b461944662ad75036df65277d6130f9411002bfb79e9df7dff40a31db9/soxr-1.0.0.tar.gz", hash = "sha256:e07ee6c1d659bc6957034f4800c60cb8b98de798823e34d2a2bba1caa85a4509", size = 171415, upload-time = "2025-09-07T13:22:21.317Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/c7/f92b81f1a151c13afb114f57799b86da9330bec844ea5a0d3fe6a8732678/soxr-1.0.0-cp312-abi3-macosx_10_14_x86_64.whl", hash = "sha256:abecf4e39017f3fadb5e051637c272ae5778d838e5c3926a35db36a53e3a607f", size = 205508, upload-time = "2025-09-07T13:22:01.252Z" }, + { url = "https://files.pythonhosted.org/packages/ff/1d/c945fea9d83ea1f2be9d116b3674dbaef26ed090374a77c394b31e3b083b/soxr-1.0.0-cp312-abi3-macosx_11_0_arm64.whl", hash = "sha256:e973d487ee46aa8023ca00a139db6e09af053a37a032fe22f9ff0cc2e19c94b4", size = 163568, upload-time = "2025-09-07T13:22:03.558Z" }, + { url = "https://files.pythonhosted.org/packages/b5/80/10640970998a1d2199bef6c4d92205f36968cddaf3e4d0e9fe35ddd405bd/soxr-1.0.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e8ce273cca101aff3d8c387db5a5a41001ba76ef1837883438d3c652507a9ccc", size = 204707, upload-time = "2025-09-07T13:22:05.125Z" }, + { url = "https://files.pythonhosted.org/packages/b1/87/2726603c13c2126cb8ded9e57381b7377f4f0df6ba4408e1af5ddbfdc3dd/soxr-1.0.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8f2a69686f2856d37823bbb7b78c3d44904f311fe70ba49b893af11d6b6047b", size = 238032, upload-time = "2025-09-07T13:22:06.428Z" }, + { url = "https://files.pythonhosted.org/packages/ce/04/530252227f4d0721a5524a936336485dfb429bb206a66baf8e470384f4a2/soxr-1.0.0-cp312-abi3-win_amd64.whl", hash = "sha256:2a3b77b115ae7c478eecdbd060ed4f61beda542dfb70639177ac263aceda42a2", size = 172070, upload-time = "2025-09-07T13:22:07.62Z" }, + { url = "https://files.pythonhosted.org/packages/99/77/d3b3c25b4f1b1aa4a73f669355edcaee7a52179d0c50407697200a0e55b9/soxr-1.0.0-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:392a5c70c04eb939c9c176bd6f654dec9a0eaa9ba33d8f1024ed63cf68cdba0a", size = 209509, upload-time = "2025-09-07T13:22:08.773Z" }, + { url = "https://files.pythonhosted.org/packages/8a/ee/3ca73e18781bb2aff92b809f1c17c356dfb9a1870652004bd432e79afbfa/soxr-1.0.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:fdc41a1027ba46777186f26a8fba7893be913383414135577522da2fcc684490", size = 167690, upload-time = "2025-09-07T13:22:10.259Z" }, + { url = "https://files.pythonhosted.org/packages/bd/f0/eea8b5f587a2531657dc5081d2543a5a845f271a3bea1c0fdee5cebde021/soxr-1.0.0-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:449acd1dfaf10f0ce6dfd75c7e2ef984890df94008765a6742dafb42061c1a24", size = 209541, upload-time = "2025-09-07T13:22:11.739Z" }, + { url = "https://files.pythonhosted.org/packages/64/59/2430a48c705565eb09e78346950b586f253a11bd5313426ced3ecd9b0feb/soxr-1.0.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:38b35c99e408b8f440c9376a5e1dd48014857cd977c117bdaa4304865ae0edd0", size = 243025, upload-time = "2025-09-07T13:22:12.877Z" }, + { url = "https://files.pythonhosted.org/packages/3c/1b/f84a2570a74094e921bbad5450b2a22a85d58585916e131d9b98029c3e69/soxr-1.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:a39b519acca2364aa726b24a6fd55acf29e4c8909102e0b858c23013c38328e5", size = 184850, upload-time = "2025-09-07T13:22:14.068Z" }, +] + [[package]] name = "sphinx" version = "8.2.3" @@ -6383,6 +6604,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca", size = 44415, upload-time = "2024-12-10T12:05:27.824Z" }, ] +[[package]] +name = "standard-aifc" +version = "3.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, + { name = "standard-chunk", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/53/6050dc3dde1671eb3db592c13b55a8005e5040131f7509cef0215212cb84/standard_aifc-3.13.0.tar.gz", hash = "sha256:64e249c7cb4b3daf2fdba4e95721f811bde8bdfc43ad9f936589b7bb2fae2e43", size = 15240, upload-time = "2024-10-30T16:01:31.772Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/52/5fbb203394cc852334d1575cc020f6bcec768d2265355984dfd361968f36/standard_aifc-3.13.0-py3-none-any.whl", hash = "sha256:f7ae09cc57de1224a0dd8e3eb8f73830be7c3d0bc485de4c1f82b4a7f645ac66", size = 10492, upload-time = "2024-10-30T16:01:07.071Z" }, +] + +[[package]] +name = "standard-chunk" +version = "3.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/06/ce1bb165c1f111c7d23a1ad17204d67224baa69725bb6857a264db61beaf/standard_chunk-3.13.0.tar.gz", hash = "sha256:4ac345d37d7e686d2755e01836b8d98eda0d1a3ee90375e597ae43aaf064d654", size = 4672, upload-time = "2024-10-30T16:18:28.326Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/90/a5c1084d87767d787a6caba615aa50dc587229646308d9420c960cb5e4c0/standard_chunk-3.13.0-py3-none-any.whl", hash = "sha256:17880a26c285189c644bd5bd8f8ed2bdb795d216e3293e6dbe55bbd848e2982c", size = 4944, upload-time = "2024-10-30T16:18:26.694Z" }, +] + +[[package]] +name = "standard-sunau" +version = "3.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/e3/ce8d38cb2d70e05ffeddc28bb09bad77cfef979eb0a299c9117f7ed4e6a9/standard_sunau-3.13.0.tar.gz", hash = "sha256:b319a1ac95a09a2378a8442f403c66f4fd4b36616d6df6ae82b8e536ee790908", size = 9368, upload-time = "2024-10-30T16:01:41.626Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/ae/e3707f6c1bc6f7aa0df600ba8075bfb8a19252140cd595335be60e25f9ee/standard_sunau-3.13.0-py3-none-any.whl", hash = "sha256:53af624a9529c41062f4c2fd33837f297f3baa196b0cfceffea6555654602622", size = 7364, upload-time = "2024-10-30T16:01:28.003Z" }, +] + [[package]] name = "starlette" version = "0.50.0" @@ -6815,9 +7070,7 @@ name = "torchcodec" version = "0.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/b3/11326a0e7a3c803a95975cfce4ac88fa4ea1a0d432bb876081046c5a5554/torchcodec-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fba260145a239b5afe13336e3a5bc1b089c9c31a073e9a7c2026d4cbd853fdd9", size = 3482584, upload-time = "2025-08-07T08:51:32.535Z" }, { url = "https://files.pythonhosted.org/packages/a7/d1/3f90561df013f6a015ef19de22726b64073fee405f53d3c4b8255ab05a67/torchcodec-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:fdef91a17fb1f1a159ce23710324a9a4e6d6a885275de73700f94a9ad562c6b2", size = 1370954, upload-time = "2025-08-07T08:51:15.021Z" }, - { url = "https://files.pythonhosted.org/packages/87/d0/0b5dd42652e4527d578e1d6239dbb907bf83e502115e517b83a55d8b7f8b/torchcodec-0.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:de20cab5df7fa7cdd74ec1dc0d508324685573f86de6789f0ebb860b7ea20b33", size = 3446017, upload-time = "2025-08-07T08:51:34.484Z" }, { url = "https://files.pythonhosted.org/packages/97/62/a938334e39101d4304619b90847d8aef7d1c607c6bcf33638f72931ae990/torchcodec-0.6.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:46dab701a2d809e975a8b07d7ee47ed34f1d903511e374c74cfc1de6a5ab0e3f", size = 1374794, upload-time = "2025-08-07T08:51:17.355Z" }, ]