- 
                Notifications
    You must be signed in to change notification settings 
- Fork 47
fix - Metric logging work with new monarch API #451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3c69e63
              7c4e9f5
              47b89f4
              56f6c2d
              44c9883
              9d18451
              00ccf0c
              b5060f7
              f6b95ce
              2a94504
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -6,9 +6,17 @@ | |
|  | ||
| import asyncio | ||
| import logging | ||
| import uuid | ||
| from typing import Any, Union | ||
|  | ||
| from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc | ||
| from monarch.actor import ( | ||
| Actor, | ||
| context, | ||
| endpoint, | ||
| get_or_spawn_controller, | ||
| ProcMesh, | ||
| this_proc, | ||
| ) | ||
|  | ||
| from forge.env import FORGE_DISABLE_METRICS | ||
| from forge.observability.metrics import ( | ||
|  | @@ -27,36 +35,35 @@ | |
|  | ||
| async def get_or_create_metric_logger( | ||
| proc_mesh: ProcMesh | None = None, | ||
| process_name: str | None = None, | ||
| ) -> "GlobalLoggingActor": | ||
| """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), | ||
| if not already initialized, registers it with the GlobalLoggingActor and returns the | ||
| GlobalLoggingActor instance. | ||
| """Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized), | ||
| registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor. | ||
|  | ||
| There are primarily two ways to use this function: | ||
| 1. In the main process, call `get_or_create_metric_logger()` to get the global logger. | ||
| 2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the | ||
| local fetcher with the global logger. | ||
| Usage: | ||
| 1. Main process: call `get_or_create_metric_logger()` to get the global logger | ||
| 2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the | ||
| map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks. | ||
|  | ||
| Args: | ||
| proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, | ||
| uses `monarch.actor.this_proc()`. | ||
| proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`. | ||
| process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None. | ||
|  | ||
| Returns: | ||
| GlobalLoggingActor: The global logging controller. | ||
|  | ||
| Raises: | ||
| ValueError: If the logging state is inconsistent, i.e. the fetcher is already | ||
| registered, but only in the process or the global logger. | ||
| ValueError: If the logging state is inconsistent. | ||
|  | ||
| Example: | ||
| from forge.observability.metric_actors import get_or_create_metric_logger | ||
| from forge.observability.metrics import record_metric | ||
|  | ||
| # Main process setup | ||
| mlogger = await get_or_create_metric_logger() | ||
| mlogger = await get_or_create_metric_logger(process_name="Controller") | ||
|  | ||
| # Initialize logging backends | ||
| await mlogger.init_backends({ | ||
| await mlogger.init_backends.call_one({ | ||
| "console": {"reduce_across_ranks": True}, | ||
| "wandb": {"project": "my_project", "reduce_across_ranks": False} | ||
| }) | ||
|  | @@ -66,12 +73,12 @@ async def get_or_create_metric_logger( | |
|  | ||
| # Training loop | ||
| for step in range(max_steps): | ||
| record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) | ||
| record_metric("loss", 1.2, reduction_type=Reduce.MEAN) | ||
| # ... training code with record_metric() calls ... | ||
| await mlogger.flush(step) # Log metrics for this step | ||
| await mlogger.flush.call_one(step) # Log metrics for this step | ||
|  | ||
| # Shutdown | ||
| await mlogger.shutdown() | ||
| await mlogger.shutdown.call_one() | ||
| """ | ||
| # Get or create the singleton global logger | ||
| global _global_logger | ||
|  | @@ -85,9 +92,15 @@ async def get_or_create_metric_logger( | |
| # Determine process context | ||
| proc = proc_mesh if proc_mesh is not None else this_proc() | ||
|  | ||
| # Auto-detect process_name from proc mesh if not provided | ||
| if process_name is None: | ||
| ctx = context() | ||
| process_name = ctx.actor_instance.actor_id.actor_name | ||
|  | ||
| # Check current state for consistency | ||
| proc_has_local_fetcher = hasattr(proc, "_local_fetcher") | ||
| global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) | ||
| proc_id = proc._uid if proc_has_local_fetcher else None | ||
| global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id) | ||
|  | ||
| # Consistency check: both should be in sync | ||
| if proc_has_local_fetcher != global_logger_has_local_fetcher: | ||
|  | @@ -102,24 +115,32 @@ async def get_or_create_metric_logger( | |
| # Setup local_fetcher_actor if needed (unless disabled by environment flag) | ||
| if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value(): | ||
| local_fetcher_actor = proc.spawn( | ||
| "local_fetcher_actor", LocalFetcherActor, global_logger | ||
| "local_fetcher_actor", LocalFetcherActor, global_logger, process_name | ||
| ) | ||
| await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) | ||
| # Generate a unique ID to map procmesh to fetcher | ||
| proc._uid = str(uuid.uuid4()) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the fix! LGTM There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you mind approving it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little concerning about the broken CI. Wouldn't it cause all the subsequent commits to break as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think that errors are related to this PR. But let me confirm by opening a dummy PR | ||
| proc._local_fetcher = local_fetcher_actor # pyre-ignore | ||
|  | ||
| await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid) | ||
|  | ||
| return global_logger | ||
|  | ||
|  | ||
| class LocalFetcherActor(Actor): | ||
| """Thin per-process actor used to trigger MetricCollector singleton | ||
| operations without direct access. It is what GlobalLoggingActor | ||
| uses to broadcast inits/flushes across ranks. | ||
| """Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh | ||
| and accesses each rank's local MetricCollector. | ||
|  | ||
| GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector | ||
| Flow: | ||
| GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger | ||
| """ | ||
|  | ||
| def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: | ||
| def __init__( | ||
| self, | ||
| global_logger: Union["GlobalLoggingActor", None] = None, | ||
| process_name: str | None = None, | ||
| ) -> None: | ||
| self.global_logger = global_logger | ||
| self.process_name = process_name | ||
| _is_initialized = False | ||
|  | ||
| @endpoint | ||
|  | @@ -146,10 +167,22 @@ async def init_backends( | |
| self, | ||
| metadata_per_primary_backend: dict[str, dict[str, Any]], | ||
| config: dict[str, Any], | ||
| global_step: int = 0, | ||
| ) -> None: | ||
| """Init local (per-rank) logger backends and MetricCollector.""" | ||
| """Init per-rank logger backends and MetricCollector. | ||
|  | ||
| Args: | ||
| metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. | ||
| config (dict[str, Any]): Backend configurations with logging modes and settings. | ||
| global_step (int): Initial step for metrics. | ||
| """ | ||
| collector = MetricCollector() | ||
| await collector.init_backends(metadata_per_primary_backend, config) | ||
| await collector.init_backends( | ||
| metadata_per_primary_backend, | ||
| config, | ||
| global_step, | ||
| process_name=self.process_name, | ||
| ) | ||
|  | ||
| @endpoint | ||
| async def shutdown(self) -> None: | ||
|  | @@ -158,22 +191,17 @@ async def shutdown(self) -> None: | |
|  | ||
|  | ||
| class GlobalLoggingActor(Actor): | ||
| """Coordinates metric logging across all ranks for every training step. | ||
| """Coordinates metric logging across all ProcMeshes and their ranks. | ||
|  | ||
| Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), | ||
| for per-rank and/or global reduction logging modes. | ||
| with per-rank and/or global reduction logging modes. | ||
|  | ||
| If a backend config has flag `reduce_across_ranks=False`, an instance of the backend | ||
| is initialized per-rank, otherwise it is done once globally. | ||
|  | ||
| This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor | ||
| is automatically spawned per-rank in `forge.controller.provisioner.py` and registered | ||
| with this actor. The LocalFetcherActor is responsible for instantiating | ||
| the per-rank MetricCollector. | ||
|  | ||
| In summary, the flow is: | ||
| - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector | ||
| - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush | ||
| Flow: | ||
| GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger | ||
| """ | ||
|  | ||
| def __init__(self): | ||
|  | @@ -209,7 +237,7 @@ async def init_backends(self, config: dict[str, Any]) -> None: | |
|  | ||
| for backend_name, backend_config in config.items(): | ||
| backend = get_logger_backend_class(backend_name)(backend_config) | ||
| await backend.init(role=BackendRole.GLOBAL) | ||
| await backend.init(role=BackendRole.GLOBAL, name="global_reduce") | ||
|  | ||
| # Extract metadata from primary logger to be shared with secondary loggers | ||
| # and store it | ||
|  | @@ -237,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None: | |
| await asyncio.gather(*tasks, return_exceptions=True) | ||
|  | ||
| @endpoint | ||
| async def register_fetcher( | ||
| self, fetcher: LocalFetcherActor, name: str | ProcMesh | ||
| ) -> None: | ||
| """Registers a fetcher with the global actor. Each key represents a process mesh. | ||
| If there are 2 processes, each with 2 replicas with N gpus, we would | ||
| have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" | ||
| self.fetchers[name] = fetcher # pyre-ignore | ||
| async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None: | ||
| """Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh. | ||
|  | ||
| Args: | ||
| fetcher: The LocalFetcherActor instance for a ProcMesh | ||
| proc_id: Unique identifier for the ProcMesh | ||
| """ | ||
| self.fetchers[proc_id] = fetcher | ||
|  | ||
| # Self-init for respawned actors | ||
| if self.config: | ||
| logger.debug(f"Initializing new LocalFetcherActor {name}") | ||
| logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}") | ||
| await fetcher.init_backends.call( | ||
| self.metadata_per_primary_backend, self.config | ||
| ) | ||
|  | ||
| @endpoint | ||
| async def deregister_fetcher(self, name: str | ProcMesh) -> None: | ||
| if name not in self.fetchers: | ||
| async def deregister_fetcher(self, proc_id: str) -> None: | ||
| if proc_id not in self.fetchers: | ||
| logger.warning( | ||
| f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." | ||
| f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister." | ||
| f"Available fetchers: {self.fetchers.keys()}" | ||
| ) | ||
| return | ||
| del self.fetchers[name] | ||
| del self.fetchers[proc_id] | ||
|  | ||
| @endpoint | ||
| async def flush(self, global_step: int) -> None: | ||
|  | @@ -333,9 +362,9 @@ async def flush(self, global_step: int) -> None: | |
| await logger_backend.log(reduced_metrics, global_step) | ||
|  | ||
| @endpoint | ||
| def has_fetcher(self, name: str | ProcMesh) -> bool: | ||
| """Check if a fetcher is registered with the given name.""" | ||
| return name in self.fetchers | ||
| def has_fetcher(self, proc_id: str) -> bool: | ||
| """Check if a fetcher is registered with the given proc_id.""" | ||
| return proc_id in self.fetchers | ||
|  | ||
| @endpoint | ||
| def get_fetcher_count(self) -> int: | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible for a proc_mesh that has
_local_fetcherbut not_uid?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, they should always have both. I guess i was having extra safe here. Is it confusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand you write it like this to be safe. But I just worry it may hide some potential errors. How about raise an error if it has
_local_fetcherbut not_uid?