diff --git a/docs/api-guide/router_replay.md b/docs/api-guide/router_replay.md new file mode 100644 index 00000000000..300a50db127 --- /dev/null +++ b/docs/api-guide/router_replay.md @@ -0,0 +1,176 @@ +# Design Document: MoE Router Replay Feature + +### 1. Overview + +This document provides a detailed description of the "Router Replay" feature implemented within the Megatron-LM Core for Mixture-of-Experts (MoE) models. + +This feature is designed to enhance determinism and analyzability in MoE model training and inference. It enables the model to load routing decisions from a predefined file and enforce their use during the forward pass, thereby bypassing the real-time routing computation. + +### 2. Motivation + +* **Determinism & Reproducibility**: In distributed training, MoE routing decisions can exhibit minor variations due to factors like floating-point precision. By replaying a fixed routing table, the MoE computation path is guaranteed to be identical across runs, which facilitates debugging and reproducing experimental results. +* **Performance Profiling**: The router's own computation (e.g., logits calculation, top-k selection) incurs overhead. In replay mode, this part of the computation can be completely skipped, allowing for more precise isolation and profiling of performance bottlenecks within the Expert Layers themselves. +* **Debugging Aid**: When issues arise in the model, fixing the routing decisions helps to isolate variables, making it easier to determine whether the problem lies with the routing mechanism or the expert computations. + +### 3. Design and Architecture + +The design follows the principles of being non-intrusive and on-demand, with the core idea of activating the replay logic only when explicitly requested by the user. + +* **Core Components**: + * `RouterReplay` (located in `megatron/core/transformer/moe/router_replay.py`): A utility class for replaying MoE routing decisions. When enabled via the `moe_enable_routing_replay` flag, a separate instance of `RouterReplay` is created for each MoE layer's router. Each instance is responsible for loading routing data and providing the deterministic routing decisions for its corresponding layer during the forward pass. + * `moe_enable_routing_replay` (located in `megatron/core/transformer/transformer_config.py`): A boolean global configuration flag that serves as the sole entry point for enabling this feature. + +* **Workflow**: + The feature supports different modes, such as recording and replaying, controlled by a `RouterReplayAction`. + + 1. **Enabling the Feature**: The user sets `moe_enable_routing_replay` to `True` in the model configuration. + 2. **Initialization**: When `moe_enable_routing_replay` is true, each `TopKRouter` creates its own `RouterReplay` instance. + 3. **Mode Configuration**: The user must programmatically set the desired router replay action (e.g., `record`, `forward_replay`, `backward_replay`) on the `RouterReplay` instances. + 4. **Execution Flow (within a mini-batch)**: + * **Forward Pass**: + * For each micro-batch, the `topk_routing_with_score_function` checks the `router_replay_action`. + * **In `record` mode**: The dynamically computed `top-k` expert indices are captured and stored. + * **In `forward_replay` mode**: The function retrieves pre-loaded expert indices from `target_topk_idx`. These indices are used for the forward computation and are also appended to the `replay_backward_list` to prepare for the backward pass. + * **Backward Pass**: + * For each micro-batch (processed in reverse order in pipeline parallelism), the `router_replay_action` is checked again. + * **In `backward_replay` mode**: The function retrieves the expert indices for the corresponding micro-batch by popping them from the `replay_backward_list`. This mode is intended for training recomputation (e.g., activation checkpointing and pipeline recompute) so the same routing decisions are used during recompute/backward as in forward, ensuring determinism and correctness. + +### 4. Implementation Details + +The implementation cleanly separates the replay logic from the router's core computation. + +* **`megatron/core/transformer/transformer_config.py`**: + * Adds the configuration option `moe_enable_routing_replay: bool = False`. + +* **`megatron/core/transformer/moe/moe_utils.py`**: + * Introduces the `RouterReplay` class to manage the state for recording and replaying routing decisions for a single MoE layer. + * `target_topk_idx`: An attribute holding the expert indices for the current micro-batch during forward replay mode. + * `recorded_topk_idx`: An attribute for storing the computed expert indices when in record mode. + * `replay_backward_list`: A list that accumulates the `top-k` indices used during the forward passes of a mini-batch. This list is consumed in FIFO order during the backward pass to ensure correctness under pipeline parallelism. + * `set_target_indices()`: A method to load the replay indices into `target_topk_idx` for the forward pass. + * `record_indices()`: A method to save the computed indices. + * The `topk_routing_with_score_function` is modified to contain the core logic. It checks the `router_replay_action` on the `router_replay` instance and accordingly performs one of the following actions: computes and records indices, replays indices from `target_topk_idx` (for forward), replays indices from `replay_backward_list` (for backward), or falls through to the default dynamic routing. + +#### Training recompute usage +- During forward replay, `set_target_indices()` prepares `replay_backward_list` so each micro-batch’s indices are available for recomputation. +- During recompute/backward, set action to `REPLAY_BACKWARD` so indices are consumed in FIFO order to mirror the forward sequence. + +### 5. Usage Guide + +1. **Enable & Instantiate** + - Create one `RouterReplay` instance per MoE router layer when building the model. + - Optionally use the global helpers to set/clear actions across all layers. +2. **Record Routing Decisions** + - Set action: `RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)`. + - Run the model; retrieve per-layer indices via `RouterReplay.get_recorded_data()` and persist. +3. **Forward Replay** + - Load indices and distribute: `RouterReplay.set_replay_data(list_of_tensors)`. + - Set action: `RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)`. + - Run the model; dynamic top‑k is bypassed and target indices are used. +4. **Backward Replay** + - For training recomputation (activation checkpointing or pipeline recompute), set action: `REPLAY_BACKWARD` during recomputation. + - Per micro‑batch indices are consumed from `replay_backward_list` in FIFO order. +5. **Cleanup** + - Use `RouterReplay.clear_global_indices()`, `RouterReplay.clear_global_router_replay_action()`, and `RouterReplay.clear_global_router_replay_instances()` to restore default behavior and prevent memory leaks. + +#### Quick usage with `topk_routing_with_score_function` + +```python +import torch +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction +from megatron.core.transformer.moe.moe_utils import topk_routing_with_score_function + +rr = RouterReplay() + +# Record +RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) +logits = torch.randn(8, 16) +probs_rec, routing_map_rec = topk_routing_with_score_function( + logits=logits, topk=2, use_pre_softmax=False, score_function="softmax", router_replay=rr, +) +recorded = rr.get_recorded_indices() +torch.save(recorded, "/tmp/replay.pt") + +# Forward replay +rr.clear_router_replay_action() +rr.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) +target = torch.load("/tmp/replay.pt") +rr.set_target_indices(target) +probs_rep, routing_map_rep = topk_routing_with_score_function( + logits=logits, topk=2, use_pre_softmax=False, score_function="softmax", router_replay=rr, +) + +RouterReplay.clear_global_router_replay_action() +RouterReplay.clear_global_indices() +RouterReplay.clear_global_router_replay_instances() +``` + +### 6. Minimal Demo + +Here is a minimal code example showing how to use RouterReplay for recording and replaying: + +```python +import torch +import torch.distributed as dist +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + +# Initialize distributed training +if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + +# Create a transformer config with RouterReplay enabled +config = TransformerConfig( + num_experts=8, + expert_model_parallel_size=1, + num_top_k=2, + moe_enable_routing_replay=True +) + +# Create a TopKRouter instance +router = TopKRouter(config) + +# Generate sample input (batch_size, sequence_length, hidden_size) +logits = torch.randn(16, 32, 8).to(torch.cuda.current_device()) + +# ----------------- +# 1. Recording Mode +# ----------------- +print("=== Recording Mode ===") +# Set global router replay action to RECORD +RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + +# Perform routing +routing_output = router.forward(logits) +print(f"Recorded top-k indices shape: {routing_output.top_k_idx.shape}") + +# ----------------- +# 2. Forward Replay Mode +# ----------------- +print("\n=== Forward Replay Mode ===") +# Save recorded indices to a file +torch.save(routing_output.top_k_idx, "/tmp/replay.pt") + +# Load indices from file and set as target for replay +replay_indices = torch.load("/tmp/replay.pt") +for router_instance in RouterReplay.global_router_replay_instances: + router_instance.target_topk_idx = replay_indices + +# Set global router replay action to REPLAY_FORWARD +RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + +# Perform routing again - this will use the replayed indices +replay_routing_output = router.forward(logits) +print(f"Replayed top-k indices shape: {replay_routing_output.top_k_idx.shape}") +print(f"Are indices the same? {torch.equal(routing_output.top_k_idx, replay_routing_output.top_k_idx)}") + + +# Clean up +RouterReplay.clear_global_router_replay_action() +RouterReplay.clear_global_indices() +RouterReplay.clear_global_router_replay_instances() +if dist.is_initialized(): + dist.destroy_process_group() +``` diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 5fdeda23dea..90e724580ed 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -14,6 +14,7 @@ from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.moe.router_replay import RouterReplay from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import internal_api @@ -567,6 +568,7 @@ def topk_routing_with_score_function( score_function: str = "softmax", expert_bias: Optional[torch.Tensor] = None, fused: bool = False, + router_replay: Optional['RouterReplay'] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the routing probabilities and map for top-k selection with score function. @@ -584,6 +586,11 @@ def topk_routing_with_score_function( expert_bias (torch.Tensor, optional): The bias added to logits for expert routing. Defaults to None. fused (bool, optional): Whether to use the fused version. Defaults to False. + router_replay (Optional['RouterReplay']): For debugging and development, allows for + deterministic routing by replaying a previously + recorded routing sequence. + + Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: @@ -611,7 +618,7 @@ def topk_routing_with_score_function( expert_bias=expert_bias, ) - def compute_topk( + def _compute_topk( scores: torch.Tensor, topk: int, num_groups: Optional[int] = None, @@ -642,6 +649,16 @@ def compute_topk( else: return torch.topk(scores, k=topk, dim=1) + def compute_topk(scores, topk, num_groups=None, group_topk=None): + # Default behavior if no replay is active + + if router_replay is None: + return _compute_topk(scores, topk, num_groups=num_groups, group_topk=group_topk) + else: + return router_replay.get_replay_topk( + scores, topk, num_groups, group_topk, _compute_topk + ) + if score_function == "softmax": if use_pre_softmax: scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index c22ca4e8446..4e3a08d66d8 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -21,6 +21,7 @@ topk_routing_with_score_function, z_loss_func, ) +from megatron.core.transformer.moe.router_replay import RouterReplay from megatron.core.transformer.transformer_config import TransformerConfig @@ -201,6 +202,10 @@ def __init__( self.global_tokens_per_expert = None self.ga_steps = None + self.router_replay = None + if self.config.moe_enable_routing_replay: + self.router_replay = RouterReplay() + def _maintain_float32_expert_bias(self): """ Maintain the expert bias in float32. @@ -523,6 +528,7 @@ def routing(self, logits: torch.Tensor): score_function=self.score_function, expert_bias=self.expert_bias, fused=self.config.moe_router_fusion, + router_replay=self.router_replay, ) # Apply token dropping to probs and routing_map. diff --git a/megatron/core/transformer/moe/router_replay.py b/megatron/core/transformer/moe/router_replay.py new file mode 100644 index 00000000000..b6b8e26a0a6 --- /dev/null +++ b/megatron/core/transformer/moe/router_replay.py @@ -0,0 +1,161 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from enum import Enum +from typing import Callable, List, Optional, Tuple + +import torch + + +class RouterReplayAction(Enum): + """ + A Enum to define the actions for router replay. + """ + + RECORD = "record" # Record the topk indices for replay + REPLAY_FORWARD = "replay_forward" # Replay the recorded topk indices for forward pass + REPLAY_BACKWARD = "replay_backward" # Replay topk indices for re-compute during backward pass + + +class RouterReplay: + """ + A class to manage the recording and replaying of MoE routing decisions. + It holds all router instances and provides static methods to globally + control recording and replaying. + """ + + # Static variable to hold all router instances, one per MoE layer. + global_router_replay_instances: List['RouterReplay'] = [] + + @staticmethod + def set_replay_data(all_layers_topk_indices: List[torch.Tensor]): + """ + Distributes the topk indices for all layers to their respective RouterReplay instances. + :param all_layers_topk_indices: A list of tensors, where each tensor contains the + topk indices for a specific layer. The order + must match the instantiation order of the routers. + """ + if len(all_layers_topk_indices) != len(RouterReplay.global_router_replay_instances): + raise ValueError( + f"The number of replay tensors ({len(all_layers_topk_indices)}) " + f"does not match instances ({len(RouterReplay.global_router_replay_instances)})." + ) + for i, router_instance in enumerate(RouterReplay.global_router_replay_instances): + router_instance.set_target_indices(all_layers_topk_indices[i]) + + @staticmethod + def get_recorded_data() -> List[torch.Tensor]: + """ + Collects the recorded topk indices from all RouterReplay instances. + :return: A list of tensors, each containing the recorded topk indices for a layer. + """ + return [ + router.get_recorded_indices() for router in RouterReplay.global_router_replay_instances + ] + + @staticmethod + def clear_global_indices(): + """Clears the recorded and target topk indices in all instances.""" + for router in RouterReplay.global_router_replay_instances: + router.clear_indices() + + @staticmethod + def set_global_router_replay_action(router_replay_action: RouterReplayAction): + """Sets the router replay action for all router instances.""" + for router in RouterReplay.global_router_replay_instances: + router.set_router_replay_action(router_replay_action) + + @staticmethod + def clear_global_router_replay_action(): + """Clears the router replay action for all router instances.""" + for router in RouterReplay.global_router_replay_instances: + router.clear_router_replay_action() + + @staticmethod + def clear_global_router_replay_instances(): + """Clear the global list of router replay instances to prevent memory leaks.""" + RouterReplay.global_router_replay_instances.clear() + + def __init__(self): + """Initializes a RouterReplay instance for a specific layer.""" + self.target_topk_idx: Optional[torch.Tensor] = None # Target topk indices for replay + self.recorded_topk_idx: Optional[torch.Tensor] = None # Recorded topk indices for replay + self.router_replay_action: Optional[RouterReplayAction] = ( + None # Router replay action for this layer + ) + self.replay_backward_list: List[torch.Tensor] = ( + [] + ) # List of tensors for backward pass replay + RouterReplay.global_router_replay_instances.append(self) + + def set_target_indices(self, topk_indices: torch.Tensor): + """Sets the target topk indices for replay.""" + self.target_topk_idx = topk_indices + self.replay_backward_list.append(topk_indices) + + def get_recorded_indices(self) -> Optional[torch.Tensor]: + """Returns the recorded topk indices.""" + return self.recorded_topk_idx + + def record_indices(self, topk_indices: torch.Tensor): + """Records the topk indices.""" + self.recorded_topk_idx = topk_indices + + def clear_indices(self): + """Clears the recorded and target topk indices.""" + self.recorded_topk_idx = None + self.target_topk_idx = None + self.replay_backward_list = [] + + def set_router_replay_action(self, router_replay_action: RouterReplayAction): + """Sets the router replay action for this layer.""" + self.router_replay_action = router_replay_action + + def clear_router_replay_action(self): + """Clears the router replay action for this layer.""" + self.router_replay_action = None + + def get_replay_topk( + self, + scores: torch.Tensor, + topk: int, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + default_compute_topk: Callable[ + [torch.Tensor, int, Optional[int], Optional[int]], Tuple[torch.Tensor, torch.Tensor] + ] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A wrapper for top-k computation that handles different replay actions. + + Args: + scores (torch.Tensor): The scores to compute top-k on. + topk (int): The number of top elements to select. + num_groups (Optional[int]): Number of expert groups for group-limited routing. + group_topk (Optional[int]): Number of groups to select for each token. + default_compute_topk (Callable): The default top-k computation function, which + should return a tuple of (values, indices). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the top-k values and indices. + """ + if self.router_replay_action == RouterReplayAction.RECORD: + probs, top_indices = default_compute_topk( + scores, topk, num_groups=num_groups, group_topk=group_topk + ) + self.record_indices(top_indices) + return probs, top_indices + elif self.router_replay_action == RouterReplayAction.REPLAY_FORWARD: + top_indices = self.target_topk_idx + # Ensure indices are on the correct device + top_indices = top_indices.to(scores.device) + # Gather the scores for the replayed indices to get the probabilities + probs = scores.gather(1, top_indices) + return probs, top_indices + elif self.router_replay_action == RouterReplayAction.REPLAY_BACKWARD: + top_indices = self.replay_backward_list.pop(0) + # Ensure indices are on the correct device + top_indices = top_indices.to(scores.device) + # Gather the scores for the replayed indices to get the probabilities + probs = scores.gather(1, top_indices) + return probs, top_indices + else: + return default_compute_topk(scores, topk, num_groups, group_topk) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index cabad4e15d7..1e0508c5368 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -549,6 +549,9 @@ class TransformerConfig(ModelParallelConfig): moe_router_topk: int = 2 """Number of experts to route to for each token.""" + moe_enable_routing_replay: bool = False + """If True, enable the routing replay feature for MoE layers.""" + moe_router_topk_limited_devices: Optional[int] = None """Number of EP ranks to consider for each token in group-limited routing, DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk. diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3c2a0e52f71..9034987ef49 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -3368,6 +3368,9 @@ def _add_moe_args(parser): help='Score function for MoE TopK routing. Can be "softmax" or "sigmoid".') group.add_argument('--moe-router-topk', type=int, default=2, help='Number of experts to route to for each token. The default is 2.') + group.add_argument('--moe-enable-routing-replay', action='store_true', + help='Enable routing replay for MoE routers. When enabled, the router will ' + 'use a pre-defined routing table instead of computing it on the fly.') group.add_argument('--moe-router-pre-softmax', action='store_true', help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.') group.add_argument('--moe-router-num-groups', type=int, default=None, diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 5680751f63f..3c7ae93a17c 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -191,6 +191,7 @@ "moe_token_dropping": False, "moe_use_legacy_grouped_gemm": False, "moe_z_loss_coeff": None, + "moe_enable_routing_replay": False, "mrope_section": None, "mtp_loss_scaling_factor": 0.1, "mtp_num_layers": None, diff --git a/tests/unit_tests/transformer/moe/test_router_replay.py b/tests/unit_tests/transformer/moe/test_router_replay.py new file mode 100644 index 00000000000..840fc0fd269 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_router_replay.py @@ -0,0 +1,95 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +import pytest +import torch + +from megatron.core.transformer.moe.moe_utils import topk_routing_with_score_function +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction + + +def setup_function(): + RouterReplay.global_router_replay_instances.clear() + + +def teardown_function(): + RouterReplay.global_router_replay_instances.clear() + + +def test_record_mode_with_topk_routing_softmax_post(): + rr = RouterReplay() + rr.set_router_replay_action(RouterReplayAction.RECORD) + logits = torch.randn(4, 6) + probs, routing_map = topk_routing_with_score_function( + logits=logits, topk=2, use_pre_softmax=False, router_replay=rr, score_function="softmax" + ) + recorded = rr.get_recorded_indices() + expected_idx = torch.topk(logits, k=2, dim=1).indices + assert recorded is not None + assert torch.equal(recorded, expected_idx) + assert probs.shape == (4, 6) + assert routing_map.shape == (4, 6) + assert routing_map.sum(dim=1).eq(2).all() + + +def test_replay_forward_with_topk_routing_softmax_pre(): + rr = RouterReplay() + rr.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + logits = torch.randn(3, 5) + target = torch.tensor([[1, 2], [0, 3], [2, 4]], dtype=torch.long) + rr.set_target_indices(target) + probs, routing_map = topk_routing_with_score_function( + logits=logits, topk=2, use_pre_softmax=True, router_replay=rr, score_function="softmax" + ) + assert routing_map.sum(dim=1).eq(2).all() + scores = torch.softmax(logits, dim=-1) + assert torch.equal(probs.gather(1, target), scores.gather(1, target)) + + +def test_replay_forward_with_topk_routing_softmax_post(): + rr = RouterReplay() + rr.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + logits = torch.randn(3, 6) + target = torch.tensor([[1, 2], [0, 5], [3, 4]], dtype=torch.long) + rr.set_target_indices(target) + probs, routing_map = topk_routing_with_score_function( + logits=logits, topk=2, use_pre_softmax=False, router_replay=rr, score_function="softmax" + ) + selected = torch.softmax(logits.gather(1, target), dim=-1) + assert torch.equal(probs.gather(1, target), selected) + assert routing_map.sum(dim=1).eq(2).all() + + +def test_global_set_get_clear_indices(): + r1 = RouterReplay() + r2 = RouterReplay() + t1 = torch.tensor([[0, 1]], dtype=torch.long) + t2 = torch.tensor([[1, 0]], dtype=torch.long) + RouterReplay.set_replay_data([t1, t2]) + assert torch.equal(r1.target_topk_idx, t1) + assert torch.equal(r2.target_topk_idx, t2) + r1.record_indices(t1) + r2.record_indices(t2) + rec = RouterReplay.get_recorded_data() + assert len(rec) == 2 + assert torch.equal(rec[0], t1) + assert torch.equal(rec[1], t2) + RouterReplay.clear_global_indices() + assert r1.target_topk_idx is None and r2.target_topk_idx is None + assert r1.get_recorded_indices() is None and r2.get_recorded_indices() is None + + +def test_global_action_set_and_clear(): + r1 = RouterReplay() + r2 = RouterReplay() + RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD) + assert r1.router_replay_action == RouterReplayAction.REPLAY_FORWARD + assert r2.router_replay_action == RouterReplayAction.REPLAY_FORWARD + RouterReplay.clear_global_router_replay_action() + assert r1.router_replay_action is None and r2.router_replay_action is None + + +def test_set_replay_data_length_mismatch(): + _ = RouterReplay() + with pytest.raises(ValueError): + RouterReplay.set_replay_data( + [torch.tensor([[0, 1]], dtype=torch.long), torch.tensor([[1, 0]], dtype=torch.long)] + )