Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ policy:
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ policy:
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
overlap_param_gather: true
average_in_collective: true
use_custom_fsdp: false
data_parallel_sharding_strategy: "optim_grads_params"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ policy:
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ policy:
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ policy:
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
2 changes: 1 addition & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ policy:
distributed_data_parallel_config:
grad_reduce_in_fp32: false
overlap_grad_reduce: true
overlap_param_gather: false
overlap_param_gather: true
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

Expand Down
32 changes: 23 additions & 9 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,15 +421,6 @@ def __init__(
pretrained_path, "iter_0000000/run_config.yaml"
)

assert not (
self.cfg["megatron_cfg"]["distributed_data_parallel_config"][
"overlap_param_gather"
]
and self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"]
), (
"Using overlap param gather together with distributed optimizer has known convergence issues. Please disable overlap param gather."
)

self.tokenizer = tokenizer
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
Expand Down Expand Up @@ -633,6 +624,13 @@ def __init__(
self._held_gather_buffer = None
self.megatron_to_hf_converter = MegatronToHFConverter(hf_model_name, self.model)

self.should_disable_forward_pre_hook = (
self.cfg["megatron_cfg"]["optimizer"]["use_distributed_optimizer"]
and self.cfg["megatron_cfg"]["distributed_data_parallel_config"][
"overlap_param_gather"
]
)

def configure_worker(self, num_gpus: int, bundle_indices: Optional[tuple] = None):
USE_EXPANDABLE_SEGMENTS = False # Disabling this right now as it seems to cause vLLM refit issues with Ampere
if USE_EXPANDABLE_SEGMENTS:
Expand All @@ -650,6 +648,14 @@ def get_gpu_info(self):
"""Return information about the GPU being used by this worker."""
return get_gpu_info(self.model)

def enable_forward_pre_hook(self):
assert isinstance(self.model, DistributedDataParallel)
self.model.enable_forward_pre_hook()

def disable_forward_pre_hook(self, param_sync=True):
assert isinstance(self.model, DistributedDataParallel)
self.model.disable_forward_pre_hook(param_sync=param_sync)

def train(
self,
data: BatchedDataDict,
Expand Down Expand Up @@ -989,6 +995,10 @@ def use_reference_model(self):
On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references
On exit: Restores original references and re-flips cuda/cpu
"""
## disable overlap param gather when swapping weights
if self.should_disable_forward_pre_hook:
self.disable_forward_pre_hook()

with torch.no_grad():
try:
# Save original references
Expand Down Expand Up @@ -1023,6 +1033,10 @@ def use_reference_model(self):
gc.collect()
torch.cuda.empty_cache()

## re-enable overlap param gather after weight swap
if self.should_disable_forward_pre_hook:
self.enable_forward_pre_hook()

# Temporary fix, 'data' is a kwarg due to some sort of ray bug
def get_reference_policy_logprobs(
self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None
Expand Down
Loading