Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/forge/actors/_torchstore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta
from torchstore.transport.buffers import rdma_available

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -69,3 +70,8 @@ def extract_param_name(key: str) -> str:

def get_dcp_whole_state_dict_key(policy_version: int) -> str:
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}"


def rdma_enabled() -> bool:
"""Return if TorchStore thinks we're using RDMA"""
return rdma_available()
37 changes: 21 additions & 16 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
rdma_available,
)

from forge.controller import (
Expand All @@ -56,7 +57,6 @@
)
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt
from forge.env import TORCHSTORE_USE_RDMA
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
Expand Down Expand Up @@ -112,7 +112,7 @@ def __post_init__(self):
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY

if self.use_dcp_for_weight_sync is None:
self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value()
self.use_dcp_for_weight_sync = not rdma_available()
logger.debug(f"{self.use_dcp_for_weight_sync=}")

@endpoint
Expand Down Expand Up @@ -492,14 +492,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
await stop_proc_mesh(actor._generator_proc)

@endpoint
async def save_model_params(self):
"""Used for debugging purpose. Save model parameters before weight update."""
await self.worker.save_model_params.call()
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[Generator] save model parameters for testing.")
await self.worker._test_save_model_params.call()

@endpoint
async def validate_model_params(self, validate_fn):
"""Used for debugging purpose. Validate saved params using validate_fn."""
return await self.worker.validate_model_params.call(validate_fn)
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[Generator] start validating model parameters.")
return await self.worker._test_validate_model_params.call(validate_fn)


@dataclass
Expand All @@ -512,6 +514,8 @@ class GeneratorWorker(ForgeActor):
"""

vllm_config: VllmConfig
# TODO: Remove below param
_test_prev_params = {}

@endpoint
async def setup(self):
Expand Down Expand Up @@ -601,19 +605,20 @@ async def update_weights(self, version: int) -> None:
t.stop()

@endpoint
async def save_model_params(self):
"""Used for debugging purposes. Save model parameters before weight update."""
self._debug_saved_params = {}
async def _test_save_model_params(self):
"""Save model parameters before weight update, used for tesing purposes only."""
logger.info("[GeneratorWorker] save model parameters for testing.")
for name, param in self.worker.model_runner.model.named_parameters():
self._debug_saved_params[name] = param.detach().cpu()
self._test_prev_params[name] = param.detach().cpu()
logger.info(
"[GeneratorWorker] finished saving model parameters, len = %d",
len(self._debug_saved_params),
len(self._test_prev_params),
)

@endpoint
async def validate_model_params(self, validate_fn):
"""Used for debugging purposes. Validate saved params using validate_fn."""
async def _test_validate_model_params(self, validate_fn):
"""Validate updated model params using validate_fn."""
logger.info("[GeneratorWorker] start validating model parameters.")
return validate_fn(
self._debug_saved_params, self.worker.model_runner.model, logger
self._test_prev_params, self.worker.model_runner.model, logger
)
6 changes: 2 additions & 4 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
rdma_available,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.env import TORCHSTORE_USE_RDMA
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

Expand Down Expand Up @@ -131,9 +131,7 @@ class RLTrainer(ForgeActor):
# Non JobConfig-related fields
loss: Callable = lambda logits, **targets: logits
state_dict_key: str = "model_state_dict"
use_dcp: bool = (
TORCHSTORE_USE_RDMA.get_value() == 0
) # torchstore currently only accepts 0 or 1
use_dcp: bool = not rdma_available()
dcp_path: str = "forge_dcp_tmp"

def __post_init__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/forge/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_value(self) -> Any:

TORCHSTORE_USE_RDMA = EnvVar(
name="TORCHSTORE_RDMA_ENABLED",
default=0,
default=1,
description="Whether or not to use RDMA in TorchStore.",
)

Expand Down
Loading