diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index ed431986dc..f329dd70c7 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -121,6 +121,23 @@ def __init__(self, cfg: TensorboardConfig, log_dir: Optional[str] = None): self.writer = SummaryWriter(log_dir=log_dir) print(f"Initialized TensorboardLogger at {log_dir}") + @staticmethod + def _coerce_to_scalar(value: Any) -> int | float | bool | str | None: + """Coerce a value to a Python scalar for TensorBoard logging. + + Returns the coerced value, or None if it can't be converted to a scalar. + """ + if isinstance(value, (int, float, bool, str)): + return value + if isinstance(value, (np.floating, np.integer, np.bool_)): + return value.item() + if isinstance(value, np.ndarray) and (value.ndim == 0 or value.size == 1): + return value.item() + if isinstance(value, torch.Tensor) and (value.ndim == 0 or value.numel() == 1): + return value.item() + # dict, list, multi-element arrays/tensors, or incompatible types + return None + def log_metrics( self, metrics: dict[str, Any], @@ -137,23 +154,19 @@ def log_metrics( step_metric: Optional step metric name (ignored in TensorBoard) """ for name, value in metrics.items(): - # NeMo-Gym will add additional metrics like wandb histograms. However, some people will log to Tensorboard instead which may not be compatible - # This logic catches non-compatible objects being logged. - if not isinstance(value, (int, float, bool, str)): - continue - if prefix: name = f"{prefix}/{name}" - # Skip non-scalar values that TensorBoard can't handle - if isinstance(value, (dict, list)): + scalar = self._coerce_to_scalar(value) + if scalar is None: print( - f"Warning: Skipping non-scalar metric '{name}' for TensorBoard logging (type: {type(value).__name__})" + f"Warning: Skipping metric '{name}' for TensorBoard logging " + f"(unsupported type: {type(value).__name__})" ) continue try: - self.writer.add_scalar(name, value, step) + self.writer.add_scalar(name, scalar, step) except Exception as e: print(f"Warning: Failed to log metric '{name}' to TensorBoard: {e}") continue diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py index 679ccf9a8e..d88137746a 100644 --- a/tests/unit/utils/test_logger.py +++ b/tests/unit/utils/test_logger.py @@ -128,6 +128,101 @@ def test_log_hyperparams(self, mock_summary_writer, temp_dir): "model.hidden_size": 128, } + @patch("nemo_rl.utils.logger.SummaryWriter") + def test_coerce_to_scalar_python_primitives(self, mock_summary_writer, temp_dir): + """Test that Python primitives pass through unchanged.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + assert logger._coerce_to_scalar(42) == 42 + assert logger._coerce_to_scalar(3.14) == 3.14 + assert logger._coerce_to_scalar(True) is True + assert logger._coerce_to_scalar("hello") == "hello" + + @patch("nemo_rl.utils.logger.SummaryWriter") + def test_coerce_to_scalar_numpy_types(self, mock_summary_writer, temp_dir): + """Test that numpy scalar types are coerced to Python primitives.""" + import numpy as np + + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + # numpy scalar types + assert logger._coerce_to_scalar(np.float32(1.5)) == 1.5 + assert logger._coerce_to_scalar(np.float64(2.5)) == 2.5 + assert logger._coerce_to_scalar(np.int32(10)) == 10 + assert logger._coerce_to_scalar(np.int64(20)) == 20 + assert logger._coerce_to_scalar(np.bool_(True)) is True + + # 0-d numpy arrays + assert logger._coerce_to_scalar(np.array(3.14)) == 3.14 + # 1-element numpy arrays + assert logger._coerce_to_scalar(np.array([42])) == 42 + + # Multi-element arrays should return None + assert logger._coerce_to_scalar(np.array([1, 2, 3])) is None + + @patch("nemo_rl.utils.logger.SummaryWriter") + def test_coerce_to_scalar_torch_tensors(self, mock_summary_writer, temp_dir): + """Test that torch scalar tensors are coerced to Python primitives.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + # 0-d tensors + assert logger._coerce_to_scalar(torch.tensor(3.14)) == pytest.approx(3.14) + assert logger._coerce_to_scalar(torch.tensor(42)) == 42 + + # 1-element tensors + assert logger._coerce_to_scalar(torch.tensor([99])) == 99 + + # Multi-element tensors should return None + assert logger._coerce_to_scalar(torch.tensor([1, 2, 3])) is None + + @patch("nemo_rl.utils.logger.SummaryWriter") + def test_coerce_to_scalar_incompatible_types(self, mock_summary_writer, temp_dir): + """Test that incompatible types return None.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + assert logger._coerce_to_scalar({"key": "value"}) is None + assert logger._coerce_to_scalar([1, 2, 3]) is None + assert logger._coerce_to_scalar(None) is None + assert logger._coerce_to_scalar(object()) is None + + @patch("nemo_rl.utils.logger.SummaryWriter") + def test_log_metrics_coerces_numpy_and_torch(self, mock_summary_writer, temp_dir): + """Test that log_metrics correctly logs numpy/torch scalars.""" + import numpy as np + + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + metrics = { + "python_float": 1.0, + "numpy_float32": np.float32(2.0), + "numpy_float64": np.float64(3.0), + "torch_scalar": torch.tensor(4.0), + "numpy_0d": np.array(5.0), + "torch_1elem": torch.tensor([6.0]), + "skip_list": [1, 2, 3], + "skip_dict": {"a": 1}, + "skip_multi_tensor": torch.tensor([1.0, 2.0]), + } + logger.log_metrics(metrics, step=1) + + mock_writer = mock_summary_writer.return_value + # Should log 6 scalars, skip 3 incompatible + assert mock_writer.add_scalar.call_count == 6 + + # Verify each scalar was logged with correct value + calls = {c[0][0]: c[0][1] for c in mock_writer.add_scalar.call_args_list} + assert calls["python_float"] == 1.0 + assert calls["numpy_float32"] == pytest.approx(2.0) + assert calls["numpy_float64"] == pytest.approx(3.0) + assert calls["torch_scalar"] == pytest.approx(4.0) + assert calls["numpy_0d"] == pytest.approx(5.0) + assert calls["torch_1elem"] == pytest.approx(6.0) + class TestWandbLogger: """Test the WandbLogger class."""