Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

# ---- Setup services ---- #
Expand Down
11 changes: 6 additions & 5 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,12 @@ def bootstrap(env: dict[str, str]):

self._proc_host_map[procs] = host_mesh

# Spawn local fetcher actor on each process and register with global logger
# Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
# When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
if not FORGE_DISABLE_METRICS.get_value():
from forge.observability.metric_actors import get_or_create_metric_logger

_ = await get_or_create_metric_logger(procs)
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand All @@ -333,14 +334,14 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
)
return
async with self._lock:
# Deregister local logger from global logger
if hasattr(proc_mesh, "_local_fetcher"):
# Deregister LocalFetcherActor from GlobalLoggingActor
if hasattr(proc_mesh, "_local_fetcher") and hasattr(proc_mesh, "_uid"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for a proc_mesh that has _local_fetcher but not _uid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, they should always have both. I guess i was having extra safe here. Is it confusing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you write it like this to be safe. But I just worry it may hide some potential errors. How about raise an error if it has _local_fetcher but not _uid?

from forge.observability.metric_actors import (
get_or_create_metric_logger,
)

global_logger = await get_or_create_metric_logger(proc_mesh)
await global_logger.deregister_fetcher.call_one(proc_mesh)
await global_logger.deregister_fetcher.call_one(proc_mesh._uid)

if hasattr(proc_mesh, "_gpu_ids"):
gpu_manager = self._host_gpu_map[proc_mesh._host._host_id]
Expand Down
6 changes: 3 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .metrics import (
BackendRole,
ConsoleBackend,
get_actor_name_with_rank,
get_logger_backend_class,
LoggerBackend,
MaxAccumulator,
MeanAccumulator,
Expand All @@ -29,12 +27,12 @@
WandbBackend,
)
from .perf_tracker import trace, Tracer
from .utils import get_proc_name_with_rank

__all__ = [
# Main API functions
"record_metric",
"reduce_metrics_states",
"get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand All @@ -45,6 +43,8 @@
"BackendRole",
# Enums
"Reduce",
# Utility functions
"get_proc_name_with_rank",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
133 changes: 81 additions & 52 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

import asyncio
import logging
import uuid
from typing import Any, Union

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
from monarch.actor import (
Actor,
context,
endpoint,
get_or_spawn_controller,
ProcMesh,
this_proc,
)

from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
Expand All @@ -27,36 +35,35 @@

async def get_or_create_metric_logger(
proc_mesh: ProcMesh | None = None,
process_name: str | None = None,
) -> "GlobalLoggingActor":
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
if not already initialized, registers it with the GlobalLoggingActor and returns the
GlobalLoggingActor instance.
"""Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized),
registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor.

There are primarily two ways to use this function:
1. In the main process, call `get_or_create_metric_logger()` to get the global logger.
2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the
local fetcher with the global logger.
Usage:
1. Main process: call `get_or_create_metric_logger()` to get the global logger
2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the
map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks.

Args:
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
uses `monarch.actor.this_proc()`.
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`.
process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None.

Returns:
GlobalLoggingActor: The global logging controller.

Raises:
ValueError: If the logging state is inconsistent, i.e. the fetcher is already
registered, but only in the process or the global logger.
ValueError: If the logging state is inconsistent.

Example:
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric

# Main process setup
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")

# Initialize logging backends
await mlogger.init_backends({
await mlogger.init_backends.call_one({
"console": {"reduce_across_ranks": True},
"wandb": {"project": "my_project", "reduce_across_ranks": False}
})
Expand All @@ -66,12 +73,12 @@ async def get_or_create_metric_logger(

# Training loop
for step in range(max_steps):
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
# ... training code with record_metric() calls ...
await mlogger.flush(step) # Log metrics for this step
await mlogger.flush.call_one(step) # Log metrics for this step

# Shutdown
await mlogger.shutdown()
await mlogger.shutdown.call_one()
"""
# Get or create the singleton global logger
global _global_logger
Expand All @@ -85,9 +92,15 @@ async def get_or_create_metric_logger(
# Determine process context
proc = proc_mesh if proc_mesh is not None else this_proc()

# Auto-detect process_name from proc mesh if not provided
if process_name is None:
ctx = context()
process_name = ctx.actor_instance.actor_id.actor_name

# Check current state for consistency
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
proc_id = proc._uid if proc_has_local_fetcher else None
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id)

# Consistency check: both should be in sync
if proc_has_local_fetcher != global_logger_has_local_fetcher:
Expand All @@ -102,24 +115,32 @@ async def get_or_create_metric_logger(
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
local_fetcher_actor = proc.spawn(
"local_fetcher_actor", LocalFetcherActor, global_logger
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
)
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
# Generate a unique ID to map procmesh to fetcher
proc._uid = str(uuid.uuid4())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind approving it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little concerning about the broken CI. Wouldn't it cause all the subsequent commits to break as well?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think that errors are related to this PR. But let me confirm by opening a dummy PR

proc._local_fetcher = local_fetcher_actor # pyre-ignore

await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid)

return global_logger


class LocalFetcherActor(Actor):
"""Thin per-process actor used to trigger MetricCollector singleton
operations without direct access. It is what GlobalLoggingActor
uses to broadcast inits/flushes across ranks.
"""Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh
and accesses each rank's local MetricCollector.

GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
Flow:
GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""

def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
def __init__(
self,
global_logger: Union["GlobalLoggingActor", None] = None,
process_name: str | None = None,
) -> None:
self.global_logger = global_logger
self.process_name = process_name
_is_initialized = False

@endpoint
Expand All @@ -146,10 +167,22 @@ async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
global_step: int = 0,
) -> None:
"""Init local (per-rank) logger backends and MetricCollector."""
"""Init per-rank logger backends and MetricCollector.

Args:
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
config (dict[str, Any]): Backend configurations with logging modes and settings.
global_step (int): Initial step for metrics.
"""
collector = MetricCollector()
await collector.init_backends(metadata_per_primary_backend, config)
await collector.init_backends(
metadata_per_primary_backend,
config,
global_step,
process_name=self.process_name,
)

@endpoint
async def shutdown(self) -> None:
Expand All @@ -158,22 +191,17 @@ async def shutdown(self) -> None:


class GlobalLoggingActor(Actor):
"""Coordinates metric logging across all ranks for every training step.
"""Coordinates metric logging across all ProcMeshes and their ranks.

Supports multiple logging backends (e.g., WandB, TensorBoard, etc.),
for per-rank and/or global reduction logging modes.
with per-rank and/or global reduction logging modes.

If a backend config has flag `reduce_across_ranks=False`, an instance of the backend
is initialized per-rank, otherwise it is done once globally.

This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor
is automatically spawned per-rank in `forge.controller.provisioner.py` and registered
with this actor. The LocalFetcherActor is responsible for instantiating
the per-rank MetricCollector.

In summary, the flow is:
- GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector
- GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush
Flow:
GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""

def __init__(self):
Expand Down Expand Up @@ -209,7 +237,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:

for backend_name, backend_config in config.items():
backend = get_logger_backend_class(backend_name)(backend_config)
await backend.init(role=BackendRole.GLOBAL)
await backend.init(role=BackendRole.GLOBAL, name="global_reduce")

# Extract metadata from primary logger to be shared with secondary loggers
# and store it
Expand Down Expand Up @@ -237,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None:
await asyncio.gather(*tasks, return_exceptions=True)

@endpoint
async def register_fetcher(
self, fetcher: LocalFetcherActor, name: str | ProcMesh
) -> None:
"""Registers a fetcher with the global actor. Each key represents a process mesh.
If there are 2 processes, each with 2 replicas with N gpus, we would
have 4 keys, i.e. 2 proces meshes, each with 2 replicas."""
self.fetchers[name] = fetcher # pyre-ignore
async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None:
"""Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh.

Args:
fetcher: The LocalFetcherActor instance for a ProcMesh
proc_id: Unique identifier for the ProcMesh
"""
self.fetchers[proc_id] = fetcher

# Self-init for respawned actors
if self.config:
logger.debug(f"Initializing new LocalFetcherActor {name}")
logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}")
await fetcher.init_backends.call(
self.metadata_per_primary_backend, self.config
)

@endpoint
async def deregister_fetcher(self, name: str | ProcMesh) -> None:
if name not in self.fetchers:
async def deregister_fetcher(self, proc_id: str) -> None:
if proc_id not in self.fetchers:
logger.warning(
f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister."
f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister."
f"Available fetchers: {self.fetchers.keys()}"
)
return
del self.fetchers[name]
del self.fetchers[proc_id]

@endpoint
async def flush(self, global_step: int) -> None:
Expand Down Expand Up @@ -333,9 +362,9 @@ async def flush(self, global_step: int) -> None:
await logger_backend.log(reduced_metrics, global_step)

@endpoint
def has_fetcher(self, name: str | ProcMesh) -> bool:
"""Check if a fetcher is registered with the given name."""
return name in self.fetchers
def has_fetcher(self, proc_id: str) -> bool:
"""Check if a fetcher is registered with the given proc_id."""
return proc_id in self.fetchers

@endpoint
def get_fetcher_count(self) -> int:
Expand Down
Loading
Loading