diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 0d49ef9ca..85872681f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -305,7 +305,7 @@ async def main(cfg: DictConfig): provisioner = await init_provisioner() metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) # ---- Setup services ---- # diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5549a8cce..cb36b2568 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -310,11 +310,12 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh - # Spawn local fetcher actor on each process and register with global logger + # Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor. + # When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh. if not FORGE_DISABLE_METRICS.get_value(): from forge.observability.metric_actors import get_or_create_metric_logger - _ = await get_or_create_metric_logger(procs) + _ = await get_or_create_metric_logger(procs, process_name=mesh_name) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): @@ -333,14 +334,14 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): ) return async with self._lock: - # Deregister local logger from global logger - if hasattr(proc_mesh, "_local_fetcher"): + # Deregister LocalFetcherActor from GlobalLoggingActor + if hasattr(proc_mesh, "_local_fetcher") and hasattr(proc_mesh, "_uid"): from forge.observability.metric_actors import ( get_or_create_metric_logger, ) global_logger = await get_or_create_metric_logger(proc_mesh) - await global_logger.deregister_fetcher.call_one(proc_mesh) + await global_logger.deregister_fetcher.call_one(proc_mesh._uid) if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index b970e57fa..8efd3dace 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,8 +12,6 @@ from .metrics import ( BackendRole, ConsoleBackend, - get_actor_name_with_rank, - get_logger_backend_class, LoggerBackend, MaxAccumulator, MeanAccumulator, @@ -29,12 +27,12 @@ WandbBackend, ) from .perf_tracker import trace, Tracer +from .utils import get_proc_name_with_rank __all__ = [ # Main API functions "record_metric", "reduce_metrics_states", - "get_actor_name_with_rank", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -45,6 +43,8 @@ "BackendRole", # Enums "Reduce", + # Utility functions + "get_proc_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e9105afc6..45e08c418 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -6,9 +6,17 @@ import asyncio import logging +import uuid from typing import Any, Union -from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from monarch.actor import ( + Actor, + context, + endpoint, + get_or_spawn_controller, + ProcMesh, + this_proc, +) from forge.env import FORGE_DISABLE_METRICS from forge.observability.metrics import ( @@ -27,36 +35,35 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, + process_name: str | None = None, ) -> "GlobalLoggingActor": - """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), - if not already initialized, registers it with the GlobalLoggingActor and returns the - GlobalLoggingActor instance. + """Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized), + registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor. - There are primarily two ways to use this function: - 1. In the main process, call `get_or_create_metric_logger()` to get the global logger. - 2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the - local fetcher with the global logger. + Usage: + 1. Main process: call `get_or_create_metric_logger()` to get the global logger + 2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the + map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks. Args: - proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, - uses `monarch.actor.this_proc()`. + proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`. + process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None. Returns: GlobalLoggingActor: The global logging controller. Raises: - ValueError: If the logging state is inconsistent, i.e. the fetcher is already - registered, but only in the process or the global logger. + ValueError: If the logging state is inconsistent. Example: from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric # Main process setup - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") # Initialize logging backends - await mlogger.init_backends({ + await mlogger.init_backends.call_one({ "console": {"reduce_across_ranks": True}, "wandb": {"project": "my_project", "reduce_across_ranks": False} }) @@ -66,12 +73,12 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... - await mlogger.flush(step) # Log metrics for this step + await mlogger.flush.call_one(step) # Log metrics for this step # Shutdown - await mlogger.shutdown() + await mlogger.shutdown.call_one() """ # Get or create the singleton global logger global _global_logger @@ -85,9 +92,15 @@ async def get_or_create_metric_logger( # Determine process context proc = proc_mesh if proc_mesh is not None else this_proc() + # Auto-detect process_name from proc mesh if not provided + if process_name is None: + ctx = context() + process_name = ctx.actor_instance.actor_id.actor_name + # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") - global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) + proc_id = proc._uid if proc_has_local_fetcher else None + global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id) # Consistency check: both should be in sync if proc_has_local_fetcher != global_logger_has_local_fetcher: @@ -102,24 +115,32 @@ async def get_or_create_metric_logger( # Setup local_fetcher_actor if needed (unless disabled by environment flag) if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value(): local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger + "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) - await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) + # Generate a unique ID to map procmesh to fetcher + proc._uid = str(uuid.uuid4()) proc._local_fetcher = local_fetcher_actor # pyre-ignore + await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid) + return global_logger class LocalFetcherActor(Actor): - """Thin per-process actor used to trigger MetricCollector singleton - operations without direct access. It is what GlobalLoggingActor - uses to broadcast inits/flushes across ranks. + """Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh + and accesses each rank's local MetricCollector. - GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector + Flow: + GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger """ - def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: + def __init__( + self, + global_logger: Union["GlobalLoggingActor", None] = None, + process_name: str | None = None, + ) -> None: self.global_logger = global_logger + self.process_name = process_name _is_initialized = False @endpoint @@ -146,10 +167,22 @@ async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]], config: dict[str, Any], + global_step: int = 0, ) -> None: - """Init local (per-rank) logger backends and MetricCollector.""" + """Init per-rank logger backends and MetricCollector. + + Args: + metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. + config (dict[str, Any]): Backend configurations with logging modes and settings. + global_step (int): Initial step for metrics. + """ collector = MetricCollector() - await collector.init_backends(metadata_per_primary_backend, config) + await collector.init_backends( + metadata_per_primary_backend, + config, + global_step, + process_name=self.process_name, + ) @endpoint async def shutdown(self) -> None: @@ -158,22 +191,17 @@ async def shutdown(self) -> None: class GlobalLoggingActor(Actor): - """Coordinates metric logging across all ranks for every training step. + """Coordinates metric logging across all ProcMeshes and their ranks. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), - for per-rank and/or global reduction logging modes. + with per-rank and/or global reduction logging modes. If a backend config has flag `reduce_across_ranks=False`, an instance of the backend is initialized per-rank, otherwise it is done once globally. - This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor - is automatically spawned per-rank in `forge.controller.provisioner.py` and registered - with this actor. The LocalFetcherActor is responsible for instantiating - the per-rank MetricCollector. - In summary, the flow is: - - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector - - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + Flow: + GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger """ def __init__(self): @@ -209,7 +237,7 @@ async def init_backends(self, config: dict[str, Any]) -> None: for backend_name, backend_config in config.items(): backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role=BackendRole.GLOBAL) + await backend.init(role=BackendRole.GLOBAL, name="global_reduce") # Extract metadata from primary logger to be shared with secondary loggers # and store it @@ -237,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None: await asyncio.gather(*tasks, return_exceptions=True) @endpoint - async def register_fetcher( - self, fetcher: LocalFetcherActor, name: str | ProcMesh - ) -> None: - """Registers a fetcher with the global actor. Each key represents a process mesh. - If there are 2 processes, each with 2 replicas with N gpus, we would - have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" - self.fetchers[name] = fetcher # pyre-ignore + async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None: + """Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh. + + Args: + fetcher: The LocalFetcherActor instance for a ProcMesh + proc_id: Unique identifier for the ProcMesh + """ + self.fetchers[proc_id] = fetcher # Self-init for respawned actors if self.config: - logger.debug(f"Initializing new LocalFetcherActor {name}") + logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}") await fetcher.init_backends.call( self.metadata_per_primary_backend, self.config ) @endpoint - async def deregister_fetcher(self, name: str | ProcMesh) -> None: - if name not in self.fetchers: + async def deregister_fetcher(self, proc_id: str) -> None: + if proc_id not in self.fetchers: logger.warning( - f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." + f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister." f"Available fetchers: {self.fetchers.keys()}" ) return - del self.fetchers[name] + del self.fetchers[proc_id] @endpoint async def flush(self, global_step: int) -> None: @@ -333,9 +362,9 @@ async def flush(self, global_step: int) -> None: await logger_backend.log(reduced_metrics, global_step) @endpoint - def has_fetcher(self, name: str | ProcMesh) -> bool: - """Check if a fetcher is registered with the given name.""" - return name in self.fetchers + def has_fetcher(self, proc_id: str) -> bool: + """Check if a fetcher is registered with the given proc_id.""" + return proc_id in self.fetchers @endpoint def get_fetcher_count(self) -> int: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3ce849ad2..980bb89fc 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -15,6 +15,8 @@ import pytz from monarch.actor import context, current_rank +from forge.observability.utils import get_proc_name_with_rank + from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -438,11 +440,14 @@ def __init__(self) -> None: self.rank = current_rank().rank self.logger_backends: list[LoggerBackend] = [] self._is_initialized = False + self.proc_name_with_rank: str | None = None async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], + global_step: int = 0, + process_name: str | None = None, ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated @@ -452,11 +457,16 @@ async def init_backends( metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + global_step (int, default 0): Initial step for metrics. + process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return + self.global_step = global_step + self.proc_name_with_rank = get_proc_name_with_rank(process_name) + # instantiate local backends if any for backend_name, backend_config in config.items(): if backend_config.get("reduce_across_ranks", True): @@ -470,7 +480,9 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, + primary_logger_metadata=primary_metadata, + name=self.proc_name_with_rank, ) self.logger_backends.append(logger_backend) @@ -495,7 +507,8 @@ def push(self, metric: Metric) -> None: logger, level=logging.WARNING, msg=( - "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." + f"Skipping metric collection for {get_proc_name_with_rank()}." + " Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" "`mlogger = await get_or_create_metric_logger()`\n" @@ -534,7 +547,8 @@ async def flush( log_once( logger, level=logging.WARNING, - msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." + msg=f"Cannot flush collected metrics for {get_proc_name_with_rank()}. " + " MetricCollector.flush() called before init_backends()." "\nPlease call in your main file:\n" "`mlogger = await get_or_create_metric_logger()`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" @@ -544,7 +558,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" + f"Collector {self.proc_name_with_rank}: No metrics to flush for global_step {global_step}" ) return {} @@ -569,7 +583,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" + f"Collector for {self.proc_name_with_rank} not initialized. Skipping shutdown" ) return @@ -593,6 +607,7 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -602,6 +617,7 @@ async def init( Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. + name (str | None): Name for logging. Raises: ValueError if missing metadata for shared local init. """ @@ -618,6 +634,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: """ pass + @abstractmethod async def finish(self) -> None: pass @@ -636,12 +653,10 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + name: str | None = None, ) -> None: - self.prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("reduce_across_ranks", True) - else "Controller" - ) + + self.name = name async def log(self, metrics: list[Metric], global_step: int) -> None: metrics_str = "\n".join( @@ -649,7 +664,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: for metric in sorted(metrics, key=lambda m: m.key) ) logger.info( - f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + f"=== [{self.name}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" ) async def finish(self) -> None: @@ -689,16 +704,13 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + name: str | None = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = ( - get_actor_name_with_rank() - if role == BackendRole.LOCAL - else "global_controller" - ) + self.name = name # Default global mode: only inits on controller if self.reduce_across_ranks: diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py new file mode 100644 index 000000000..811bbfe41 --- /dev/null +++ b/src/forge/observability/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from monarch.actor import context, current_rank + +logger = logging.getLogger(__name__) + + +def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str: + """ + Returns a unique identifier for the current rank from Monarch actor context. + + Multiple ranks from the same ProcMesh will share the same ProcMesh hash suffix, + but have different rank numbers. + + Format: "{ProcessName}_{ProcMeshHash}_r{rank}" where: + - ProcessName: The provided proc_name (e.g., "TrainActor") or extracted from actor_name if None. + - ProcMeshHash: Hash suffix identifying the ProcMesh (e.g., "1abc2def") + - rank: Local rank within the ProcMesh (0, 1, 2, ...) + + Note: If called from the main process (e.g. main.py), returns "client_r0". + + Args: + proc_name: Optional override for process name. If None, uses actor_id.actor_name. + + Returns: + str: Unique identifier per rank (e.g., "TrainActor_1abc2def_r0" or "client_r0"). + """ + ctx = context() + actor_id = ctx.actor_instance.actor_id + actor_name = actor_id.actor_name + rank = current_rank().rank + + # If proc_name provided, extract procmesh hash from actor_name and combine + if proc_name is not None: + parts = actor_name.split("_") + if len(parts) > 1: + replica_hash = parts[-1] # (e.g., "MyActor_1abc2def" -> "1abc2def") + return f"{proc_name}_{replica_hash}_r{rank}" + else: + # if a direct process (e.g. called from main), actor_name == "client" -> len(parts) == 1 + return f"{proc_name}_r{rank}" + + # No proc_name override - use full actor_name with rank + return f"{actor_name}_r{rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index 57ccd97b5..eae50c2db 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -95,12 +95,16 @@ async def main(): } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) - trainer = await TrainActor.options(**service_config).as_service() - generator = await GeneratorActor.options(**service_config).as_service() + trainer = await TrainActor.options( + **service_config, mesh_name="TrainActor" + ).as_service() + generator = await GeneratorActor.options( + **service_config, mesh_name="GeneratorActor" + ).as_service() for i in range(3): print(f"\n=== Global Step {i} ===") diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index e9b001aa5..425352340 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -33,7 +33,7 @@ async def run(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/tests/unit_tests/observability/__init__.py b/tests/unit_tests/observability/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit_tests/observability/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py index e8900392c..d256c5d7c 100644 --- a/tests/unit_tests/observability/conftest.py +++ b/tests/unit_tests/observability/conftest.py @@ -9,32 +9,7 @@ from unittest.mock import MagicMock, patch import pytest -from forge.observability.metrics import LoggerBackend, MetricCollector - - -class MockBackend(LoggerBackend): - """Mock backend for testing metrics logging without external dependencies.""" - - def __init__(self, logger_backend_config=None): - super().__init__(logger_backend_config or {}) - self.logged_metrics = [] - self.init_called = False - self.finish_called = False - self.metadata = {} - - async def init(self, role="local", primary_logger_metadata=None): - self.init_called = True - self.role = role - self.primary_logger_metadata = primary_logger_metadata or {} - - async def log(self, metrics, step): - self.logged_metrics.append((metrics, step)) - - async def finish(self): - self.finish_called = True - - def get_metadata_for_secondary_ranks(self): - return self.metadata +from forge.observability.metrics import MetricCollector @pytest.fixture(autouse=True) diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py new file mode 100644 index 000000000..1c315b2e9 --- /dev/null +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Optimized unit tests for metric actors functionality.""" + +import pytest + +from forge.observability.metric_actors import ( + get_or_create_metric_logger, + GlobalLoggingActor, + LocalFetcherActor, +) +from monarch.actor import this_host + + +@pytest.fixture +def global_logger(): + """Create a GlobalLoggingActor for testing.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestGlobalLogger", GlobalLoggingActor) + + +@pytest.fixture +def local_fetcher(global_logger): + """Create a LocalFetcherActor linked to global logger.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) + + +class TestBasicOperations: + """Test basic operations for actors.""" + + @pytest.mark.asyncio + async def test_local_fetcher_flush(self, local_fetcher): + """Test LocalFetcherActor flush operations.""" + result_with_state = await local_fetcher.flush.call_one( + global_step=1, return_state=True + ) + assert result_with_state == {} + + result_without_state = await local_fetcher.flush.call_one( + global_step=1, return_state=False + ) + assert result_without_state == {} + + @pytest.mark.asyncio + async def test_global_logger_basic_ops(self, global_logger): + """Test GlobalLoggingActor basic operations.""" + count = await global_logger.get_fetcher_count.call_one() + assert count >= 0 + + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") + assert has_fetcher is False + + # Global logger flush (should not raise error) + await global_logger.flush.call_one(global_step=1) + + @pytest.mark.asyncio + async def test_backend_init(self, local_fetcher): + """Test backend initialization and shutdown.""" + metadata = {"wandb": {"shared_run_id": "test123"}} + config = {"console": {"reduce_across_ranks": False}} + + await local_fetcher.init_backends.call_one(metadata, config, global_step=5) + await local_fetcher.shutdown.call_one() + + +class TestRegistrationLifecycle: + """Test registration lifecycle.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_registration_lifecycle(self, global_logger, local_fetcher): + """Test complete registration/deregistration lifecycle.""" + proc_name = "lifecycle_test_proc" + + # Initial state + initial_count = await global_logger.get_fetcher_count.call_one() + assert await global_logger.has_fetcher.call_one(proc_name) is False + + # Register + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) + + # Verify registered + new_count = await global_logger.get_fetcher_count.call_one() + assert new_count == initial_count + 1 + assert await global_logger.has_fetcher.call_one(proc_name) is True + + # Deregister + await global_logger.deregister_fetcher.call_one(proc_name) + + # Verify deregistered + final_count = await global_logger.get_fetcher_count.call_one() + assert final_count == initial_count + assert await global_logger.has_fetcher.call_one(proc_name) is False + + +class TestBackendConfiguration: + """Test backend configuration validation.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_valid_backend_configs(self, global_logger): + """Test valid backend configurations.""" + # Empty config + await global_logger.init_backends.call_one({}) + + # Valid configs for different reduce_across_ranks modes + for reduce_across_ranks in [True, False]: + config = {"console": {"reduce_across_ranks": reduce_across_ranks}} + await global_logger.init_backends.call_one(config) + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_invalid_backend_configs(self, global_logger): + """Test invalid backend configurations are handled gracefully.""" + # Empty config should work + await global_logger.init_backends.call_one({}) + + # Config with only project should work + config_with_project = {"console": {"project": "test_project"}} + await global_logger.init_backends.call_one(config_with_project) + + # Config with reduce_across_ranks should work + config_with_reduce = {"console": {"reduce_across_ranks": True}} + await global_logger.init_backends.call_one(config_with_reduce) + + +class TestErrorHandling: + """Test error handling scenarios.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_deregister_nonexistent_fetcher(self, global_logger): + """Test deregistering non-existent fetcher doesn't crash.""" + await global_logger.deregister_fetcher.call_one("nonexistent_proc") + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_shutdown(self, global_logger): + """Test shutdown without issues.""" + await global_logger.shutdown.call_one() + + +class TestGetOrCreateMetricLogger: + """Test the integration function.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_get_or_create_functionality(self): + """Test get_or_create_metric_logger basic functionality.""" + result = await get_or_create_metric_logger(process_name="TestController") + + # Should return a GlobalLoggingActor mesh + assert result is not None + + # Should be able to call basic methods + count = await result.get_fetcher_count.call_one() + assert count >= 0 diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 701bda2dc..b4a8ffcdf 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -80,12 +80,9 @@ def test_new_enums_and_constants(self): assert isinstance(BackendRole.LOCAL, BackendRole) assert isinstance(BackendRole.GLOBAL, BackendRole) - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_backend_role_usage(self, mock_actor_name): + async def test_backend_role_usage(self): """Test that BackendRole constants are actually used instead of string literals.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - # Test ConsoleBackend console_backend = ConsoleBackend({}) await console_backend.init(role=BackendRole.LOCAL) @@ -295,10 +292,8 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): mock_collector_class.assert_called_once() mock_collector.push.assert_called_once() - @patch("forge.observability.metrics.get_actor_name_with_rank") - def test_wandb_backend_creation(self, mock_actor_name): + def test_wandb_backend_creation(self): """Test WandbBackend creation and basic setup without WandB dependency.""" - mock_actor_name.return_value = "TestActor_abcd_r0" config = { "project": "test_project", @@ -316,12 +311,9 @@ def test_wandb_backend_creation(self, mock_actor_name): metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_console_backend(self, mock_actor_name): + async def test_console_backend(self): """Test ConsoleBackend basic operations.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - backend = ConsoleBackend({}) await backend.init(role=BackendRole.LOCAL) @@ -425,28 +417,33 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche if hasattr(procs, "_local_fetcher"): delattr(procs, "_local_fetcher") - # Test functionality - global_logger = await get_or_create_metric_logger(proc_mesh=procs) + # Test functionality - pass explicit process_name since test bypasses provisioner + global_logger = await get_or_create_metric_logger( + proc_mesh=procs, process_name="TestProcess" + ) # Get results to check proc_has_fetcher = hasattr(procs, "_local_fetcher") - global_has_fetcher = await global_logger.has_fetcher.call_one(procs) + proc_id = procs._uid if hasattr(procs, "_uid") else None + global_has_fetcher = ( + await global_logger.has_fetcher.call_one(proc_id) if proc_id else False + ) # Assert based on expected behavior if should_register_fetchers: assert ( proc_has_fetcher - ), f"Expected process to have _local_fetcher when {env_var_value=}" + ), f"Expected process to have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}" assert ( global_has_fetcher - ), f"Expected global logger to have fetcher registered when {env_var_value=}" + ), f"Expected global logger to have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}" else: assert ( not proc_has_fetcher - ), f"Expected process to NOT have _local_fetcher when {env_var_value=}" + ), f"Expected process to NOT have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}" assert ( not global_has_fetcher - ), f"Expected global logger to NOT have fetcher registered when {env_var_value=}" + ), f"Expected global logger to NOT have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}" @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/unit_tests/observability/test_utils.py b/tests/unit_tests/observability/test_utils.py new file mode 100644 index 000000000..6b173cc42 --- /dev/null +++ b/tests/unit_tests/observability/test_utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for observability utility functions.""" + +from forge.controller.actor import ForgeActor + +from forge.observability.utils import get_proc_name_with_rank +from monarch.actor import endpoint + + +class UtilActor(ForgeActor): + """Actor for testing get_proc_name_with_rank in spawned context.""" + + @endpoint + async def get_name(self) -> str: + return get_proc_name_with_rank() + + +class TestGetProcNameWithRank: + """Tests for get_proc_name_with_rank utility.""" + + def test_direct_proc(self): + """Direct proc should return 'client_r0'.""" + assert get_proc_name_with_rank() == "client_r0" + + def test_direct_proc_with_override(self): + """Direct proc with override should use provided name.""" + result = get_proc_name_with_rank(proc_name="MyProcess") + assert result == "MyProcess_r0" + + # TODO (felipemello): currently not working with CI wheel, but passes locally + # reactive once wheel is updated with new monarch version + # @pytest.mark.timeout(10) + # @pytest.mark.asyncio + # async def test_replicas(self): + # """Test service with replicas returns unique names and hashes per replica.""" + # actor = await UtilActor.options( + # procs=1, num_replicas=2, with_gpus=False + # ).as_service() + # results = await actor.get_name.fanout() + + # assert len(results) == 2 + # assert len(set(results)) == 2 # All names are unique + # for name in results: + # assert name.startswith("UtilActor") + # assert name.endswith("_r0") + + # # Extract hashes from names (format: ActorName_replicaIdx_hash_r0) + # hashes = [name.split("_")[-2] for name in results] + # assert hashes[0] != hashes[1] # Hashes are different between replicas