Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ def __init__(self, cfg: TensorboardConfig, log_dir: Optional[str] = None):
self.writer = SummaryWriter(log_dir=log_dir)
print(f"Initialized TensorboardLogger at {log_dir}")

@staticmethod
def _coerce_to_scalar(value: Any) -> int | float | bool | str | None:
"""Coerce a value to a Python scalar for TensorBoard logging.

Returns the coerced value, or None if it can't be converted to a scalar.
"""
if isinstance(value, (int, float, bool, str)):
return value
if isinstance(value, (np.floating, np.integer, np.bool_)):
return value.item()
if isinstance(value, np.ndarray) and (value.ndim == 0 or value.size == 1):
return value.item()
if isinstance(value, torch.Tensor) and (value.ndim == 0 or value.numel() == 1):
return value.item()
# dict, list, multi-element arrays/tensors, or incompatible types
return None

def log_metrics(
self,
metrics: dict[str, Any],
Expand All @@ -137,23 +154,19 @@ def log_metrics(
step_metric: Optional step metric name (ignored in TensorBoard)
"""
for name, value in metrics.items():
# NeMo-Gym will add additional metrics like wandb histograms. However, some people will log to Tensorboard instead which may not be compatible
# This logic catches non-compatible objects being logged.
if not isinstance(value, (int, float, bool, str)):
continue

if prefix:
name = f"{prefix}/{name}"

# Skip non-scalar values that TensorBoard can't handle
if isinstance(value, (dict, list)):
scalar = self._coerce_to_scalar(value)
if scalar is None:
print(
f"Warning: Skipping non-scalar metric '{name}' for TensorBoard logging (type: {type(value).__name__})"
f"Warning: Skipping metric '{name}' for TensorBoard logging "
f"(unsupported type: {type(value).__name__})"
)
continue

try:
self.writer.add_scalar(name, value, step)
self.writer.add_scalar(name, scalar, step)
except Exception as e:
print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}")
continue
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,101 @@ def test_log_hyperparams(self, mock_summary_writer, temp_dir):
"model.hidden_size": 128,
}

@patch("nemo_rl.utils.logger.SummaryWriter")
def test_coerce_to_scalar_python_primitives(self, mock_summary_writer, temp_dir):
"""Test that Python primitives pass through unchanged."""
cfg = {"log_dir": temp_dir}
logger = TensorboardLogger(cfg, log_dir=temp_dir)

assert logger._coerce_to_scalar(42) == 42
assert logger._coerce_to_scalar(3.14) == 3.14
assert logger._coerce_to_scalar(True) is True
assert logger._coerce_to_scalar("hello") == "hello"

@patch("nemo_rl.utils.logger.SummaryWriter")
def test_coerce_to_scalar_numpy_types(self, mock_summary_writer, temp_dir):
"""Test that numpy scalar types are coerced to Python primitives."""
import numpy as np

cfg = {"log_dir": temp_dir}
logger = TensorboardLogger(cfg, log_dir=temp_dir)

# numpy scalar types
assert logger._coerce_to_scalar(np.float32(1.5)) == 1.5
assert logger._coerce_to_scalar(np.float64(2.5)) == 2.5
assert logger._coerce_to_scalar(np.int32(10)) == 10
assert logger._coerce_to_scalar(np.int64(20)) == 20
assert logger._coerce_to_scalar(np.bool_(True)) is True

# 0-d numpy arrays
assert logger._coerce_to_scalar(np.array(3.14)) == 3.14
# 1-element numpy arrays
assert logger._coerce_to_scalar(np.array([42])) == 42

# Multi-element arrays should return None
assert logger._coerce_to_scalar(np.array([1, 2, 3])) is None

@patch("nemo_rl.utils.logger.SummaryWriter")
def test_coerce_to_scalar_torch_tensors(self, mock_summary_writer, temp_dir):
"""Test that torch scalar tensors are coerced to Python primitives."""
cfg = {"log_dir": temp_dir}
logger = TensorboardLogger(cfg, log_dir=temp_dir)

# 0-d tensors
assert logger._coerce_to_scalar(torch.tensor(3.14)) == pytest.approx(3.14)
assert logger._coerce_to_scalar(torch.tensor(42)) == 42

# 1-element tensors
assert logger._coerce_to_scalar(torch.tensor([99])) == 99

# Multi-element tensors should return None
assert logger._coerce_to_scalar(torch.tensor([1, 2, 3])) is None

@patch("nemo_rl.utils.logger.SummaryWriter")
def test_coerce_to_scalar_incompatible_types(self, mock_summary_writer, temp_dir):
"""Test that incompatible types return None."""
cfg = {"log_dir": temp_dir}
logger = TensorboardLogger(cfg, log_dir=temp_dir)

assert logger._coerce_to_scalar({"key": "value"}) is None
assert logger._coerce_to_scalar([1, 2, 3]) is None
assert logger._coerce_to_scalar(None) is None
assert logger._coerce_to_scalar(object()) is None

@patch("nemo_rl.utils.logger.SummaryWriter")
def test_log_metrics_coerces_numpy_and_torch(self, mock_summary_writer, temp_dir):
"""Test that log_metrics correctly logs numpy/torch scalars."""
import numpy as np

cfg = {"log_dir": temp_dir}
logger = TensorboardLogger(cfg, log_dir=temp_dir)

metrics = {
"python_float": 1.0,
"numpy_float32": np.float32(2.0),
"numpy_float64": np.float64(3.0),
"torch_scalar": torch.tensor(4.0),
"numpy_0d": np.array(5.0),
"torch_1elem": torch.tensor([6.0]),
"skip_list": [1, 2, 3],
"skip_dict": {"a": 1},
"skip_multi_tensor": torch.tensor([1.0, 2.0]),
}
logger.log_metrics(metrics, step=1)

mock_writer = mock_summary_writer.return_value
# Should log 6 scalars, skip 3 incompatible
assert mock_writer.add_scalar.call_count == 6

# Verify each scalar was logged with correct value
calls = {c[0][0]: c[0][1] for c in mock_writer.add_scalar.call_args_list}
assert calls["python_float"] == 1.0
assert calls["numpy_float32"] == pytest.approx(2.0)
assert calls["numpy_float64"] == pytest.approx(3.0)
assert calls["torch_scalar"] == pytest.approx(4.0)
assert calls["numpy_0d"] == pytest.approx(5.0)
assert calls["torch_1elem"] == pytest.approx(6.0)


class TestWandbLogger:
"""Test the WandbLogger class."""
Expand Down
Loading