Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ policy:
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false
refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB

dtensor_cfg:
enabled: false
Expand Down
1 change: 1 addition & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ policy:
precision: "bfloat16"
fsdp_offload_enabled: false
activation_checkpointing_enabled: false
refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB

optimizer:
name: "torch.optim.AdamW"
Expand Down
42 changes: 36 additions & 6 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,34 @@ def setup(
def refit_policy_generation(
policy: PolicyInterface,
policy_generation: GenerationInterface,
refit_buffer_size_gb: int, # GB
):
"""Refit the policy generation interface with the latest policy weights."""
policy.offload_before_refit()
ipc_handles = policy.get_weights_ipc_handles()
policy_generation.prepare_for_generation()
policy_generation.update_weights(ipc_handles)
policy_generation.prepare_for_generation(tags=["weights"])
# Streaming update weights to save memory
state_dict_info = policy.prepare_weights_for_ipc()
# group keys to save time
available_bytes = refit_buffer_size_gb * (1024**3)
split_keys, keys = [], []
for key, size_in_bytes in state_dict_info:
if size_in_bytes > available_bytes:
if keys:
split_keys.append(keys)
keys = []
available_bytes = refit_buffer_size_gb * (1024**3)

keys.append(key)
available_bytes -= size_in_bytes

if len(keys) > 0:
split_keys.append(keys)
# do update
for keys in split_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
policy_generation.update_weights(ipc_handles)
policy.offload_after_refit()
policy_generation.prepare_for_generation(tags=["kv_cache"])


# ===============================================================================
Expand Down Expand Up @@ -321,12 +342,13 @@ def grpo_train(
consumed_samples = grpo_save_state["consumed_samples"]
val_period = master_config["grpo"]["val_period"]
val_at_start = master_config["grpo"]["val_at_start"]
refit_buffer_size_gb = master_config["policy"]["refit_buffer_size_gb"]

# Run validation at the start if configured
if val_at_start and step == 0:
print("\n🔍 Running initial validation...")
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(policy, policy_generation)
refit_policy_generation(policy, policy_generation, refit_buffer_size_gb)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
Expand Down Expand Up @@ -368,7 +390,11 @@ def grpo_train(
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(policy, policy_generation)
refit_policy_generation(
policy,
policy_generation,
refit_buffer_size_gb,
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
Expand Down Expand Up @@ -476,7 +502,11 @@ def grpo_train(
# Run validation if it's a validation step
if val_period > 0 and (step + 1) % val_period == 0:
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(policy, policy_generation)
refit_policy_generation(
policy,
policy_generation,
refit_buffer_size_gb,
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
Expand Down
12 changes: 9 additions & 3 deletions nemo_reinforcer/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,14 @@ def sleep(self):
gc.collect()
torch.cuda.empty_cache()

def wake_up(self):
self.llm.wake_up()
def wake_up(self, **kwargs):
# tags like ["weights", "kv_cache"]
# We can call this function with just tags=["weights"] while doing refit to
# avoid spiking memory with the kv_cache while the training fwk is awake.
if "tags" in kwargs:
self.llm.wake_up(tags=kwargs["tags"])
else:
self.llm.wake_up()


class VllmGeneration(GenerationInterface):
Expand Down Expand Up @@ -622,7 +628,7 @@ def prepare_for_generation(self, *args, **kwargs):
try:
# Use run_all_workers_single_data for methods that don't need data
futures = self.worker_group.run_all_workers_single_data(
"wake_up", only_on="tied_leader"
"wake_up", only_on="tied_leader", **kwargs
)
# Wait for all futures to complete
results = ray.get(futures)
Expand Down
2 changes: 1 addition & 1 deletion nemo_reinforcer/models/generation/vllm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def update_weights_from_ipc_handles(self, ipc_handles):
weights = []

# Process each handle to get the tensor
for name, handle in handles.items():
for name, handle in handles:
func, args = handle
list_args = list(args)
# Update device ID to match the current device
Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ class PolicyConfig(TypedDict):
max_grad_norm: Optional[Union[float, int]]
fsdp_offload_enabled: bool
activation_checkpointing_enabled: bool
refit_buffer_size_gb: int
84 changes: 44 additions & 40 deletions nemo_reinforcer/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def __init__(
if self.cpu_offload:
self.model = self.move_buffer_to_device(self.model, "cpu")

self._held_model_params = None
# used for streaming update inference engine weights
self._held_sharded_state_dict_reference = None
self._held_streamed_param_reference = None

if init_reference_model:
self.reference_model_state_dict = get_cpu_state_dict(
Expand Down Expand Up @@ -235,6 +237,9 @@ def __init__(
def is_alive(self):
return True

def reset_peak_memory_stats(self):
torch.cuda.reset_peak_memory_stats()

def get_gpu_info(self):
"""Return information about the GPU being used by this worker."""
return get_gpu_info(self.model)
Expand Down Expand Up @@ -554,50 +559,45 @@ def report_device_id(self) -> str:
return get_device_uuid(device_idx)

@torch.no_grad()
def get_weight_ipc_handles(self, offload_model=True):
from torch.multiprocessing.reductions import reduce_tensor

def prepare_weights_for_ipc(self):
self.model = self.move_to_cuda(self.model)
params = self.model.state_dict()
self._held_sharded_state_dict_reference = self.model.state_dict()
# Collect info for streaming multiple tensors
state_dict_info = []
for name, tensor in self._held_sharded_state_dict_reference.items():
# dtensor's numel will return complete tensor instead of only local tensor
size_in_bytes = tensor.element_size() * tensor.numel()
state_dict_info.append((name, size_in_bytes))
return state_dict_info

# Create a copy of parameters in the desired dtype (bfloat16 or float32)
dtype_params = {}
for name, param in params.items():
if isinstance(param, DTensor):
param = param.full_tensor()
@torch.no_grad()
def get_weights_ipc_handles(self, keys):
from torch.multiprocessing.reductions import reduce_tensor

converted_params = {}
for key in keys:
# Get full_tensor for dtensor (GPU > 1)
tensor = self._held_sharded_state_dict_reference[key]
if isinstance(tensor, DTensor):
full_tensor = tensor.full_tensor()
else:
full_tensor = tensor
# Convert parameters to the configured dtype
dtype_params[name] = param.to(
device="cuda", dtype=self.dtype, non_blocking=True
)
converted_params[key] = full_tensor.to(self.dtype, non_blocking=True)

for name, buffer in self.model.named_buffers():
if isinstance(buffer, DTensor):
buffer = buffer.full_tensor()
# Temporary record the full tensor for cleanup
# It is needed for cleanup the last full_tensor in the refit process
self._held_streamed_param_reference = converted_params

dtype_params[name] = buffer.to(
device="cuda", dtype=self.dtype, non_blocking=True
)

torch.cuda.synchronize()

# Replace the original params with the converted ones
params = dtype_params

# hold on to the params so we can explicitly delete them after refit
self._held_model_params = params

data = {}
# Get device UUID for IPC
device_uuid = self.report_device_id()
for name, p in params.items():
data[name] = reduce_tensor(p.detach())

if offload_model or self.cpu_offload:
self.model = self.move_to_cpu(self.model)
gc.collect()
torch.cuda.empty_cache()
# Create handles for the tensors
all_handles = []
for key, p in converted_params.items():
handle = reduce_tensor(p.detach())
all_handles.append((key, handle))

return {device_uuid: data}
return {device_uuid: all_handles}

def prepare_for_lp_inference(self):
if not self.cpu_offload:
Expand Down Expand Up @@ -655,9 +655,13 @@ def offload_after_refit(self):
torch.randn(1).cuda() # wake up torch allocator
self.offload_before_refit() # rerun the old offload function

if self._held_model_params is not None:
del self._held_model_params
self._held_model_params = None
# Clean up the held tensors
if self._held_sharded_state_dict_reference is not None:
del self._held_sharded_state_dict_reference
self._held_sharded_state_dict_reference = None
if self._held_streamed_param_reference is not None:
del self._held_streamed_param_reference
self._held_streamed_param_reference = None

gc.collect()
torch.cuda.empty_cache()
Expand Down
88 changes: 61 additions & 27 deletions nemo_reinforcer/models/policy/fsdp1_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,11 @@ def do_fsdp(model):
self.reference_model = do_fsdp(self.reference_model)
self.reference_model = self.manual_offload_to_cpu(self.reference_model)
self.model = self.manual_load_to_gpu(self.model)
self._held_reference_model_params = None

# used for streaming update inference engine weights
self._held_sharded_state_dict_reference = None
self._held_streamed_param_reference = None

# register_fsdp_forward_method(self.model, "generate")
if init_optimizer:
optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"])
Expand Down Expand Up @@ -205,6 +209,9 @@ def do_fsdp(model):
def is_alive(self):
return True

def reset_peak_memory_stats(self):
torch.cuda.reset_peak_memory_stats()

def get_gpu_info(self):
"""Return information about the GPU being used by this worker."""
return get_gpu_info(self.model)
Expand Down Expand Up @@ -720,38 +727,61 @@ def report_device_id(self) -> str:
return get_device_uuid(device_idx)

@torch.no_grad()
def get_weight_ipc_handles(self, offload_model=True):
from torch.multiprocessing.reductions import reduce_tensor
def prepare_weights_for_ipc(self):
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType

# If the model is not FSDP, then we need to manually move it to the GPU
# For an FSDP model, model.state_dict() will move the params to the GPU
if not isinstance(self.model, torch.distributed.fsdp.FullyShardedDataParallel):
if not isinstance(self.model, FullyShardedDataParallel):
self.model = self.manual_load_to_gpu(self.model)
self._held_sharded_state_dict_reference = self.model.state_dict()
else:
# Get sharded state dict instead of full state dict for FSDP1
with FullyShardedDataParallel.state_dict_type(
self.model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
):
self._held_sharded_state_dict_reference = self.model.state_dict()

# Collect info for streaming multiple tensors
state_dict_info = []
for name, tensor in self._held_sharded_state_dict_reference.items():
# dtensor's numel will return complete tensor instead of only local tensor
size_in_bytes = tensor.element_size() * tensor.numel()
state_dict_info.append((name, size_in_bytes))

return state_dict_info

# TODO @sahilj: do this without an allgather (maybe FSDP2)
params = self.model.state_dict()
@torch.no_grad()
def get_weights_ipc_handles(self, keys):
from torch.distributed.tensor import DTensor
from torch.multiprocessing.reductions import reduce_tensor

# Create a copy of parameters in the desired dtype (bfloat16 or float32)
dtype_params = {}
for name, param in params.items():
converted_params = {}
for key in keys:
# Get full_tensor for dtensor (GPU > 1)
tensor = self._held_sharded_state_dict_reference[key]
if isinstance(tensor, DTensor):
full_tensor = tensor.full_tensor()
else:
full_tensor = tensor
# Convert parameters to the configured dtype
dtype_params[name] = param.to(self.dtype, non_blocking=True)

# Replace the original params with the converted ones
params = dtype_params
# For FSDP1, params may get GC'ed before sending to vllm,
# so we need to hold a reference to them
self._held_reference_model_params = params
data = {}
converted_params[key] = full_tensor.to(self.dtype, non_blocking=True)

# Temporary record the full tensor for cleanup
# It is needed for cleanup the last full_tensor in the refit process
self._held_streamed_param_reference = converted_params

# Get device UUID for IPC
device_uuid = self.report_device_id()
for name, p in params.items():
data[name] = reduce_tensor(p.detach())
# Create handles for the tensors
all_handles = []
for key, p in converted_params.items():
handle = reduce_tensor(p.detach())
all_handles.append((key, handle))

if offload_model:
self.model = self.manual_offload_to_cpu(self.model)
gc.collect()
torch.cuda.empty_cache()
return {device_uuid: data}
return {device_uuid: all_handles}

def prepare_for_lp_inference(self):
self.model = self.manual_load_to_gpu(self.model)
Expand Down Expand Up @@ -802,9 +832,13 @@ def offload_after_refit(self):
torch.randn(1).cuda() # wake up torch allocator
self.offload_before_refit() # rerun the old offload function

if self._held_reference_model_params is not None:
del self._held_reference_model_params
self._held_reference_model_params = None
# Clean up the held tensors
if self._held_sharded_state_dict_reference is not None:
del self._held_sharded_state_dict_reference
self._held_sharded_state_dict_reference = None
if self._held_streamed_param_reference is not None:
del self._held_streamed_param_reference
self._held_streamed_param_reference = None

gc.collect()
torch.cuda.empty_cache()
Expand Down
Loading