diff --git a/.gitignore b/.gitignore index 38276ac019..79a00631e6 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ dist/ # Cache uv_cache/ hf_home/ +hf_datasets_cache/ *logs/ datasets/ docker/ diff --git a/docs/design_docs/gpu_logger.md b/docs/design_docs/gpu_logger.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/design_docs/logger.md b/docs/design_docs/logger.md index cf55c442e4..fa81c7c291 100644 --- a/docs/design_docs/logger.md +++ b/docs/design_docs/logger.md @@ -78,3 +78,31 @@ When enabled, the pretty logging will generate formatted text similar to: ![Validation Pretty Logging Example](../assets/val-log.png) +## GPU Metric Logging + +Reinforcer monitors GPU memory and utilization through [system metrics](https://docs.ray.io/en/latest/ray-observability/reference/system-metrics.html#system-metrics) exposed by Ray nodes. While Ray makes these metrics available for tools like Prometheus, Reinforcer directly polls GPU memory and utilization data and logs them to TensorBoard and/or Weights & Biases. + +This approach allows us to offer the same GPU metric tracking on all loggers (not just wandb) and simplifies the implementation greatly. + +This feature is enabled with the `monitor_gpus` configuration parameter and the frequency of collection and flushing to the loggers is controlled by `gpu_collection_interval` and `gpu_flush_interval` (both in seconds), respectively: + +```python +logger: + wandb_enabled: false + tensorboard_enabled: false + monitor_gpus: true + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 +``` + +:::{note} +While monitoring through the remote workers is possible, it requires some delicate implementation details to make sure: +* sending logs back to driver does not incur a large overhead +* metrics are easily interpretable since we may be double counting due to colocated workers +* workers gracefully flush their logs in the event of failure +* the logging is the same for tensorboard and wandb +* some workers which spawn other workers correctly report the total usage of the grandchild worker + +These reasons lead us to the simple implementation of collecting on the driver +::: diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index ab2fbdf59c..827fd9cbeb 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -77,10 +77,14 @@ logger: num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal wandb_enabled: false tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: project: "grpo-dev" name: "grpo-dev-logger" tensorboard: {} + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: gpus_per_node: 1 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 2436795abb..1282285fc3 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -44,13 +44,17 @@ data: logger: log_dir: "logs" # Base directory for all logs - wandb_enabled: true + wandb_enabled: false tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: project: "sft-dev" name: "sft-dev-logger" tensorboard: log_dir: "tb_logs" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: gpus_per_node: 8 diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 72c491206f..bc11687ab0 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -195,7 +195,10 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs task_data_processors["math"] = (math_task_spec, openinstructmath2_data_processor) math_env = MathEnvironment.options( - runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} + runtime_env={ + "py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE, + "env_vars": dict(os.environ), # Pass thru all user environment variables + } ).remote(env_configs["math"]) dataset = AllTaskProcessedDataset( data.formatted_ds["train"], diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 7acfccd51b..8e04c9c929 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -137,6 +137,12 @@ def setup( logger_config = master_config["logger"] cluster_config = master_config["cluster"] + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + # ========================== # Checkpointing # ========================== @@ -238,8 +244,6 @@ def setup( ) loss_fn = ClippedPGLossFn(loss_config) - logger = Logger(logger_config) - logger.log_hyperparams(master_config) print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 90230a06ab..8504dac007 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -166,4 +166,8 @@ def __call__( num_unmasked_tokens = torch.tensor(1) loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens - return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()} + return loss, { + "loss": loss.item(), + "num_unmasked_tokens": num_unmasked_tokens.item(), + "total_tokens": mask.numel(), + } diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index b216c02724..0d06ad6366 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -61,6 +61,7 @@ class SFTConfig(TypedDict): val_at_start: bool seed: int + class MasterConfig(TypedDict): policy: PolicyConfig data: DataConfig @@ -102,6 +103,12 @@ def setup( cluster_config = master_config["cluster"] sft_config = master_config["sft"] + # ========================== + # Logger + # ========================== + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + # ========================== # Checkpointing # ========================== @@ -179,9 +186,6 @@ def setup( loss_fn = NLLLoss() print(f" ✓ Model initialized") - logger = Logger(logger_config) - logger.log_hyperparams(master_config) - print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") print("=" * 60 + "\n") diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a568dbcda6..138c3802d1 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -123,6 +123,7 @@ def masked_mean(values, mask, dim=None): return values[mask.bool()].mean() return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) + def set_seed(seed: int): """Sets the seed for python, numpy, and pytorch.""" random.seed(seed) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 2eb5598f6b..b27b575c10 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -57,7 +57,10 @@ def __repr__(self): This makes it easier to identify which worker is producing specific log messages. """ - return f"{self.__class__.__name__}[rank={torch.distributed.get_rank()}]" + if torch.distributed.is_initialized(): + return f"{self.__class__.__name__}[rank={torch.distributed.get_rank()}]" + else: + return f"{self.__class__.__name__}" def __init__( self, @@ -123,8 +126,7 @@ def do_fsdp(model): if init_optimizer: optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) self.optimizer = optimizer_cls( - self.model.parameters(), - **self.cfg["optimizer"]["kwargs"] + self.model.parameters(), **self.cfg["optimizer"]["kwargs"] ) else: self.optimizer = None @@ -830,6 +832,11 @@ def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = Non else: print("WARNING: No scheduler checkpoint provided") + def shutdown(self): + """Shutdown the policy.""" + # + pass + class HfPolicy(PolicyInterface, GenerationInterface): def __init__( diff --git a/nemo_reinforcer/utils/logger.py b/nemo_reinforcer/utils/logger.py index ea98451e96..dca4181681 100644 --- a/nemo_reinforcer/utils/logger.py +++ b/nemo_reinforcer/utils/logger.py @@ -16,9 +16,12 @@ import os import re import glob +import time +import threading +import requests from abc import ABC, abstractmethod import logging -from typing import List, Any, Dict, Optional, TypedDict +from typing import List, Any, Dict, Optional, TypedDict, Union import wandb from rich.console import Console from rich.panel import Panel @@ -28,6 +31,10 @@ from nemo_reinforcer.data.interfaces import LLMMessageLogType from torch.utils.tensorboard import SummaryWriter +import ray +from prometheus_client.parser import text_string_to_metric_families +from prometheus_client.samples import Sample + # Flag to track if rich logging has been configured _rich_logging_configured = False @@ -41,12 +48,19 @@ class TensorboardConfig(TypedDict): log_dir: str +class GPUMonitoringConfig(TypedDict): + collection_interval: int | float + flush_interval: int | float + + class LoggerConfig(TypedDict): log_dir: str wandb_enabled: bool tensorboard_enabled: bool wandb: WandbConfig tensorboard: TensorboardConfig + monitor_gpus: bool + gpu_monitoring: GPUMonitoringConfig class LoggerInterface(ABC): @@ -130,6 +144,225 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None: self.run.config.update(params) +class RayGpuMonitorLogger: + """Monitor GPU utilization across a Ray cluster and log metrics to a parent logger.""" + + def __init__( + self, + collection_interval: int | float, + flush_interval: int | float, + parent_logger: Optional["Logger"] = None, + ): + """Initialize the GPU monitor. + + Args: + collection_interval: Interval in seconds to collect GPU metrics + flush_interval: Interval in seconds to flush metrics to parent logger + parent_logger: Logger to receive the collected metrics + """ + self.collection_interval = collection_interval + self.flush_interval = flush_interval + self.parent_logger = parent_logger + self.metrics_buffer = [] # Store metrics with timestamps + self.last_flush_time = time.time() + self.is_running = False + self.collection_thread = None + self.lock = threading.Lock() + + def start(self): + """Start the GPU monitoring thread.""" + if not ray.is_initialized(): + raise ValueError( + "Ray must be initialized with nemo_reinforcer.distributed.virtual_cluster.init_ray() before the GPU logging can begin." + ) + + if self.is_running: + return + + self.start_time = time.time() + self.is_running = True + self.collection_thread = threading.Thread( + target=self._collection_loop, + daemon=True, # Make this a daemon thread so it doesn't block program exit + ) + self.collection_thread.start() + print( + f"GPU monitoring started with collection interval={self.collection_interval}s, flush interval={self.flush_interval}s" + ) + + def stop(self): + """Stop the GPU monitoring thread.""" + self.is_running = False + if self.collection_thread: + self.collection_thread.join(timeout=self.collection_interval * 2) + + # Final flush + self.flush() + print("GPU monitoring stopped") + + def _collection_loop(self): + """Main collection loop that runs in a separate thread.""" + while self.is_running: + try: + collection_time = time.time() + relative_time = collection_time - self.start_time + + # Collect metrics with timing information + metrics = self._collect_metrics() + if metrics: + with self.lock: + self.metrics_buffer.append( + { + "step": int( + relative_time + ), # Store the relative time as step + "metrics": metrics, + } + ) + + # Check if it's time to flush + current_time = time.time() + if current_time - self.last_flush_time >= self.flush_interval: + self.flush() + self.last_flush_time = current_time + + time.sleep(self.collection_interval) + except Exception as e: + print(f"Error in GPU monitoring collection loop: {e}") + time.sleep(self.collection_interval) # Continue despite errors + + def _parse_gpu_metric(self, sample: Sample, node_idx: int) -> Dict[str, Any]: + """Parse a GPU metric sample into a standardized format. + + Args: + sample: Prometheus metric sample + node_idx: Index of the node + + Returns: + Dictionary with metric name and value + """ + # TODO: Consider plumbing {'GpuDeviceName': 'NVIDIA H100 80GB HBM3'} + # Expected labels for GPU metrics + expected_labels = ["GpuIndex"] + for label in expected_labels: + if label not in sample.labels: + # This is probably a CPU node + return {} + + metric_name = sample.name + # Rename known metrics to match wandb naming convention + if metric_name == "ray_node_gpus_utilization": + metric_name = "gpu" + elif metric_name == "ray_node_gram_used": + metric_name = "memory" + else: + # Skip unexpected metrics + return {} + + labels = sample.labels + index = labels["GpuIndex"] + value = sample.value + + metric_name = f"node.{node_idx}.gpu.{index}.{metric_name}" + return {metric_name: value} + + def _collect_metrics(self) -> Dict[str, Any]: + """Collect GPU metrics from all Ray nodes. + + Returns: + Dictionary of collected metrics + """ + if not ray.is_initialized(): + print("Ray is not initialized. Cannot collect GPU metrics.") + return {} + + try: + nodes = ray.nodes() + if not nodes: + print("No Ray nodes found.") + return {} + + # Use a dictionary to keep unique metric endpoints and maintain order + unique_metric_addresses = {} + for node in nodes: + node_ip = node["NodeManagerAddress"] + metrics_port = node.get("MetricsExportPort") + if not metrics_port: + continue + metrics_address = f"{node_ip}:{metrics_port}" + unique_metric_addresses[metrics_address] = True + + # Process each node's metrics + collected_metrics = {} + for node_idx, metric_address in enumerate(unique_metric_addresses): + gpu_metrics = self._fetch_and_parse_metrics(node_idx, metric_address) + collected_metrics.update(gpu_metrics) + + return collected_metrics + + except Exception as e: + print(f"Error collecting GPU metrics: {e}") + return {} + + def _fetch_and_parse_metrics(self, node_idx, metric_address): + """Fetch metrics from a node and parse GPU metrics. + + Args: + node_idx: Index of the node + metric_address: Address of the metrics endpoint + + Returns: + Dictionary of GPU metrics + """ + url = f"http://{metric_address}/metrics" + + try: + response = requests.get(url, timeout=5.0) + if response.status_code != 200: + print(f"Error: Status code {response.status_code}") + return {} + + metrics_text = response.text + gpu_metrics = {} + + # Parse the Prometheus format + for family in text_string_to_metric_families(metrics_text): + # Skip non-GPU metrics + if family.name not in ( + "ray_node_gram_used", + "ray_node_gpus_utilization", + ): + continue + + for sample in family.samples: + metrics = self._parse_gpu_metric(sample, node_idx) + gpu_metrics.update(metrics) + + return gpu_metrics + + except Exception as e: + print(f"Error fetching metrics from {metric_address}: {e}") + return {} + + def flush(self): + """Flush collected metrics to the parent logger.""" + if not self.parent_logger: + return + + with self.lock: + if not self.metrics_buffer: + return + + # Log each set of metrics with its original step + for entry in self.metrics_buffer: + step = entry["step"] + metrics = entry["metrics"] + self.parent_logger.log_metrics(metrics, step, prefix="ray") + + # Clear buffer after logging + self.metrics_buffer = [] + + class Logger(LoggerInterface): """Main logger class that delegates to multiple backend loggers.""" @@ -142,6 +375,9 @@ def __init__(self, cfg: LoggerConfig): - tensorboard_enabled - wandb - tensorboard + - monitor_gpus + - gpu_collection_interval + - gpu_flush_interval """ self.loggers = [] @@ -162,6 +398,16 @@ def __init__(self, cfg: LoggerConfig): ) self.loggers.append(tensorboard_logger) + # Initialize GPU monitoring if requested + self.gpu_monitor = None + if cfg["monitor_gpus"]: + self.gpu_monitor = RayGpuMonitorLogger( + collection_interval=cfg["gpu_monitoring"]["collection_interval"], + flush_interval=cfg["gpu_monitoring"]["flush_interval"], + parent_logger=self, + ) + self.gpu_monitor.start() + if not self.loggers: print("No loggers initialized") @@ -187,6 +433,11 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None: for logger in self.loggers: logger.log_hyperparams(params) + def __del__(self): + """Clean up resources when the logger is destroyed.""" + if self.gpu_monitor: + self.gpu_monitor.stop() + def flatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: """Flatten a nested dictionary. diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 7f6d8784e5..9b2eb4e21c 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -15,13 +15,17 @@ import ray from nemo_reinforcer.environments.math_environment import MathEnvironment import time +import os @pytest.fixture(scope="module") def math_env(): """Create a MathEnvironment actor for testing.""" env = MathEnvironment.options( - runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} + runtime_env={ + "py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE, + "env_vars": dict(os.environ), + } ).remote({"num_workers": 2}) yield env # Clean up the actor and wait for it to be killed diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 8e013cc1ea..9f6724923a 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -23,15 +23,6 @@ from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig -# Skip all tests if no CUDA or vLLM -pytestmark = [ - pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 1, - reason="CUDA not available or insufficient GPUs", - ) -] - - # Define basic vLLM test config basic_vllm_test_config: VllmConfig = { "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing @@ -65,7 +56,7 @@ def cluster(): bundle_ct_per_node_list=[2], # 1 node with 2 GPU bundle use_gpus=True, max_colocated_worker_groups=2, - num_gpus_per_node=torch.cuda.device_count(), # Use available GPUs + num_gpus_per_node=2, # Use available GPUs name="vllm-test-cluster", ) yield virtual_cluster @@ -192,6 +183,15 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): "max_new_tokens": 16, "do_sample": False, "precision": "float32", + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, } vllm_policy = None @@ -349,10 +349,6 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): def test_vllm_policy_tensor_parallel(cluster, tokenizer): """Test vLLM policy with tensor parallelism > 1.""" - # Skip if less than 2 GPUs are available - if torch.cuda.device_count() < 2: - pytest.skip("Tensor parallelism test requires at least 2 GPUs") - # Configure with tensor_parallel_size=2 tp_config = basic_vllm_test_config.copy() tp_config["tensor_parallel_size"] = 2 @@ -411,12 +407,6 @@ def test_vllm_policy_tensor_parallel(cluster, tokenizer): @pytest.mark.parametrize("tensor_parallel_size", [1, 2]) def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): """Test that weights can be updated from HF to vLLM policy.""" - # Skip if requesting tensor_parallel_size=2 but less than 2 GPUs are available - if tensor_parallel_size > 1 and torch.cuda.device_count() < 2: - pytest.skip( - f"Tensor parallelism test with tp={tensor_parallel_size} requires at least {tensor_parallel_size} GPUs" - ) - # Create HF policy from nemo_reinforcer.models.policy.hf_policy import HfPolicy @@ -439,6 +429,15 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): "max_new_tokens": 16, "do_sample": False, "precision": "float32", + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, } hf_policy = HfPolicy(cluster, hf_config) diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 9b825e5302..ae80b3fd1e 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -40,6 +40,15 @@ "top_p": 1.0, "top_k": None, }, + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, "scheduler": { "name": "torch.optim.lr_scheduler.CosineAnnealingLR", "kwargs": { diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py index d54c8748f5..a8d982266b 100644 --- a/tests/unit/utils/test_logger.py +++ b/tests/unit/utils/test_logger.py @@ -22,6 +22,7 @@ Logger, TensorboardLogger, WandbLogger, + RayGpuMonitorLogger, flatten_dict, ) @@ -199,6 +200,398 @@ def test_log_hyperparams(self, mock_wandb): mock_run.config.update.assert_called_once_with(params) +class TestRayGpuMonitorLogger: + """Test the RayGpuMonitorLogger class.""" + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for logs.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + @pytest.fixture + def mock_parent_logger(self): + """Create a mock parent logger.""" + + class MockLogger: + def __init__(self): + self.logged_metrics = [] + self.logged_steps = [] + self.logged_prefixes = [] + + def log_metrics(self, metrics, step, prefix=""): + self.logged_metrics.append(metrics) + self.logged_steps.append(step) + self.logged_prefixes.append(prefix) + + return MockLogger() + + @patch("nemo_reinforcer.utils.logger.ray") + def test_init(self, mock_ray): + """Test initialization of RayGpuMonitorLogger.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Initialize the monitor with standard settings + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Verify initialization parameters + assert monitor.collection_interval == 10.0 + assert monitor.flush_interval == 60.0 + assert monitor.parent_logger is None + assert monitor.metrics_buffer == [] + assert monitor.is_running is False + assert monitor.collection_thread is None + + @patch("nemo_reinforcer.utils.logger.ray") + @patch("nemo_reinforcer.utils.logger.threading.Thread") + def test_start(self, mock_thread, mock_ray): + """Test start method of RayGpuMonitorLogger.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Start the monitor + monitor.start() + + # Verify thread was created and started + mock_thread.assert_called_once() + mock_thread.return_value.start.assert_called_once() + + # Verify monitor state + assert monitor.is_running is True + assert monitor.collection_thread is mock_thread.return_value + + @patch("nemo_reinforcer.utils.logger.ray") + def test_start_ray_not_initialized(self, mock_ray): + """Test start method when Ray is not initialized.""" + # Mock ray.is_initialized to return False + mock_ray.is_initialized.return_value = False + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Starting should raise a ValueError + with pytest.raises(ValueError): + monitor.start() + + @patch("nemo_reinforcer.utils.logger.ray") + @patch("nemo_reinforcer.utils.logger.threading.Thread") + def test_stop(self, mock_thread, mock_ray): + """Test stop method of RayGpuMonitorLogger.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Start the monitor + monitor.start() + + # Create a spy for the flush method + with patch.object(monitor, "flush") as mock_flush: + # Stop the monitor + monitor.stop() + + # Verify flush was called + mock_flush.assert_called_once() + + # Verify monitor state + assert monitor.is_running is False + + @patch("nemo_reinforcer.utils.logger.ray") + def test_parse_gpu_metric(self, mock_ray): + """Test _parse_gpu_metric method.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Create a sample with GPU utilization metric + from prometheus_client.samples import Sample + + utilization_sample = Sample( + name="ray_node_gpus_utilization", + labels={"GpuIndex": "0", "GpuDeviceName": "NVIDIA Test GPU"}, + value=75.5, + timestamp=None, + exemplar=None, + ) + + # Parse the sample + result = monitor._parse_gpu_metric(utilization_sample, node_idx=1) + + # Verify the result + assert result == {"node.1.gpu.0.gpu": 75.5} + + # Create a sample with GPU memory metric + memory_sample = Sample( + name="ray_node_gram_used", + labels={"GpuIndex": "0", "GpuDeviceName": "NVIDIA Test GPU"}, + value=4096.0, + timestamp=None, + exemplar=None, + ) + + # Parse the sample + result = monitor._parse_gpu_metric(memory_sample, node_idx=1) + + # Verify the result + assert result == {"node.1.gpu.0.memory": 4096.0} + + # Test with an unexpected metric name + other_sample = Sample( + name="ray_node_cpu_utilization", + labels={"GpuIndex": "0"}, + value=50.0, + timestamp=None, + exemplar=None, + ) + + # Parse the sample + result = monitor._parse_gpu_metric(other_sample, node_idx=1) + + # Verify the result is empty + assert result == {} + + # Test with missing GpuIndex label + invalid_sample = Sample( + name="ray_node_gpus_utilization", + labels={"OtherLabel": "value"}, + value=75.5, + timestamp=None, + exemplar=None, + ) + + # Parse the sample + result = monitor._parse_gpu_metric(invalid_sample, node_idx=1) + + # Verify the result is empty + assert result == {} + + @patch("nemo_reinforcer.utils.logger.ray") + @patch("nemo_reinforcer.utils.logger.requests.get") + def test_fetch_and_parse_metrics(self, mock_get, mock_ray): + """Test _fetch_and_parse_metrics method.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Set up mock response with Prometheus metrics + mock_response = mock_get.return_value + mock_response.status_code = 200 + # Simplified Prometheus format text with GPU metrics + mock_response.text = """ +# HELP ray_node_gpus_utilization GPU utilization +# TYPE ray_node_gpus_utilization gauge +ray_node_gpus_utilization{GpuIndex="0",GpuDeviceName="NVIDIA Test GPU"} 75.5 +# HELP ray_node_gram_used GPU memory used +# TYPE ray_node_gram_used gauge +ray_node_gram_used{GpuIndex="0",GpuDeviceName="NVIDIA Test GPU"} 4096.0 + """ + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Mock the _parse_gpu_metric method to return expected values + with patch.object(monitor, "_parse_gpu_metric") as mock_parse: + mock_parse.side_effect = [ + {"node.2.gpu.0.gpu": 75.5}, + {"node.2.gpu.0.memory": 4096.0}, + ] + + # Call the method + result = monitor._fetch_and_parse_metrics( + node_idx=2, metric_address="test_ip:test_port" + ) + + # Verify request was made correctly + mock_get.assert_called_once_with( + "http://test_ip:test_port/metrics", timeout=5.0 + ) + + # Verify parsing was done for both metrics + assert mock_parse.call_count == 2 + + # Verify the result combines both metrics + assert result == {"node.2.gpu.0.gpu": 75.5, "node.2.gpu.0.memory": 4096.0} + + @patch("nemo_reinforcer.utils.logger.ray") + def test_collect_metrics(self, mock_ray): + """Test _collect_metrics method.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Mock ray.nodes to return test nodes + mock_ray.nodes.return_value = [ + {"NodeManagerAddress": "10.0.0.1", "MetricsExportPort": 8080}, + {"NodeManagerAddress": "10.0.0.2", "MetricsExportPort": 8080}, + ] + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Mock the _fetch_and_parse_metrics method + with patch.object(monitor, "_fetch_and_parse_metrics") as mock_fetch: + mock_fetch.side_effect = [ + {"node.0.gpu.0.gpu": 75.5, "node.0.gpu.0.memory": 4096.0}, + {"node.1.gpu.0.gpu": 50.0, "node.1.gpu.0.memory": 2048.0}, + ] + + # Call the method + result = monitor._collect_metrics() + + # Verify _fetch_and_parse_metrics was called for each node + assert mock_fetch.call_count == 2 + mock_fetch.assert_any_call(0, "10.0.0.1:8080") + mock_fetch.assert_any_call(1, "10.0.0.2:8080") + + # Verify the result combines metrics from all nodes + assert result == { + "node.0.gpu.0.gpu": 75.5, + "node.0.gpu.0.memory": 4096.0, + "node.1.gpu.0.gpu": 50.0, + "node.1.gpu.0.memory": 2048.0, + } + + @patch("nemo_reinforcer.utils.logger.ray") + def test_flush_empty_buffer(self, mock_ray, mock_parent_logger): + """Test flush method with empty buffer.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Initialize the monitor with parent logger + monitor = RayGpuMonitorLogger( + collection_interval=10.0, + flush_interval=60.0, + parent_logger=mock_parent_logger, + ) + + # Call flush with empty buffer + monitor.flush() + + # Verify parent logger's log_metrics was not called + assert len(mock_parent_logger.logged_metrics) == 0 + + @patch("nemo_reinforcer.utils.logger.ray") + def test_flush(self, mock_ray, mock_parent_logger): + """Test flush method with metrics in buffer.""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Initialize the monitor with parent logger + monitor = RayGpuMonitorLogger( + collection_interval=10.0, + flush_interval=60.0, + parent_logger=mock_parent_logger, + ) + + # Add test metrics to buffer + monitor.metrics_buffer = [ + { + "step": 10, + "metrics": {"node.0.gpu.0.gpu": 75.5, "node.0.gpu.0.memory": 4096.0}, + }, + { + "step": 20, + "metrics": {"node.0.gpu.0.gpu": 80.0, "node.0.gpu.0.memory": 5120.0}, + }, + ] + + # Call flush + monitor.flush() + + # Verify parent logger's log_metrics was called for each entry + assert len(mock_parent_logger.logged_metrics) == 2 + assert mock_parent_logger.logged_metrics[0] == { + "node.0.gpu.0.gpu": 75.5, + "node.0.gpu.0.memory": 4096.0, + } + assert mock_parent_logger.logged_steps[0] == 10 + assert mock_parent_logger.logged_prefixes[0] == "ray" + + assert mock_parent_logger.logged_metrics[1] == { + "node.0.gpu.0.gpu": 80.0, + "node.0.gpu.0.memory": 5120.0, + } + assert mock_parent_logger.logged_steps[1] == 20 + assert mock_parent_logger.logged_prefixes[1] == "ray" + + # Verify buffer was cleared + assert monitor.metrics_buffer == [] + + @patch("nemo_reinforcer.utils.logger.ray") + @patch("nemo_reinforcer.utils.logger.time") + def test_collection_loop(self, mock_time, mock_ray): + """Test _collection_loop method (one iteration).""" + # Mock ray.is_initialized to return True + mock_ray.is_initialized.return_value = True + + # Set up time mocks for a single iteration + mock_time.time.side_effect = [ + 100.0, + 110.0, + 170.0, + 180.0, + ] # start_time, collection_time, flush_check_time, sleep_until + + # Initialize the monitor + monitor = RayGpuMonitorLogger( + collection_interval=10.0, flush_interval=60.0, parent_logger=None + ) + + # Set start time and running flag + monitor.start_time = 100.0 + monitor.is_running = True + + # Create a flag to only run one iteration + monitor.iteration_done = False + + def side_effect(): + if not monitor.iteration_done: + monitor.iteration_done = True + return {"node.0.gpu.0.gpu": 75.5} + else: + monitor.is_running = False + return {} + + # Mock _collect_metrics to return test metrics + with patch.object(monitor, "_collect_metrics", side_effect=side_effect): + # Mock flush method + with patch.object(monitor, "flush") as mock_flush: + # Run the collection loop (will stop after one iteration) + monitor._collection_loop() + + # Verify monitor.metrics_buffer has the collected metrics + assert len(monitor.metrics_buffer) == 1 + assert ( + monitor.metrics_buffer[0]["step"] == 10 + ) # relative time (110 - 100) + assert monitor.metrics_buffer[0]["metrics"] == { + "node.0.gpu.0.gpu": 75.5 + } + + # Verify flush was called (flush_interval elapsed) + mock_flush.assert_called_once() + + class TestLogger: """Test the main Logger class.""" @@ -216,6 +609,7 @@ def test_init_no_loggers(self, mock_tb_logger, mock_wandb_logger, temp_dir): cfg = { "wandb_enabled": False, "tensorboard_enabled": False, + "monitor_gpus": False, "log_dir": temp_dir, } logger = Logger(cfg) @@ -231,6 +625,7 @@ def test_init_wandb_only(self, mock_tb_logger, mock_wandb_logger, temp_dir): cfg = { "wandb_enabled": True, "tensorboard_enabled": False, + "monitor_gpus": False, "wandb": {"project": "test-project"}, "log_dir": temp_dir, } @@ -249,6 +644,7 @@ def test_init_tensorboard_only(self, mock_tb_logger, mock_wandb_logger, temp_dir cfg = { "wandb_enabled": False, "tensorboard_enabled": True, + "monitor_gpus": False, "tensorboard": {"log_dir": "test_logs"}, "log_dir": temp_dir, } @@ -267,6 +663,7 @@ def test_init_both_loggers(self, mock_tb_logger, mock_wandb_logger, temp_dir): cfg = { "wandb_enabled": True, "tensorboard_enabled": True, + "monitor_gpus": False, "wandb": {"project": "test-project"}, "tensorboard": {"log_dir": "test_logs"}, "log_dir": temp_dir, @@ -289,6 +686,7 @@ def test_log_metrics(self, mock_tb_logger, mock_wandb_logger, temp_dir): cfg = { "wandb_enabled": True, "tensorboard_enabled": True, + "monitor_gpus": False, "wandb": {"project": "test-project"}, "tensorboard": {"log_dir": "test_logs"}, "log_dir": temp_dir, @@ -314,6 +712,7 @@ def test_log_hyperparams(self, mock_tb_logger, mock_wandb_logger, temp_dir): cfg = { "wandb_enabled": True, "tensorboard_enabled": True, + "monitor_gpus": False, "wandb": {"project": "test-project"}, "tensorboard": {"log_dir": "test_logs"}, "log_dir": temp_dir, @@ -330,3 +729,38 @@ def test_log_hyperparams(self, mock_tb_logger, mock_wandb_logger, temp_dir): # Check that log_hyperparams was called on both loggers mock_wandb_instance.log_hyperparams.assert_called_once_with(params) mock_tb_instance.log_hyperparams.assert_called_once_with(params) + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + @patch("nemo_reinforcer.utils.logger.RayGpuMonitorLogger") + def test_init_with_gpu_monitoring( + self, mock_gpu_monitor, mock_tb_logger, mock_wandb_logger, temp_dir + ): + """Test initialization with GPU monitoring enabled.""" + cfg = { + "wandb_enabled": True, + "tensorboard_enabled": True, + "monitor_gpus": True, + "gpu_monitoring": { + "collection_interval": 15.0, + "flush_interval": 45.0, + }, + "wandb": {"project": "test-project"}, + "tensorboard": {"log_dir": "test_logs"}, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + # Check that regular loggers were initialized + assert len(logger.loggers) == 2 + mock_wandb_logger.assert_called_once() + mock_tb_logger.assert_called_once() + + # Check that GPU monitor was initialized with correct parameters + mock_gpu_monitor.assert_called_once_with( + collection_interval=15.0, flush_interval=45.0, parent_logger=logger + ) + + # Check that GPU monitor was started + mock_gpu_instance = mock_gpu_monitor.return_value + mock_gpu_instance.start.assert_called_once()