Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dist/
# Cache
uv_cache/
hf_home/
hf_datasets_cache/
*logs/
datasets/
docker/
Expand Down
Empty file.
28 changes: 28 additions & 0 deletions docs/design_docs/logger.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,31 @@ When enabled, the pretty logging will generate formatted text similar to:

![Validation Pretty Logging Example](../assets/val-log.png)

## GPU Metric Logging

Reinforcer monitors GPU memory and utilization through [system metrics](https://docs.ray.io/en/latest/ray-observability/reference/system-metrics.html#system-metrics) exposed by Ray nodes. While Ray makes these metrics available for tools like Prometheus, Reinforcer directly polls GPU memory and utilization data and logs them to TensorBoard and/or Weights & Biases.

This approach allows us to offer the same GPU metric tracking on all loggers (not just wandb) and simplifies the implementation greatly.

This feature is enabled with the `monitor_gpus` configuration parameter and the frequency of collection and flushing to the loggers is controlled by `gpu_collection_interval` and `gpu_flush_interval` (both in seconds), respectively:

```python
logger:
wandb_enabled: false
tensorboard_enabled: false
monitor_gpus: true
gpu_monitoring:
collection_interval: 10
flush_interval: 10
```

:::{note}
While monitoring through the remote workers is possible, it requires some delicate implementation details to make sure:
* sending logs back to driver does not incur a large overhead
* metrics are easily interpretable since we may be double counting due to colocated workers
* workers gracefully flush their logs in the event of failure
* the logging is the same for tensorboard and wandb
* some workers which spawn other workers correctly report the total usage of the grandchild worker

These reasons lead us to the simple implementation of collecting on the driver
:::
4 changes: 4 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ logger:
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: false
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "grpo-dev"
name: "grpo-dev-logger"
tensorboard: {}
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 1
Expand Down
6 changes: 5 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ data:

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: true
wandb_enabled: false
tensorboard_enabled: false
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "sft-dev"
name: "sft-dev-logger"
tensorboard:
log_dir: "tb_logs"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
Expand Down
5 changes: 4 additions & 1 deletion examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs
task_data_processors["math"] = (math_task_spec, openinstructmath2_data_processor)

math_env = MathEnvironment.options(
runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE}
runtime_env={
"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE,
"env_vars": dict(os.environ), # Pass thru all user environment variables
}
).remote(env_configs["math"])
dataset = AllTaskProcessedDataset(
data.formatted_ds["train"],
Expand Down
8 changes: 6 additions & 2 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ def setup(
logger_config = master_config["logger"]
cluster_config = master_config["cluster"]

# ==========================
# Logger
# ==========================
logger = Logger(logger_config)
logger.log_hyperparams(master_config)

# ==========================
# Checkpointing
# ==========================
Expand Down Expand Up @@ -238,8 +244,6 @@ def setup(
)

loss_fn = ClippedPGLossFn(loss_config)
logger = Logger(logger_config)
logger.log_hyperparams(master_config)

print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
Expand Down
6 changes: 5 additions & 1 deletion nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,8 @@ def __call__(
num_unmasked_tokens = torch.tensor(1)
loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens

return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()}
return loss, {
"loss": loss.item(),
"num_unmasked_tokens": num_unmasked_tokens.item(),
"total_tokens": mask.numel(),
}
10 changes: 7 additions & 3 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class SFTConfig(TypedDict):
val_at_start: bool
seed: int


class MasterConfig(TypedDict):
policy: PolicyConfig
data: DataConfig
Expand Down Expand Up @@ -102,6 +103,12 @@ def setup(
cluster_config = master_config["cluster"]
sft_config = master_config["sft"]

# ==========================
# Logger
# ==========================
logger = Logger(logger_config)
logger.log_hyperparams(master_config)

# ==========================
# Checkpointing
# ==========================
Expand Down Expand Up @@ -179,9 +186,6 @@ def setup(
loss_fn = NLLLoss()
print(f" ✓ Model initialized")

logger = Logger(logger_config)
logger.log_hyperparams(master_config)

print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
print("=" * 60 + "\n")
Expand Down
1 change: 1 addition & 0 deletions nemo_reinforcer/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def masked_mean(values, mask, dim=None):
return values[mask.bool()].mean()
return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)


def set_seed(seed: int):
"""Sets the seed for python, numpy, and pytorch."""
random.seed(seed)
Expand Down
13 changes: 10 additions & 3 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def __repr__(self):

This makes it easier to identify which worker is producing specific log messages.
"""
return f"{self.__class__.__name__}[rank={torch.distributed.get_rank()}]"
if torch.distributed.is_initialized():
return f"{self.__class__.__name__}[rank={torch.distributed.get_rank()}]"
else:
return f"{self.__class__.__name__}"

def __init__(
self,
Expand Down Expand Up @@ -123,8 +126,7 @@ def do_fsdp(model):
if init_optimizer:
optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"])
self.optimizer = optimizer_cls(
self.model.parameters(),
**self.cfg["optimizer"]["kwargs"]
self.model.parameters(), **self.cfg["optimizer"]["kwargs"]
)
else:
self.optimizer = None
Expand Down Expand Up @@ -830,6 +832,11 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non
else:
print("WARNING: No scheduler checkpoint provided")

def shutdown(self):
"""Shutdown the policy."""
#
pass


class HfPolicy(PolicyInterface, GenerationInterface):
def __init__(
Expand Down
Loading
Loading