diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py
index eb96ac0579..e9693432da 100644
--- a/benchmarks/routines/gemm.py
+++ b/benchmarks/routines/gemm.py
@@ -1038,13 +1038,12 @@ def testMmFp4(args):
input_fp4, input_inv_s = flashinfer.mxfp4_quantize(input)
mat2_fp4, mat2_inv_s = flashinfer.mxfp4_quantize(mat2)
- if "trtllm" in backends:
- mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize(
- mat2,
- global_sf_mat2,
- sfLayout=flashinfer.SfLayout.layout_128x4,
- do_shuffle=True,
- )
+ mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize(
+ mat2,
+ global_sf_mat2,
+ sfLayout=flashinfer.SfLayout.layout_128x4,
+ do_shuffle=True,
+ )
if args.verbose >= 2:
print(f"[VVERBOSE] {input_fp4.shape = }")
diff --git a/docs/logging.rst b/docs/logging.rst
index c3c2c83d8f..b5b866b128 100644
--- a/docs/logging.rst
+++ b/docs/logging.rst
@@ -12,7 +12,7 @@ Enable logging using two environment variables:
.. code-block:: bash
- # Set logging level (0-5)
+ # Set logging level (0-10)
export FLASHINFER_LOGLEVEL=3
# Set log destination (default is stdout)
@@ -45,6 +45,10 @@ Logging Levels
- Statistics
- Level 3 + tensor statistics (min, max, mean, NaN/Inf counts)
- Numerical analysis
+ * - **10**
+ - Flight Recorder - Full Input/Output Dumps
+ - Level 5 + dumps all input/output tensors to ``.pt`` (or ``.safetensors``) files
+ - Full Reproducibility / Debugging
Environment Variables
---------------------
@@ -63,12 +67,156 @@ Main Configuration
* - ``FLASHINFER_LOGLEVEL``
- int
- 0
- - Logging level (0, 1, 3, 5)
+ - Logging level (0, 1, 3, 5, 10)
* - ``FLASHINFER_LOGDEST``
- str
- ``stdout``
- Log destination: ``stdout``, ``stderr``, or file path
+Dump Configuration (Level 10)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When FLASHINFER_LOGLEVEL is set to 10, the following environment variables can be used to configure the dump behavior:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 30 15 15 40
+
+ * - Variable
+ - Type
+ - Default
+ - Description
+ * - ``FLASHINFER_DUMP_DIR``
+ - str
+ - ``flashinfer_dumps``
+ - Directory to save dump files
+ * - ``FLASHINFER_DUMP_MAX_SIZE_GB``
+ - float
+ - 20
+ - Maximum size of dump directory in GB
+ * - ``FLASHINFER_DUMP_MAX_COUNT``
+ - int
+ - 1000
+ - Maximum number of API calls to dump
+ * - ``FLASHINFER_DUMP_INCLUDE``
+ - str
+ - (empty)
+ - Comma-separated patterns to include (fnmatch-style)
+ * - ``FLASHINFER_DUMP_EXCLUDE``
+ - str
+ - (empty)
+ - Comma-separated patterns to exclude (fnmatch-style)
+ * - ``FLASHINFER_DUMP_SAFETENSORS``
+ - int
+ - 0
+ - Set to 1 to use safetensors format (no pickle, but loses stride info)
+
+SafeTensors Format (Optional)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+By default, tensors are saved using ``torch.save()`` which preserves tensor stride and contiguity information.
+For faster, pickle-free serialization, you can enable safetensors format:
+
+.. code-block:: bash
+
+ export FLASHINFER_DUMP_SAFETENSORS=1
+
+.. warning::
+ SafeTensors does NOT preserve tensor strides or non-contiguity.
+ All tensors are saved as contiguous. Use the default ``torch.save`` format
+ if stride preservation is important for your debugging.
+
+**Comparison**:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 35 40
+
+ * - Aspect
+ - torch.save (default)
+ - safetensors
+ * - Speed
+ - Standard
+ - Faster
+ * - Safety
+ - Uses pickle
+ - No pickle (safer)
+ * - Stride preservation
+ - ✅ Yes
+ - ❌ No (contiguous only)
+ * - File extension
+ - ``.pt``
+ - ``.safetensors``
+ * - Dependency
+ - `torch``
+ - Requires ``pip install safetensors``
+
+**Replay is format-agnostic**: The replay command automatically detects the format based on file extension.
+
+Dump Filtering (Include/Exclude)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Use ``FLASHINFER_DUMP_INCLUDE`` and ``FLASHINFER_DUMP_EXCLUDE`` to control which API calls are dumped.
+This is especially useful when running end-to-end inference with many API calls but you only care about specific ones.
+
+**Pattern Syntax** (fnmatch-style):
+
+- ``*`` matches any number of characters
+- ``?`` matches a single character
+- Matching is case-sensitive
+- For class methods, the function name is formatted as ``ClassName.method_name``
+
+**Filter Logic**:
+
+1. If ``FLASHINFER_DUMP_INCLUDE`` is set, only APIs matching at least one pattern are dumped
+2. If ``FLASHINFER_DUMP_EXCLUDE`` is set, APIs matching any pattern are skipped
+3. Both can be combined: include filter is applied first, then exclude filter
+
+**Examples**:
+
+.. code-block:: bash
+
+ # Only dump decode-related APIs
+ export FLASHINFER_DUMP_INCLUDE="*decode*"
+
+ # Dump everything except __init__ and plan methods
+ export FLASHINFER_DUMP_EXCLUDE="*.__init__,*.plan"
+
+ # Only dump run() methods from wrapper classes
+ export FLASHINFER_DUMP_INCLUDE="*Wrapper.run"
+
+ # Dump all single_* APIs except prefill
+ export FLASHINFER_DUMP_INCLUDE="single_*"
+ export FLASHINFER_DUMP_EXCLUDE="*prefill*"
+
+ # Only dump a specific wrapper's run method
+ export FLASHINFER_DUMP_INCLUDE="BatchDecodeWithPagedKVCacheWrapper.run"
+
+ # Dump FP8 APIs but not quantization steps
+ export FLASHINFER_DUMP_INCLUDE="*fp8*,*FP8*"
+ export FLASHINFER_DUMP_EXCLUDE="*quantize*"
+
+**Common Patterns**:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 40 60
+
+ * - Pattern
+ - Matches
+ * - ``*decode*``
+ - ``single_decode_with_kv_cache``, ``BatchDecodeWithPagedKVCacheWrapper.run``
+ * - ``*Wrapper.run``
+ - ``BatchDecodeWithPagedKVCacheWrapper.run``, ``BatchPrefillWithPagedKVCacheWrapper.run``
+ * - ``*.__init__``
+ - All wrapper ``__init__`` methods
+ * - ``*.plan``
+ - All wrapper ``plan`` methods
+ * - ``mm_fp8``
+ - Exact match for ``mm_fp8`` function
+ * - ``single_*``
+ - ``single_decode_with_kv_cache``, ``single_prefill_with_kv_cache``
+
Process ID Substitution
^^^^^^^^^^^^^^^^^^^^^^^^
@@ -116,3 +264,415 @@ Level 0 has zero overhead
^^^^^^^^^^^^^^^^^^^^^^^^^^^
At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead.
+
+Flight Recorder & Replay
+------------------------
+
+FlashInfer includes a "Flight Recorder" mode (Level 10) that captures inputs/outputs for reproducibility.
+
+Dump Directory Structure
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+When Level 10 logging is enabled, FlashInfer creates the following structure:
+
+.. code-block:: text
+
+ FLASHINFER_DUMP_DIR/
+ ├── session.jsonl # Central log: one line per event (quick scanning)
+ ├── 20250108_143216_802_pid12345_mm_fp8_call0001/
+ │ ├── metadata.jsonl # Per-dump metadata (JSONL format)
+ │ ├── inputs.pt # Input tensors (or .safetensors if enabled)
+ │ └── outputs.pt # Output tensors (or .safetensors if enabled)
+ ├── 20250108_143216_868_pid12345_single_decode_call0001/
+ │ ├── metadata.jsonl
+ │ ├── inputs.pt # (or .safetensors)
+ │ └── outputs.pt # (or .safetensors)
+ └── ...
+
+**JSONL Format**: Both ``session.jsonl`` and ``metadata.jsonl`` use `JSON Lines `_ format
+(one JSON object per line). This enables:
+
+- **Crash-safe logging**: Each API call appends two lines (inputs_saved, then completed)
+- **Quick scanning**: Use ``session.jsonl`` to browse all recorded calls without reading subdirectories
+- **Streaming reads**: Process records line-by-line for large sessions
+
+**Per-dump metadata.jsonl**:
+
+- Line 1: Written **before** execution (``execution_status: "inputs_saved"``)
+- Line 2: Appended **after** successful execution (``execution_status: "completed"``)
+
+If a crash occurs, only line 1 will be present, preserving the inputs for debugging.
+
+**Central session.jsonl**:
+
+One-stop log of all API calls. Use standard tools to filter and analyze:
+
+.. code-block:: bash
+
+ # Enable Flight Recorder (Metadata + Tensors)
+ export FLASHINFER_LOGLEVEL=10
+ export FLASHINFER_DUMP_DIR=./my_dumps
+
+ # Run your application
+ python3 benchmarks/flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command --use_cupti --no_cuda_graph --num_iters 5
+ ... output redacted ...
+
+ # Replay recorded calls
+ export FLASHINFER_LOGLEVEL=0 # 1 for more detailed replay results.
+ flashinfer replay --dir ./my_dumps
+ # or
+ python -m flashinfer replay --dir ./my_dumps
+
+ [1] nvfp4_quantize (20251204_143216_802_pid12345_nvfp4_quantize_call0001): ✅ Passed
+ [2] fp4_quantize (20251204_143216_868_pid12345_fp4_quantize_call0001): ✅ Passed
+ [3] nvfp4_quantize (20251204_143216_949_pid12345_nvfp4_quantize_call0002): ✅ Passed
+ [4] fp4_quantize (20251204_143217_003_pid12345_fp4_quantize_call0002): ✅ Passed
+ [5] mm_fp4 (20251204_143217_178_pid12345_mm_fp4_call0001): ✅ Passed
+ [6] mm_fp4 (20251204_143217_346_pid12345_mm_fp4_call0002): ✅ Passed
+ [7] mm_fp4 (20251204_143217_427_pid12345_mm_fp4_call0003): ✅ Passed
+ [8] mm_fp4 (20251204_143217_475_pid12345_mm_fp4_call0004): ✅ Passed
+ [9] mm_fp4 (20251204_143217_510_pid12345_mm_fp4_call0005): ✅ Passed
+ [10] mm_fp4 (20251204_143217_551_pid12345_mm_fp4_call0006): ✅ Passed
+ [11] mm_fp4 (20251204_143217_591_pid12345_mm_fp4_call0007): ✅ Passed
+ [12] mm_fp4 (20251204_143217_631_pid12345_mm_fp4_call0008): ✅ Passed
+ [13] mm_fp4 (20251204_143217_672_pid12345_mm_fp4_call0009): ✅ Passed
+ [14] mm_fp4 (20251204_143217_708_pid12345_mm_fp4_call0010): ✅ Passed
+ [15] mm_fp4 (20251204_143217_769_pid12345_mm_fp4_call0011): ✅ Passed
+ [16] mm_fp4 (20251204_143217_812_pid12345_mm_fp4_call0012): ✅ Passed
+ [17] mm_fp4 (20251204_143217_852_pid12345_mm_fp4_call0013): ✅ Passed
+ [18] mm_fp4 (20251204_143217_904_pid12345_mm_fp4_call0014): ✅ Passed
+ [19] mm_fp4 (20251204_143218_153_pid12345_mm_fp4_call0015): ✅ Passed
+ [20] mm_fp4 (20251204_143218_390_pid12345_mm_fp4_call0016): ✅ Passed
+ [21] mm_fp4 (20251204_143218_627_pid12345_mm_fp4_call0017): ✅ Passed
+ [22] mm_fp4 (20251204_143218_862_pid12345_mm_fp4_call0018): ✅ Passed
+
+ Summary: 22 passed, 0 failed/mismatch
+
+Python-Based Replay Examples
+----------------------------
+
+The following examples demonstrate how to use Level 10 logging to dump and replay API calls programmatically using Python.
+
+Example 1: ``bmm_fp8`` - Simple Function Call
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+**Producer Script** (``bmm_fp8_producer.py``):
+
+This script initializes tensors, calls ``bmm_fp8``, and dumps the inputs/outputs to disk.
+
+.. code-block:: python
+
+ """
+ Producer script: Run bmm_fp8 with Level 10 logging to dump tensors.
+
+ Usage:
+ FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./bmm_fp8_dumps python bmm_fp8_producer.py
+ """
+ import torch
+ from flashinfer import bmm_fp8
+
+ def to_float8(x, dtype=torch.float8_e4m3fn):
+ """Convert tensor to FP8 with per-tensor scaling."""
+ finfo = torch.finfo(dtype)
+ min_val, max_val = x.aminmax()
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
+ scale = finfo.max / amax
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
+ return x_scl_sat.to(dtype), scale.float().reciprocal()
+
+ # Parameters
+ b, m, n, k = 4, 64, 128, 256
+ input_dtype = torch.float8_e4m3fn
+ mat2_dtype = torch.float8_e4m3fn
+ res_dtype = torch.bfloat16
+
+ # Create input tensors
+ input_bf16 = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16)
+ input_fp8, input_inv_s = to_float8(input_bf16, dtype=input_dtype)
+
+ # mat2: row major -> column major (transposed)
+ mat2_bf16 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
+ mat2_fp8, mat2_inv_s = to_float8(mat2_bf16, dtype=mat2_dtype)
+
+ # Pre-allocate output
+ res = torch.empty([b, m, n], device="cuda", dtype=res_dtype)
+
+ # Call bmm_fp8 - this will be logged/dumped at Level 10
+ bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res, backend="cublas")
+
+ # Print a small portion of the output for verification
+ print("Output shape:", res.shape)
+ print("Output[0, :3, :3]:")
+ print(res[0, :3, :3])
+
+**Reproducer Script** (``bmm_fp8_reproducer.py``):
+
+This script loads the dumped tensors and replays the ``bmm_fp8`` call.
+
+.. code-block:: python
+
+ """
+ Reproducer script: Load dumped tensors and replay bmm_fp8.
+
+ Usage:
+ python bmm_fp8_reproducer.py
+ """
+ import torch
+ from pathlib import Path
+ from flashinfer import bmm_fp8
+ from flashinfer.api_logging import replay_from_dump
+
+ DUMP_DIR = "./bmm_fp8_dumps"
+
+ # Find the bmm_fp8 dump directory (should be the only one or the latest)
+ dump_path = Path(DUMP_DIR)
+ bmm_dumps = sorted([d for d in dump_path.iterdir() if d.is_dir() and "bmm_fp8" in d.name])
+ latest_dump = bmm_dumps[-1] # Use the latest dump
+ print(f"Loading dump from: {latest_dump}")
+
+ # Use replay_from_dump to load inputs and optionally execute
+ result = replay_from_dump(
+ str(latest_dump),
+ compare_outputs=True, # Load expected outputs for comparison
+ device="cuda",
+ run=False, # We'll call the function manually below
+ )
+
+ # Extract the loaded arguments - args contains all positional args including the output tensor
+ args = result["args"]
+ kwargs = result["kwargs"]
+ expected_tensors = result.get("expected_tensors", {})
+
+ # Replay the call - args already contains (input, mat2, input_inv_s, mat2_inv_s, dtype, out)
+ res = bmm_fp8(*args, **kwargs)
+
+ # Print the same portion for comparison
+ print("Replayed output shape:", res.shape)
+ print("Replayed output[0, :3, :3]:")
+ print(res[0, :3, :3])
+
+ # Compare with expected output if available
+ if "result" in expected_tensors:
+ expected = expected_tensors["result"]
+ if torch.allclose(res, expected, rtol=1e-3, atol=1e-3):
+ print("\n✅ Output matches expected result!")
+ else:
+ diff = (res - expected).abs().max().item()
+ print(f"\n❌ Output mismatch! Max diff: {diff}")
+
+Example 2: ``BatchDecodeWithPagedKVCacheWrapper`` - Stateful Wrapper Class
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+**Producer Script** (``batch_decode_producer.py``):
+
+This script demonstrates logging with a stateful wrapper class that requires ``__init__``, ``plan``, and ``run`` calls.
+
+.. code-block:: python
+
+ """
+ Producer script: Run BatchDecodeWithPagedKVCacheWrapper with Level 10 logging.
+
+ Usage:
+ FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./batch_decode_dumps python batch_decode_producer.py
+ """
+ import torch
+ import flashinfer
+
+ # Parameters
+ batch_size = 4
+ kv_len = 512
+ page_size = 16
+ num_kv_heads = 4
+ num_qo_heads = 32
+ head_dim = 128
+ kv_layout = "NHD"
+
+ # Create query tensor
+ q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+
+ # Create paged KV cache
+ num_pages_per_seq = (kv_len + page_size - 1) // page_size
+ total_num_pages = num_pages_per_seq * batch_size
+ kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] # NHD layout
+ kv_data = torch.randn(*kv_shape, device="cuda", dtype=torch.float16)
+
+ # Create index tensors
+ kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * num_pages_per_seq
+ kv_indices = torch.arange(0, total_num_pages, device="cuda", dtype=torch.int32)
+ kv_last_page_len = torch.full(
+ (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda"
+ )
+
+ # Create workspace and wrapper - __init__ will be logged
+ workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
+ wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout)
+
+ # Plan - will be logged
+ wrapper.plan(
+ kv_indptr,
+ kv_indices,
+ kv_last_page_len,
+ num_qo_heads,
+ num_kv_heads,
+ head_dim,
+ page_size,
+ data_type=torch.float16,
+ q_data_type=torch.float16,
+ )
+
+ # Run - will be logged
+ output, lse = wrapper.run(q, kv_data, return_lse=True)
+
+ # Print a small portion of the output
+ print("Output shape:", output.shape)
+ print("Output[0, :3, :3]:")
+ print(output[0, :3, :3])
+ print("\nLSE shape:", lse.shape)
+ print("LSE[0, :5]:", lse[0, :5])
+
+**Reproducer Script** (``batch_decode_reproducer.py``):
+
+This script demonstrates replaying a sequence of stateful API calls.
+
+.. code-block:: python
+
+ """
+ Reproducer script: Replay BatchDecodeWithPagedKVCacheWrapper calls.
+
+ Usage:
+ python batch_decode_reproducer.py
+ """
+ import torch
+ from pathlib import Path
+ from flashinfer.api_logging import replay_sequence
+
+ DUMP_DIR = "./batch_decode_dumps"
+
+ # replay_sequence handles stateful objects automatically via object_registry
+ # It will:
+ # 1. Replay __init__ to create the wrapper instance
+ # 2. Replay plan() on the same instance
+ # 3. Replay run() on the same instance and compare outputs
+ results = replay_sequence(DUMP_DIR, device="cuda")
+
+ # Print summary
+ passed = 0
+ failed = 0
+ for i, res in enumerate(results):
+ func_name = res.get("metadata", {}).get("function_name", "unknown")
+ dump_dir = Path(res.get("dump_dir", "")).name
+
+ if "error" in res:
+ print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Error: {res['error']}")
+ failed += 1
+ elif res.get("comparison_match", True):
+ print(f"[{i+1}] {func_name} ({dump_dir}): ✅ Passed")
+ passed += 1
+ else:
+ print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Mismatch")
+ failed += 1
+
+ print(f"\nSummary: {passed} passed, {failed} failed")
+
+ # For manual inspection, you can also access individual results
+ # Find the 'run' call result (usually the last non-init, non-plan call)
+ for res in results:
+ func_name = res.get("metadata", {}).get("function_name", "")
+ if "run" in func_name and "execution_result" in res:
+ output = res["execution_result"]
+ if isinstance(output, tuple):
+ output_tensor, lse = output
+ print("\nReplayed output[0, :3, :3]:")
+ print(output_tensor[0, :3, :3])
+ print("Replayed LSE[0, :5]:", lse[0, :5])
+ break
+
+Manual Replay Without ``replay_from_dump``
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+For more control, you can manually load the dumped tensors:
+
+.. note::
+ This example assumes the default ``torch.save`` format (``.pt`` files).
+ If dumps were created with ``FLASHINFER_DUMP_SAFETENSORS=1``, use
+ ``safetensors.torch.load_file()`` instead of ``torch.load()``.
+
+.. code-block:: python
+
+ """
+ Manual replay: Load tensors directly from .pt files.
+ """
+ import json
+ import torch
+ from pathlib import Path
+ from flashinfer import bmm_fp8
+
+ # Path is an example, replace with the actual path.
+ dump_dir = Path("./bmm_fp8_dumps/20250108_103217_012_pid12345_bmm_fp8_call0001")
+
+ # Load metadata from JSONL (read last line for most complete state)
+ with open(dump_dir / "metadata.jsonl") as f:
+ lines = [line.strip() for line in f if line.strip()]
+ metadata = json.loads(lines[-1]) # Last line has completed state
+
+ print(f"Function: {metadata['function_name']}")
+ print(f"Module: {metadata['module']}")
+ print(f"Status: {metadata['execution_status']}")
+ print(f"Input tensors: {metadata['tensor_info']['input_tensor_keys']}")
+
+ # Load input tensors
+ inputs = torch.load(dump_dir / "inputs.pt", map_location="cuda")
+
+ # Load expected outputs (if execution completed successfully)
+ outputs_path = dump_dir / "outputs.pt"
+ if outputs_path.exists():
+ expected = torch.load(outputs_path, map_location="cuda")
+ print(f"Output tensors: {list(expected.keys())}")
+
+ # Tensors are ready to use - reconstruct the call as needed
+ for key, tensor in inputs.items():
+ print(f" {key}: shape={tensor.shape}, dtype={tensor.dtype}")
+
+Scanning Session History
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+Use the central ``session.jsonl`` to quickly scan all recorded API calls:
+
+.. code-block:: python
+
+ """
+ Scan session.jsonl for quick overview of recorded calls.
+ """
+ import json
+ from pathlib import Path
+ from collections import Counter
+
+ dump_root = Path("./my_dumps")
+ session_file = dump_root / "session.jsonl"
+
+ # Read all records
+ records = []
+ with open(session_file) as f:
+ for line in f:
+ if line.strip():
+ records.append(json.loads(line))
+
+ # Filter to completed calls only
+ completed = [r for r in records if r["execution_status"] == "completed"]
+ print(f"Total completed calls: {len(completed)}")
+
+ # Count by function name
+ func_counts = Counter(r["function_name"] for r in completed)
+ print("\nCalls by function:")
+ for func, count in func_counts.most_common():
+ print(f" {func}: {count}")
+
+ # Find calls that didn't complete (potential crashes)
+ inputs_only = [r for r in records if r["execution_status"] == "inputs_saved"]
+ # Group by dump_dir to find incomplete calls
+ completed_dirs = {r["dump_dir"] for r in completed}
+ incomplete = [r for r in inputs_only if r["dump_dir"] not in completed_dirs]
+ if incomplete:
+ print(f"\n⚠️ Found {len(incomplete)} incomplete calls (potential crashes):")
+ for r in incomplete:
+ print(f" - {r['function_name']} at {r['dump_dir']}")
diff --git a/flashinfer/__main__.py b/flashinfer/__main__.py
index 4e822da19f..eed6671359 100644
--- a/flashinfer/__main__.py
+++ b/flashinfer/__main__.py
@@ -383,5 +383,78 @@ def export_compile_commands_cmd(path, output):
click.secho(f"❌ Failed to write compile commands: {e}", fg="red")
+@cli.command("replay")
+@click.option(
+ "--dir",
+ "dump_dir",
+ required=True,
+ help="Directory containing dump files (or root directory of session)",
+)
+def replay_cmd(dump_dir):
+ """Replay API calls from dump directory"""
+ from .api_logging import replay_sequence, replay_from_dump
+
+ device = "cuda"
+
+ if not os.path.exists(dump_dir):
+ click.secho(f"❌ Directory not found: {dump_dir}", fg="red")
+ return
+
+ # Check if this is a single dump or a session / sequence root
+ is_single_dump = os.path.exists(os.path.join(dump_dir, "metadata.jsonl"))
+
+ try:
+ if is_single_dump:
+ click.secho(f"Replaying single dump from {dump_dir}...", fg="cyan")
+ result = replay_from_dump(
+ dump_dir, compare_outputs=True, device=device, run=True
+ )
+ if result.get("comparison_match"):
+ click.secho("✅ Replay passed (outputs matched)", fg="green")
+ elif result.get("execution_error"):
+ click.secho(
+ f"❌ Execution failed: {result['execution_error']}", fg="red"
+ )
+ else:
+ click.secho("⚠️ Replay finished but outputs did not match", fg="yellow")
+ else:
+ # Session / sequence replay
+ click.secho(f"Replaying session from {dump_dir}...", fg="cyan")
+ results = replay_sequence(dump_dir, device=device)
+
+ passed = 0
+ failed = 0
+
+ for i, res in enumerate(results):
+ dump_name = (
+ os.path.basename(res.get("dump_dir", ""))
+ if "dump_dir" in res
+ else f"call_{i + 1}"
+ )
+ # If replay_from_dump returned successfully, metadata might have the name
+ if "metadata" in res and "function_name" in res["metadata"]:
+ func_name = res["metadata"]["function_name"]
+ dump_name = f"{func_name} ({dump_name})"
+
+ if "error" in res:
+ click.secho(
+ f"[{i + 1}] {dump_name}: ❌ Error: {res['error']}", fg="red"
+ )
+ failed += 1
+ elif res.get("comparison_match"):
+ click.secho(f"[{i + 1}] {dump_name}: ✅ Passed", fg="green")
+ passed += 1
+ else:
+ click.secho(f"[{i + 1}] {dump_name}: ⚠️ Mismatch", fg="yellow")
+ failed += 1
+
+ click.secho(
+ f"\nSummary: {passed} passed, {failed} failed/mismatch", fg="white"
+ )
+
+ except Exception as e:
+ click.secho(f"❌ Replay failed: {e}", fg="red")
+
+
if __name__ == "__main__":
cli()
diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py
index 734d6bae28..bfd4571531 100644
--- a/flashinfer/api_logging.py
+++ b/flashinfer/api_logging.py
@@ -15,13 +15,18 @@
"""
import enum
+import fnmatch
import functools
import inspect
+import json
import logging
import os
import sys
-from typing import Any, Callable
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Callable, Dict, Tuple, Optional
import contextlib
+import importlib
import torch
@@ -42,6 +47,27 @@ def _substitute_process_id(path: str) -> str:
_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0"))
_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout"))
+# Configuration for Level 10 tensor dumping
+_DUMP_DIR = os.environ.get("FLASHINFER_DUMP_DIR", "flashinfer_dumps")
+_DUMP_MAX_SIZE_GB = float(os.environ.get("FLASHINFER_DUMP_MAX_SIZE_GB", "20"))
+_DUMP_MAX_COUNT = int(os.environ.get("FLASHINFER_DUMP_MAX_COUNT", "1000"))
+
+# Dump filtering: include/exclude patterns (fnmatch-style, comma-separated)
+# Examples: "*decode*,*prefill*" or "BatchDecodeWrapper.run,mm_fp8"
+_DUMP_INCLUDE = os.environ.get("FLASHINFER_DUMP_INCLUDE", "")
+_DUMP_EXCLUDE = os.environ.get("FLASHINFER_DUMP_EXCLUDE", "")
+_DUMP_INCLUDE_PATTERNS = [p.strip() for p in _DUMP_INCLUDE.split(",") if p.strip()]
+_DUMP_EXCLUDE_PATTERNS = [p.strip() for p in _DUMP_EXCLUDE.split(",") if p.strip()]
+
+# SafeTensors format option (default: use torch.save which preserves stride/contiguity)
+_DUMP_SAFETENSORS = os.environ.get("FLASHINFER_DUMP_SAFETENSORS", "0") == "1"
+
+# Global tracking for dump limits (reset per process)
+_dump_count = 0
+_dump_total_size_bytes = 0
+_dump_call_counter = {} # Track call count per function
+_session_jsonl_initialized = False # Track if session.jsonl header was written
+
# Create logger using Python's logging library
_logger = logging.getLogger("flashinfer.api")
@@ -82,11 +108,947 @@ def _setup_logger():
def _get_timestamp() -> str:
"""Get current timestamp in the format [YYYY-MM-DD HH:MM:SS]."""
- from datetime import datetime
-
return datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
+def _warn_dump():
+ """Warn users about security implications of Level 10 logging."""
+ if _API_LOG_LEVEL >= 10:
+ print("=" * 80)
+ print(
+ "WARNING: FlashInfer API Logging is set to Level 10 (Tensor Dumping).\n"
+ "This will dump ALL input and outputs including tensors for FlashInfer APIs to disk in\n"
+ "the configured dump directory. Ensure that you are NOT processing sensitive data\n"
+ "or that the dump directory is secure. To disable dumping, unset FLASHINFER_LOGLEVEL or\n"
+ "set it to below 10. For more information, see https://docs.flashinfer.ai/logging.html"
+ )
+ print(f"Current dump directory is: {_DUMP_DIR}")
+ if _DUMP_SAFETENSORS:
+ print(
+ "⚠️ SAFETENSORS mode enabled: tensor stride/non-contiguity will NOT be preserved.\n"
+ " Tensors will be saved as contiguous. Use torch.save (default) to preserve strides."
+ )
+ if _DUMP_INCLUDE_PATTERNS:
+ print(f"Include filter: {_DUMP_INCLUDE_PATTERNS}")
+ if _DUMP_EXCLUDE_PATTERNS:
+ print(f"Exclude filter: {_DUMP_EXCLUDE_PATTERNS}")
+ print("=" * 80)
+
+
+def _should_dump_function(func_name: str) -> bool:
+ """
+ Check if a function should be dumped based on include/exclude filters.
+
+ Uses fnmatch-style patterns (wildcards: * for any chars, ? for single char).
+ Matching is case-sensitive.
+
+ Parameters
+ ----------
+ func_name : str
+ The function name to check. For class methods, this is formatted as
+ "ClassName.method_name" (e.g., "BatchDecodeWrapper.run").
+
+ Returns
+ -------
+ bool
+ True if the function should be dumped, False otherwise.
+
+ Filter Logic
+ ------------
+ 1. If FLASHINFER_DUMP_INCLUDE is set:
+ - Function must match at least one include pattern
+ - If it doesn't match any, return False (skip dump)
+ 2. If FLASHINFER_DUMP_EXCLUDE is set:
+ - If function matches any exclude pattern, return False (skip dump)
+ 3. Otherwise, return True (dump the function)
+ """
+ # If include patterns are specified, func must match at least one
+ if _DUMP_INCLUDE_PATTERNS:
+ if not any(fnmatch.fnmatch(func_name, pat) for pat in _DUMP_INCLUDE_PATTERNS):
+ return False
+
+ # If exclude patterns are specified, func must not match any
+ if _DUMP_EXCLUDE_PATTERNS:
+ if any(fnmatch.fnmatch(func_name, pat) for pat in _DUMP_EXCLUDE_PATTERNS):
+ return False
+
+ return True
+
+
+def _append_to_jsonl(filepath: Path, record: Dict[str, Any]) -> None:
+ """
+ Append a JSON record as a single line to a JSONL file.
+
+ Parameters
+ ----------
+ filepath : Path
+ Path to the JSONL file
+ record : Dict[str, Any]
+ Record to append (will be serialized as single-line JSON)
+ """
+ with open(filepath, "a") as f:
+ f.write(json.dumps(record) + "\n")
+
+
+def _read_jsonl_last_record(filepath: Path) -> Optional[Dict[str, Any]]:
+ """
+ Read the last record from a JSONL file.
+
+ For metadata.jsonl, this returns the most complete state (completed if available,
+ otherwise inputs_saved).
+
+ Parameters
+ ----------
+ filepath : Path
+ Path to the JSONL file
+
+ Returns
+ -------
+ Optional[Dict[str, Any]]
+ The last record, or None if file is empty/doesn't exist
+ """
+ if not filepath.exists():
+ return None
+
+ last_line = None
+ with open(filepath, "r") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ last_line = line
+
+ if last_line:
+ return json.loads(last_line)
+ return None
+
+
+def _get_tensor_size_bytes(tensor: torch.Tensor) -> int:
+ """Calculate the size of a tensor in bytes."""
+ return tensor.element_size() * tensor.nelement()
+
+
+def _serialize_value(value: Any) -> Any:
+ """
+ Convert a non-tensor value to a JSON-serializable format for metadata.
+
+ This function is intended for serializing non-tensor arguments/values
+ that are used in API input or output metadata. Tensor arguments are not handled here.
+ """
+ try:
+ if isinstance(value, torch.dtype):
+ # Special handling for torch.dtype
+ return {
+ "type": "torch.dtype",
+ "value": str(value), # e.g., "torch.bfloat16"
+ }
+ elif isinstance(value, enum.Enum):
+ return {
+ "type": "enum",
+ "name": f"{type(value).__name__}.{value.name}",
+ "value": value.value,
+ }
+ elif isinstance(value, (int, float, str, bool, type(None))):
+ return value
+ elif isinstance(value, (list, tuple, dict)):
+ return {
+ "type": type(value).__name__,
+ "value": str(value)[:1000],
+ } # Truncate long structures
+ else:
+ return {
+ "type": type(value).__name__,
+ "repr": str(value)[:1000],
+ }
+ except Exception:
+ return {
+ "type": type(value).__name__,
+ "repr": "",
+ }
+
+
+def _extract_tensors_and_metadata(
+ args: tuple, kwargs: dict
+) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+ """
+ Extract tensors and non-tensor metadata from function arguments.
+
+ Tensors are moved to CPU but preserve their stride/contiguity information.
+
+ Returns
+ -------
+ tensors : Dict[str, torch.Tensor]
+ Dictionary of tensor arguments with keys like "arg_0", "kwarg_name"
+ All tensors are on CPU with original stride preserved.
+ metadata : Dict[str, Any]
+ Dictionary of non-tensor arguments (serializable to JSON)
+ """
+ tensors = {}
+ metadata = {}
+
+ # Process positional arguments
+ for i, arg in enumerate(args):
+ key = f"arg_{i}"
+ if isinstance(arg, torch.Tensor):
+ tensors[key] = arg.cpu()
+ else:
+ metadata[key] = _serialize_value(arg)
+
+ # Process keyword arguments
+ for key, value in kwargs.items():
+ kwarg_key = f"kwarg_{key}"
+ if isinstance(value, torch.Tensor):
+ tensors[kwarg_key] = value.cpu()
+ else:
+ metadata[kwarg_key] = _serialize_value(value)
+
+ return tensors, metadata
+
+
+def _dump_function_inputs(
+ func: Callable,
+ func_name: str,
+ args: tuple,
+ kwargs: dict,
+ self_id: Optional[int] = None,
+) -> Optional[str]:
+ """
+ Dump function inputs to disk BEFORE execution (crash-safe).
+
+ This function:
+ 1. Extracts tensors and metadata from inputs
+ 2. Creates a timestamped directory
+ 3. Saves inputs.pt and partial metadata.json
+ 4. Tracks cumulative size and count limits
+
+ Parameters
+ ----------
+ func : Callable
+ The function being called
+ func_name : str
+ Name of the function
+ args : tuple
+ Positional arguments
+ kwargs : dict
+ Keyword arguments
+ self_id : Optional[int]
+ The id() of the 'self' object if this is a method call
+
+ Returns
+ -------
+ Optional[str]
+ Path to the dump directory, or None if dump was skipped
+ """
+ global _dump_count, _dump_total_size_bytes
+
+ # Check include/exclude filters first (before any work is done)
+ if not _should_dump_function(func_name):
+ _logger.debug(
+ f"Skipping dump for {func_name} (filtered by include/exclude patterns)"
+ )
+ return None
+
+ if _dump_count >= _DUMP_MAX_COUNT:
+ _logger.warning(
+ f"Dump limit reached ({_DUMP_MAX_COUNT} dumps). Skipping dump for {func_name}. "
+ f"Increase FLASHINFER_DUMP_MAX_COUNT if needed."
+ )
+ return None
+
+ try:
+ # Get call counter for this function
+ if func_name not in _dump_call_counter:
+ _dump_call_counter[func_name] = 0
+ _dump_call_counter[func_name] += 1
+ call_seq = _dump_call_counter[func_name]
+
+ # Create dump directory structure
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[
+ :-3
+ ] # Include milliseconds
+ pid = os.getpid()
+ dump_name = f"{timestamp}_pid{pid}_{func_name}_call{call_seq:04d}"
+ dump_dir = Path(_DUMP_DIR) / dump_name
+ dump_dir.mkdir(parents=True, exist_ok=True)
+
+ # Extract tensors and metadata from inputs
+ input_tensors, input_metadata = _extract_tensors_and_metadata(args, kwargs)
+
+ # Calculate input size
+ input_size = sum(_get_tensor_size_bytes(t) for t in input_tensors.values())
+
+ # Check size limit (conservative check - only inputs for now)
+ max_size_bytes = _DUMP_MAX_SIZE_GB * 1024 * 1024 * 1024
+ if _dump_total_size_bytes + input_size > max_size_bytes:
+ _logger.warning(
+ f"Dump size limit reached ({_DUMP_MAX_SIZE_GB} GB). Skipping dump for {func_name}. "
+ f"Increase FLASHINFER_DUMP_MAX_SIZE_GB if needed."
+ )
+ # Clean up empty directory
+ dump_dir.rmdir()
+ return None
+
+ # Save input tensors
+ if input_tensors:
+ if _DUMP_SAFETENSORS:
+ # SafeTensors format: faster, no pickle, but loses stride/contiguity
+ try:
+ from safetensors.torch import save_file
+
+ # safetensors requires contiguous tensors
+ tensors_contiguous = {
+ k: v.contiguous() for k, v in input_tensors.items()
+ }
+ save_file(tensors_contiguous, str(dump_dir / "inputs.safetensors"))
+ except ImportError:
+ _logger.error(
+ "safetensors package not installed. "
+ "Install with: pip install safetensors"
+ )
+ raise
+ else:
+ # torch.save format: preserves stride/contiguity
+ torch.save(input_tensors, dump_dir / "inputs.pt")
+
+ # Create partial metadata (inputs only, outputs will be added later)
+ metadata: Dict[str, Any] = {
+ "function_name": func_name,
+ "module": func.__module__ if hasattr(func, "__module__") else "",
+ "call_sequence": call_seq,
+ "timestamp": timestamp,
+ "process_id": os.getpid(),
+ "input_metadata": input_metadata,
+ "output_metadata": {}, # Placeholder, will be updated after execution
+ "tensor_info": {
+ "input_tensor_keys": list(input_tensors.keys()),
+ "output_tensor_keys": [], # Placeholder, will be updated after execution
+ "input_size_bytes": input_size,
+ "input_size_mb": input_size / (1024 * 1024),
+ },
+ "tensor_details": {}, # Detailed shape/dtype/stride info for reconstruction
+ "tensor_format": "safetensors" if _DUMP_SAFETENSORS else "torch",
+ "function_signature": str(inspect.signature(func))
+ if hasattr(inspect, "signature")
+ else "",
+ "versions": {
+ "torch": torch.__version__,
+ "python": sys.version,
+ },
+ "execution_status": "inputs_saved", # Will be updated to "completed" after outputs
+ }
+
+ # Add self_id to metadata if it is a class method call
+ if self_id is not None:
+ metadata["self_id"] = self_id
+
+ # Add tensor details for random generation fallback
+ for key, tensor in input_tensors.items():
+ metadata["tensor_details"][key] = {
+ "shape": list(tensor.shape),
+ "dtype": str(tensor.dtype),
+ "stride": list(tensor.stride()),
+ "device": str(tensor.device),
+ }
+
+ # Try to get FlashInfer version
+ try:
+ from .version import __version__ as flashinfer_version
+
+ metadata["versions"]["flashinfer"] = flashinfer_version # type: ignore[index]
+ except Exception:
+ metadata["versions"]["flashinfer"] = "" # type: ignore[index]
+
+ # Add dump_dir to metadata for central session.jsonl reference
+ metadata["dump_dir"] = str(dump_dir)
+
+ # Save metadata to per-dump JSONL (first line: inputs_saved)
+ _append_to_jsonl(dump_dir / "metadata.jsonl", metadata)
+
+ # Append to central session.jsonl for quick scanning
+ session_jsonl_path = Path(_DUMP_DIR) / "session.jsonl"
+ _append_to_jsonl(session_jsonl_path, metadata)
+
+ # Update global tracking (only input size for now)
+ _dump_count += 1
+ _dump_total_size_bytes += input_size
+
+ _logger.debug(
+ f"Dumped inputs to: {dump_dir} "
+ f"(size: {input_size / (1024 * 1024):.2f} MB, "
+ f"total: {_dump_count}/{_DUMP_MAX_COUNT} dumps)"
+ )
+
+ return str(dump_dir)
+
+ except Exception as e:
+ _logger.error(f"Failed to dump function call {func_name}: {e}")
+ import traceback
+
+ _logger.error(traceback.format_exc())
+ return None
+
+
+def _dump_function_outputs(dump_dir: str, result: Any) -> None:
+ """
+ Add function outputs to an existing dump directory (crash-safe).
+
+ This function is called AFTER successful execution to append outputs
+ to the dump that was created before execution.
+
+ Parameters
+ ----------
+ dump_dir : str
+ Path to the dump directory created by _dump_function_inputs
+ result : Any
+ Function return value
+ """
+ global _dump_total_size_bytes
+
+ try:
+ dump_path = Path(dump_dir)
+ if not dump_path.exists():
+ _logger.error(f"Dump directory not found: {dump_dir}")
+ return
+
+ # Extract tensors and metadata from outputs
+ output_tensors = {}
+ output_metadata = {}
+ if isinstance(result, torch.Tensor):
+ output_tensors["result"] = result.cpu()
+ elif isinstance(result, tuple):
+ for i, item in enumerate(result):
+ if isinstance(item, torch.Tensor):
+ output_tensors[f"result_{i}"] = item.cpu()
+ else:
+ output_metadata[f"result_{i}"] = _serialize_value(item)
+ else:
+ output_metadata["result"] = _serialize_value(result)
+
+ # Calculate output size
+ output_size = sum(_get_tensor_size_bytes(t) for t in output_tensors.values())
+
+ # Save output tensors
+ if output_tensors:
+ if _DUMP_SAFETENSORS:
+ # SafeTensors format: faster, no pickle, but loses stride/contiguity
+ from safetensors.torch import save_file
+
+ tensors_contiguous = {
+ k: v.contiguous() for k, v in output_tensors.items()
+ }
+ save_file(tensors_contiguous, str(dump_path / "outputs.safetensors"))
+ else:
+ # torch.save format: preserves stride/contiguity
+ torch.save(output_tensors, dump_path / "outputs.pt")
+
+ # Load existing metadata from JSONL (last record) and update it
+ metadata_jsonl_path = dump_path / "metadata.jsonl"
+ metadata = _read_jsonl_last_record(metadata_jsonl_path)
+
+ if metadata is not None:
+ # Update with output information
+ metadata["output_metadata"] = output_metadata
+ metadata["tensor_info"]["output_tensor_keys"] = list(output_tensors.keys())
+ metadata["tensor_info"]["output_size_bytes"] = output_size
+ metadata["tensor_info"]["output_size_mb"] = output_size / (1024 * 1024)
+ metadata["tensor_info"]["total_size_bytes"] = (
+ metadata["tensor_info"]["input_size_bytes"] + output_size
+ )
+ metadata["tensor_info"]["total_size_mb"] = metadata["tensor_info"][
+ "total_size_bytes"
+ ] / (1024 * 1024)
+ metadata["execution_status"] = "completed"
+
+ # Add output tensor details
+ if "tensor_details" not in metadata:
+ metadata["tensor_details"] = {}
+ for key, tensor in output_tensors.items():
+ metadata["tensor_details"][key] = {
+ "shape": list(tensor.shape),
+ "dtype": str(tensor.dtype),
+ "stride": list(tensor.stride()),
+ "device": str(tensor.device),
+ }
+
+ # Append completion record to per-dump JSONL
+ _append_to_jsonl(metadata_jsonl_path, metadata)
+
+ # Append completion record to central session.jsonl
+ session_jsonl_path = Path(_DUMP_DIR) / "session.jsonl"
+ _append_to_jsonl(session_jsonl_path, metadata)
+
+ # Update global size tracking
+ _dump_total_size_bytes += output_size
+
+ _logger.debug(
+ f"Dumped outputs to: {dump_dir} "
+ f"(output size: {output_size / (1024 * 1024):.2f} MB, "
+ f"total dump size: {metadata['tensor_info']['total_size_mb']:.2f} MB)"
+ )
+ else:
+ _logger.error(f"metadata.jsonl not found or empty in {dump_dir}")
+
+ except Exception as e:
+ _logger.error(f"Failed to dump outputs to {dump_dir}: {e}")
+ import traceback
+
+ _logger.error(traceback.format_exc())
+
+
+def _reconstruct_value(value: Any) -> Any:
+ """
+ Reconstruct special types from metadata format.
+
+ Handles:
+ - torch.dtype objects
+ - enum.Enum objects (future)
+ - Other serialized types
+ """
+ if isinstance(value, dict):
+ value_type = value.get("type")
+
+ if value_type == "torch.dtype":
+ # Reconstruct torch.dtype from string
+ dtype_str = value.get("value", "")
+ # Parse strings like "torch.bfloat16", "torch.float16", etc.
+ dtype_name = dtype_str.replace("torch.", "")
+ try:
+ return getattr(torch, dtype_name)
+ except AttributeError:
+ _logger.warning(f"Could not reconstruct dtype: {dtype_str}")
+ return value
+
+ # For other dict types, return as-is
+ return value
+
+ return value
+
+
+def _resolve_function(module_name: str, function_name: str) -> Optional[Callable]:
+ """Resolve a function from module name and function name."""
+ try:
+ module = importlib.import_module(module_name)
+ # Handle nested function names (e.g. Class.method)
+ parts = function_name.split(".")
+ obj: Any = module
+ for part in parts:
+ obj = getattr(obj, part)
+ if not callable(obj):
+ return None
+ return obj
+ except Exception as e:
+ _logger.warning(
+ f"Could not resolve function {module_name}.{function_name}: {e}"
+ )
+ return None
+
+
+def _compare_results(
+ actual: Any, expected: Any, rtol: float = 1e-3, atol: float = 1e-3
+) -> bool:
+ """Recursively compare execution results."""
+ # torch.Tensor comparison
+ if isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor):
+ # Check shape
+ if actual.shape != expected.shape:
+ _logger.warning(
+ f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}"
+ )
+ return False
+ # Check dtype
+ if actual.dtype != expected.dtype:
+ _logger.warning(
+ f"Dtype mismatch: actual {actual.dtype} vs expected {expected.dtype}"
+ )
+ return False
+
+ # Check values; apply relative and absolute tolerance.
+ if not torch.allclose(actual, expected, rtol=rtol, atol=atol):
+ diff = (actual - expected).abs().max().item()
+ _logger.warning(f"Value mismatch: max diff {diff}")
+ return False
+ return True
+
+ # list/tuple comparison
+ elif isinstance(actual, (list, tuple)) and isinstance(expected, (list, tuple)):
+ if len(actual) != len(expected):
+ _logger.warning(
+ f"Length mismatch: actual {len(actual)} vs expected {len(expected)}"
+ )
+ return False
+ return all(
+ _compare_results(a, e, rtol, atol)
+ for a, e in zip(actual, expected, strict=True)
+ )
+
+ # dict comparison
+ elif isinstance(actual, dict) and isinstance(expected, dict):
+ if actual.keys() != expected.keys():
+ _logger.warning(
+ f"Key mismatch: actual {actual.keys()} vs expected {expected.keys()}"
+ )
+ return False
+ return all(_compare_results(actual[k], expected[k], rtol, atol) for k in actual)
+
+ # fallback for other types (including None). Just do a naive comparison.
+ else:
+ if actual != expected:
+ _logger.warning(f"Value mismatch: actual {actual} vs expected {expected}")
+ return False
+ return True
+
+
+def replay_from_dump(
+ dump_dir: str,
+ compare_outputs: bool = False,
+ device: str = "cuda",
+ run: bool = False,
+ object_registry: Optional[Dict[Tuple[int, int], Any]] = None,
+) -> Any:
+ """
+ Replay a function call from a dumped directory.
+
+ This function:
+ 1. Loads metadata.jsonl to get function info
+ 2. Loads inputs.pt to get input tensors
+ 3. Moves tensors to specified device (default: cuda)
+ 4. Reconstructs the function call
+ 5. Optionally executes the function (if run=True)
+ 6. Optionally compares with saved outputs
+
+ Parameters
+ ----------
+ dump_dir : str
+ Path to the dump directory
+ compare_outputs : bool
+ If True, load and compare with saved outputs
+ device : str
+ Target device for tensors. Options:
+ - "cuda" (default): Load to cuda:0
+ - "cpu": Load to CPU
+ - "cuda:N": Load to specific CUDA device
+ run : bool
+ If True, try to resolve and execute the function
+ object_registry : Optional[Dict[Tuple[int, int], Any]]
+ Registry of stateful objects mapped by (process_id, self_id) tuple.
+ This composite key ensures objects from different processes don't collide
+ in multi-GPU environments where different processes may have objects
+ at the same memory address.
+
+ Returns
+ -------
+ result : dict
+ Dictionary containing:
+ - 'args': Positional arguments (tensors on specified device)
+ - 'kwargs': Keyword arguments (tensors on specified device)
+ - 'metadata': Full metadata
+ - 'execution_result': Result of execution (if run=True)
+ - 'comparison_match': Boolean indicating if result matched expected (if run=True and compare_outputs=True)
+ If compare_outputs=True, also includes:
+ - 'expected_tensors': Expected output tensors
+ - 'expected_metadata': Expected output metadata
+ """
+ dump_path = Path(dump_dir)
+ if not dump_path.exists():
+ raise FileNotFoundError(f"Dump directory not found: {dump_dir}")
+
+ # Load metadata from JSONL (last record has most complete state)
+ metadata_jsonl_path = dump_path / "metadata.jsonl"
+ if not metadata_jsonl_path.exists():
+ raise FileNotFoundError(f"metadata.jsonl not found in {dump_dir}")
+
+ metadata = _read_jsonl_last_record(metadata_jsonl_path)
+ if metadata is None:
+ raise ValueError(f"metadata.jsonl is empty in {dump_dir}")
+
+ func_name = metadata["function_name"]
+
+ # Load input tensors - auto-detect format (torch.save or safetensors)
+ inputs_pt_path = dump_path / "inputs.pt"
+ inputs_safetensors_path = dump_path / "inputs.safetensors"
+
+ if inputs_pt_path.exists():
+ input_tensors = torch.load(str(inputs_pt_path), map_location="cpu")
+ elif inputs_safetensors_path.exists():
+ try:
+ from safetensors.torch import load_file
+
+ input_tensors = load_file(str(inputs_safetensors_path), device="cpu")
+ except ImportError:
+ raise ImportError(
+ "Dump was saved with safetensors but package not installed. "
+ "Install with: pip install safetensors"
+ ) from None
+ else:
+ raise FileNotFoundError(
+ f"Neither inputs.pt nor inputs.safetensors found in {dump_dir}"
+ )
+
+ # Move tensors to specified device
+ for key, tensor in input_tensors.items():
+ input_tensors[key] = tensor.to(device)
+
+ # Reconstruct args and kwargs
+ args = []
+ kwargs = {}
+ input_metadata = metadata.get("input_metadata", {})
+
+ # Get max arg index from both tensors and metadata
+ max_arg_idx = -1
+
+ for key in input_tensors.keys():
+ if key.startswith("arg_"):
+ idx = int(key.split("_")[1])
+ max_arg_idx = max(max_arg_idx, idx)
+
+ for key in input_metadata.keys():
+ if key.startswith("arg_"):
+ idx = int(key.split("_")[1])
+ max_arg_idx = max(max_arg_idx, idx)
+
+ # Reconstruct positional args in order so that we can replay
+ # the function call exactly as it was logged.
+ for i in range(max_arg_idx + 1):
+ key = f"arg_{i}"
+ if key in input_tensors:
+ args.append(input_tensors[key])
+ elif key in input_metadata:
+ args.append(_reconstruct_value(input_metadata[key]))
+ else:
+ # Should not happen if dump is consistent, but safeguard
+ _logger.warning(f"Missing argument {i} in dump.")
+ args.append(None)
+
+ # Add keyword arguments. Here the ordering is not important.
+ for key in input_tensors.keys():
+ if key.startswith("kwarg_"):
+ kwarg_name = key.replace("kwarg_", "")
+ kwargs[kwarg_name] = input_tensors[key]
+
+ for key in input_metadata.keys():
+ if key.startswith("kwarg_"):
+ kwarg_name = key.replace("kwarg_", "")
+ if kwarg_name not in kwargs: # Don't override tensor kwargs
+ kwargs[kwarg_name] = _reconstruct_value(input_metadata[key])
+
+ _logger.info(f"Replaying {func_name} from {dump_dir}")
+ _logger.info(f" Args: {len(args)}, Kwargs: {list(kwargs.keys())}")
+
+ result_dict: Dict[str, Any] = {"args": args, "kwargs": kwargs, "metadata": metadata}
+
+ # Load expected outputs if needed - auto-detect format
+ expected_outputs = {}
+ output_metadata = {}
+ if compare_outputs:
+ outputs_pt_path = dump_path / "outputs.pt"
+ outputs_safetensors_path = dump_path / "outputs.safetensors"
+
+ if outputs_pt_path.exists():
+ expected_outputs = torch.load(str(outputs_pt_path), map_location="cpu")
+ elif outputs_safetensors_path.exists():
+ try:
+ from safetensors.torch import load_file
+
+ expected_outputs = load_file(
+ str(outputs_safetensors_path), device="cpu"
+ )
+ except ImportError:
+ raise ImportError(
+ "Dump was saved with safetensors but package not installed. "
+ "Install with: pip install safetensors"
+ ) from None
+
+ # Move output tensors to specified device
+ for key, tensor in expected_outputs.items():
+ expected_outputs[key] = tensor.to(device)
+
+ output_metadata = metadata.get("output_metadata", {})
+ result_dict["expected_tensors"] = expected_outputs
+ result_dict["expected_metadata"] = output_metadata
+
+ if run:
+ module_name = metadata.get("module")
+ self_id = metadata.get("self_id")
+ process_id = metadata.get("process_id")
+
+ func = None
+ obj = None
+
+ # Stateful replay logic for class methods calls.
+ # Necessary for wrapped classes like BatchDecodeWithPagedKVCacheWrapper.
+ # Use (process_id, self_id) as composite key to avoid collisions across processes.
+ # In multi-GPU environments, different processes may have objects with the same
+ # memory address (self_id), so we need to scope by process_id.
+ if self_id is not None:
+ registry_key = (process_id, self_id)
+ if func_name.endswith(".__init__"):
+ # This is a constructor call
+ # Resolution: Get the class and instantiate it
+ class_name = func_name.split(".")[
+ -2
+ ] # e.g. "Wrapper.__init__" -> "Wrapper"
+ cls_obj = _resolve_function(module_name, class_name)
+ if cls_obj and callable(cls_obj):
+ # Instantiate: obj = Class(*args[1:], **kwargs)
+ # Note: args[0] is 'self' placeholder in the dump for __init__, skip it
+ real_args = args[1:] if len(args) > 0 else []
+ try:
+ _logger.info(
+ f"Instantiating {class_name} (PID: {process_id}, ID: {self_id})..."
+ )
+ # We need to handle the case where __init__ is called.
+ # The safest way is to just call the class constructor.
+ # We assume the logged args match the constructor args.
+ obj = cls_obj(*real_args, **kwargs)
+ if object_registry is not None:
+ object_registry[registry_key] = obj
+ # __init__ returns None, but effectively we returned the object
+ execution_result = None
+ result_dict["execution_result"] = execution_result
+
+ # Since we successfully "ran" (instantiated), we can mark it done
+ # But there is no output to compare for __init__ usually (returns None)
+ if compare_outputs:
+ result_dict["comparison_match"] = (
+ True # Trivial pass for __init__
+ )
+ return result_dict
+ except Exception as e:
+ _logger.error(f"Failed to instantiate {class_name}: {e}")
+ result_dict["execution_error"] = str(e)
+ return result_dict
+ else:
+ # Instance method call
+ if object_registry is not None and registry_key in object_registry:
+ obj = object_registry[registry_key]
+ method_name = func_name.split(".")[-1]
+ if hasattr(obj, method_name):
+ func = getattr(obj, method_name)
+ # args[0] is 'self' placeholder, skip it
+ args = args[1:] if len(args) > 0 else []
+ else:
+ _logger.warning(f"Object {obj} has no method {method_name}")
+ else:
+ _logger.warning(
+ f"Object (PID: {process_id}, ID: {self_id}) not found in registry."
+ )
+
+ if func is None:
+ func = _resolve_function(module_name, func_name)
+
+ if func:
+ try:
+ _logger.info(f"Executing {module_name}.{func_name}...")
+ execution_result = func(*args, **kwargs)
+ result_dict["execution_result"] = execution_result
+
+ if compare_outputs:
+ # Flatten execution result to dict for comparison
+ actual_outputs = {}
+ if isinstance(execution_result, torch.Tensor):
+ actual_outputs["result"] = execution_result
+ elif isinstance(execution_result, (tuple, list)):
+ for i, item in enumerate(execution_result):
+ if isinstance(item, torch.Tensor):
+ actual_outputs[f"result_{i}"] = item
+ elif isinstance(execution_result, dict):
+ # If result is already a dict of tensors? Unlikely for FlashInfer but possible
+ actual_outputs = execution_result
+
+ # Compare tensors
+ match = True
+ if expected_outputs:
+ match = _compare_results(actual_outputs, expected_outputs)
+
+ result_dict["comparison_match"] = match
+ if match:
+ _logger.info("Replay comparison passed!")
+ else:
+ _logger.warning("Replay comparison FAILED.")
+
+ except Exception as e:
+ _logger.error(f"Execution failed: {e}")
+ import traceback
+
+ _logger.error(traceback.format_exc())
+ result_dict["execution_error"] = str(e)
+ else:
+ _logger.warning(
+ f"Skipping execution: could not resolve {module_name}.{func_name}"
+ )
+ elif not compare_outputs:
+ _logger.warning(
+ "Automatic function resolution disabled. "
+ "Pass run=True to execute, or manually call function."
+ )
+
+ return result_dict
+
+
+def replay_sequence(root_dir: str, device: str = "cuda") -> list:
+ """
+ Replay a sequence of API calls from a root dump directory.
+
+ This function iterates through all dump directories in the root directory,
+ sorted by timestamp/sequence number, and replays them in order.
+
+ Parameters
+ ----------
+ root_dir : str
+ Path to the root directory containing dump subdirectories
+ device : str
+ Target device for execution (default: "cuda")
+
+ Returns
+ -------
+ list
+ List of results from replay_from_dump calls
+ """
+ root_path = Path(root_dir)
+ if not root_path.exists():
+ raise FileNotFoundError(f"Root dump directory not found: {root_dir}")
+
+ # Find all subdirectories that look like dumps
+ # Pattern: YYYYMMDD_HHMMSS_milliseconds_pid_funcname_callXXXX
+ dump_dirs = []
+ for item in root_path.iterdir():
+ if item.is_dir() and (item / "metadata.jsonl").exists():
+ dump_dirs.append(item)
+
+ # Sort by directory name (which starts with timestamp)
+ dump_dirs.sort(key=lambda x: x.name)
+
+ results = []
+ total = len(dump_dirs)
+ _logger.info(f"Found {total} dumps to replay from {root_dir}")
+
+ # Registry for stateful objects (mapped by (process_id, self_id) tuple)
+ # This composite key prevents collisions in multi-GPU/multi-process environments
+ object_registry: Dict[Tuple[int, int], Any] = {}
+
+ for i, dump_dir in enumerate(dump_dirs):
+ _logger.info(f"[{i + 1}/{total}] Replaying {dump_dir.name}...")
+ try:
+ # We assume that for sequence replay, we want to EXECUTE the calls
+ # and assume outputs are not necessarily present or we just want to verify it runs.
+ # If outputs are present, we can compare.
+ res = replay_from_dump(
+ str(dump_dir),
+ compare_outputs=True,
+ device=device,
+ run=True,
+ object_registry=object_registry,
+ )
+ # Add dump_dir to the result for CLI reporting
+ res["dump_dir"] = str(dump_dir)
+ results.append(res)
+ except Exception as e:
+ # Let's record error and continue.
+ _logger.error(f"Failed to replay {dump_dir.name}: {e}")
+ results.append({"error": str(e), "dump_dir": str(dump_dir)})
+
+ return results
+
+
def _log_system_info():
"""Log system information once at module initialization."""
if _API_LOG_LEVEL == 0:
@@ -160,6 +1122,7 @@ def _log_system_info():
# Log system information once at module load time (if logging is enabled)
_log_system_info()
+_warn_dump()
def _format_value(value: Any, level: int, indent: int = 0) -> str:
@@ -174,7 +1137,6 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str:
The logging level (1, 2, or 3)
indent : int
The indentation level for nested structures
-
Returns
-------
str
@@ -349,7 +1311,6 @@ def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict:
Positional arguments that were provided
kwargs : dict
Keyword arguments that were provided
-
Returns
-------
dict
@@ -358,7 +1319,6 @@ def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict:
try:
sig = inspect.signature(func)
default_params = {}
-
# Determine which parameters were NOT provided
for i, (param_name, param) in enumerate(sig.parameters.items()):
# Skip if parameter has no default
@@ -416,7 +1376,6 @@ def _log_function_inputs(
for i, arg in enumerate(args):
lines.append(f" arg[{i}]:")
lines.append(_format_value(arg, level, indent=2))
-
# Keyword arguments
if kwargs:
lines.append("Keyword input arguments:")
@@ -425,7 +1384,6 @@ def _log_function_inputs(
lines.append(_format_value(value, level, indent=2))
else:
lines.append("(No explicit arguments)")
-
# Log default parameters that were not explicitly provided
default_params = _get_default_params(func, args, kwargs)
if default_params:
@@ -433,14 +1391,12 @@ def _log_function_inputs(
for param_name, default_value in default_params.items():
lines.append(f" {param_name}= [DEFAULT]")
lines.append(_format_value(default_value, level, indent=2))
-
_logger.debug("\n".join(lines))
def _log_function_outputs(func_name: str, result: Any, level: int) -> None:
"""
Log function outputs AFTER successful execution.
-
Parameters
----------
func_name : str
@@ -465,6 +1421,9 @@ def flashinfer_api(func: Callable = None) -> Callable:
"""
Decorator to FlashInfer's APIs.
+ .. warning::
+ This API logging feature is experimental and may change in future versions.
+
Currently logs input and output values of the function using Python's logging library.
This decorator integrates with Python's standard logging infrastructure while
maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0).
@@ -478,6 +1437,7 @@ def flashinfer_api(func: Callable = None) -> Callable:
- 1: Log function name only (logged BEFORE execution - crash-safe)
- 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe)
- 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe)
+ - 10: Level 5 logging + dump metadata and input/output tensors to disk for reproducibility (preserves stride/contiguity)
FLASHINFER_LOGDEST : str (default: "stdout")
- "stdout": Log to standard output
@@ -485,6 +1445,26 @@ def flashinfer_api(func: Callable = None) -> Callable:
- : Log to specified file path
- Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt")
+ Level 10 Tensor Dumping (additional variables):
+ FLASHINFER_DUMP_DIR : str (default: "flashinfer_dumps")
+ - Directory where tensor dumps are saved
+
+ FLASHINFER_DUMP_MAX_SIZE_GB : float (default: 20)
+ - Maximum total size of dumps in GB
+
+ FLASHINFER_DUMP_MAX_COUNT : int (default: 1000)
+ - Maximum number of function call dumps
+
+ FLASHINFER_DUMP_SAFETENSORS : int (default: 0)
+ - 0: Use torch.save format (preserves stride/contiguity)
+ - 1: Use safetensors format (no pickle, but loses stride info)
+
+ FLASHINFER_DUMP_INCLUDE : str (default: "")
+ - Comma-separated list of patterns to include for dumping (fnmatch-style)
+
+ FLASHINFER_DUMP_EXCLUDE : str (default: "")
+ - Comma-separated list of patterns to exclude for dumping (fnmatch-style)
+
Examples
--------
Basic usage:
@@ -522,6 +1502,7 @@ def decorator(f: Callable) -> Callable:
def wrapper(*args, **kwargs):
# Determine function name (with class name if applicable)
func_name = f.__name__
+ self_id = None
if args and hasattr(args[0], "__class__"):
try:
class_name = args[0].__class__.__name__
@@ -529,9 +1510,22 @@ def wrapper(*args, **kwargs):
"BatchMLAPagedAttentionWrapper"
]:
func_name = f"{class_name}.{func_name}"
+ self_id = id(args[0])
except Exception:
pass
+ # Level 10: Dump inputs BEFORE execution (crash-safe)
+ dump_dir = None
+ if _API_LOG_LEVEL >= 10:
+ try:
+ dump_dir = _dump_function_inputs(
+ f, func_name, args, kwargs, self_id=self_id
+ )
+ if dump_dir:
+ _logger.debug(f"Inputs dumped to: {dump_dir}")
+ except Exception as e:
+ _logger.error(f"[DUMP ERROR (inputs) in {func_name}]: {e}")
+
# Log BEFORE execution (crash-safe for all levels!)
try:
if _API_LOG_LEVEL == 1:
@@ -541,7 +1535,9 @@ def wrapper(*args, **kwargs):
)
elif _API_LOG_LEVEL >= 3:
# Level 3+: Log full inputs before execution (crash-safe)
- _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL)
+ # For level 10, we use level 5 logging (includes statistics)
+ effective_level = min(_API_LOG_LEVEL, 5) # Cap at 5 for logging
+ _log_function_inputs(f, func_name, args, kwargs, effective_level)
except Exception as e:
_logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}")
@@ -552,10 +1548,19 @@ def wrapper(*args, **kwargs):
try:
if _API_LOG_LEVEL >= 3:
# Level 3+: Log outputs (inputs were already logged above)
- _log_function_outputs(func_name, result, _API_LOG_LEVEL)
+ effective_level = min(_API_LOG_LEVEL, 5)
+ _log_function_outputs(func_name, result, effective_level)
except Exception as e:
_logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}")
+ # Level 10: Dump outputs AFTER successful execution (crash-safe)
+ if _API_LOG_LEVEL >= 10 and dump_dir:
+ try:
+ _dump_function_outputs(dump_dir, result)
+ _logger.info(f"Outputs dumped to: {dump_dir}")
+ except Exception as e:
+ _logger.error(f"[DUMP ERROR (outputs) in {func_name}]: {e}")
+
return result
return wrapper
diff --git a/tests/utils/test_logging_replay.py b/tests/utils/test_logging_replay.py
new file mode 100644
index 0000000000..11cd83fb80
--- /dev/null
+++ b/tests/utils/test_logging_replay.py
@@ -0,0 +1,1067 @@
+"""
+Copyright (c) 2025 by FlashInfer team.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+"""
+Tests for Level 10 API logging with actual FlashInfer APIs.
+
+This test suite verifies that Level 10 logging (tensor dumping) works correctly
+with all decorated FlashInfer APIs. For each API, we:
+1. Run the API with Level 10 logging enabled
+2. Verify a dump was created
+3. Load the dump using replay_from_dump
+4. Run the API again with replayed inputs
+5. Verify: original_output ≈ dumped_output ≈ replayed_output
+"""
+
+import os
+import sys
+
+import pytest
+import torch
+
+from flashinfer.utils import get_compute_capability
+from tests.utils_fp8 import to_float8
+
+
+def _clean_flashinfer_modules():
+ """Remove flashinfer modules from sys.modules."""
+ modules_to_delete = [k for k in sys.modules.keys() if k.startswith("flashinfer")]
+ for module in modules_to_delete:
+ del sys.modules[module]
+
+
+@pytest.fixture
+def level10_environment(tmp_path):
+ """Set up test environment and clean up after each test."""
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ "FLASHINFER_DUMP_SAFETENSORS": os.environ.get("FLASHINFER_DUMP_SAFETENSORS"),
+ }
+
+ # Set up test environment
+ dump_dir = tmp_path / "test_dumps"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_MAX_COUNT"] = "1000"
+ os.environ["FLASHINFER_DUMP_MAX_SIZE_GB"] = "10"
+ # Clear any existing filters and safetensors mode
+ if "FLASHINFER_DUMP_INCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_INCLUDE"]
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+ if "FLASHINFER_DUMP_SAFETENSORS" in os.environ:
+ del os.environ["FLASHINFER_DUMP_SAFETENSORS"]
+
+ # Force reimport to pick up new environment variables
+ _clean_flashinfer_modules()
+
+ yield dump_dir
+
+ # Restore original environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+
+def verify_and_replay_dump(
+ dump_dir, original_output, func_to_replay, expected_dumps=1, dump_idx=0
+):
+ """Helper to verify dump creation and replay functionality."""
+ from flashinfer.api_logging import replay_sequence
+
+ # Replay sequence
+ results = replay_sequence(str(dump_dir), device="cuda")
+
+ # Filter results if checking for specific function
+ if func_to_replay:
+ target_name = func_to_replay.__name__
+ results = [
+ res
+ for res in results
+ if res.get("metadata", {}).get("function_name") == target_name
+ ]
+
+ assert len(results) >= expected_dumps, (
+ f"Expected at least {expected_dumps} dumps, found {len(results)}"
+ )
+
+ # Get the target result (usually the last one if multiple)
+ if dump_idx == -1:
+ result = results[-1]
+ else:
+ result = results[dump_idx]
+
+ # Verify comparison passed
+ assert result["comparison_match"] is True
+
+ # Verify execution result matches original (in-memory check)
+ execution_result = result["execution_result"]
+ assert torch.allclose(original_output, execution_result, atol=1e-5, rtol=1e-3)
+
+
+def test_replay_sequence(level10_environment):
+ """Test replaying a sequence of calls."""
+ from flashinfer import single_decode_with_kv_cache
+ from flashinfer.api_logging import replay_sequence
+
+ # Generate two calls
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call 1
+ q1 = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k1 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ v1 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ _ = single_decode_with_kv_cache(q1, k1, v1)
+
+ # Call 2
+ q2 = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k2 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ v2 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ _ = single_decode_with_kv_cache(q2, k2, v2)
+
+ # Replay sequence
+ results = replay_sequence(str(level10_environment), device="cuda")
+
+ assert len(results) == 2
+ for res in results:
+ assert res["comparison_match"] is True
+ assert "execution_result" in res
+
+
+def test_mm_fp8_replay(level10_environment):
+ """Test Level 10 logging with mm_fp8 API."""
+ compute_capability = get_compute_capability(torch.device("cuda"))
+ if compute_capability[0] not in [10]:
+ pytest.skip("mm_fp8 is only supported on Blackwell GPUs.")
+
+ from flashinfer import mm_fp8, autotune, prepare_low_latency_gemm_weights
+
+ # Test configuration
+ m, n, k = 4, 2560, 8192
+ input_dtype = torch.float8_e4m3fn
+ mat2_dtype = torch.float8_e4m3fn
+
+ # Create inputs
+ torch.manual_seed(42)
+ input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
+ input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
+
+ mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
+ mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
+
+ global_scale = input_inv_s * mat2_inv_s
+
+ _cache_permute_indices = {}
+ prepared_weights = prepare_low_latency_gemm_weights(
+ mat2_fp8, _cache_permute_indices
+ )
+
+ # Run API with Level 10 logging (without pre-allocated output)
+ with autotune():
+ original_output = mm_fp8(
+ input_fp8,
+ prepared_weights,
+ global_scale,
+ )
+
+ verify_and_replay_dump(level10_environment, original_output, mm_fp8)
+
+
+def test_bmm_fp8_replay(level10_environment):
+ """Test Level 10 logging with bmm_fp8 API."""
+
+ from flashinfer import bmm_fp8, autotune
+
+ # Test configuration
+ b, m, n, k = 1, 48, 80, 64
+ input_dtype = torch.float8_e4m3fn
+ mat2_dtype = torch.float8_e4m3fn
+ res_dtype = torch.bfloat16
+ backend = "cudnn"
+
+ # Create inputs
+ torch.manual_seed(42)
+ input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16)
+ input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
+
+ mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
+ mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
+
+ # Run API with Level 10 logging
+ with autotune():
+ original_output = bmm_fp8(
+ input_fp8,
+ mat2_fp8,
+ input_inv_s,
+ mat2_inv_s,
+ res_dtype,
+ backend=backend,
+ )
+
+ verify_and_replay_dump(level10_environment, original_output, bmm_fp8)
+
+
+def test_mm_fp4_replay(level10_environment):
+ """Test Level 10 logging with mm_fp4 API."""
+ compute_capability = get_compute_capability(torch.device("cuda"))
+ compute_capability_number = compute_capability[0] * 10 + compute_capability[1]
+
+ from flashinfer import mm_fp4, autotune, nvfp4_quantize, SfLayout
+
+ backend = "cudnn"
+ if not mm_fp4.is_backend_supported(backend, compute_capability_number):
+ pytest.skip(
+ f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}."
+ )
+
+ # Test configuration
+ m, n, k = 48, 128, 128
+ res_dtype = torch.bfloat16
+ use_128x4_sf_layout = True
+
+ # Create inputs
+ torch.manual_seed(42)
+ input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
+ mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
+
+ a_sf_layout = SfLayout.layout_128x4 if use_128x4_sf_layout else SfLayout.layout_8x4
+ global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max()
+ global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max()
+
+ input_fp4, input_inv_s = nvfp4_quantize(
+ input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False
+ )
+ mat2_fp4, mat2_inv_s = nvfp4_quantize(
+ mat2, global_sf_mat2, sfLayout=SfLayout.layout_128x4, do_shuffle=False
+ )
+
+ alpha = 1.0 / (global_sf_input * global_sf_mat2)
+ block_size = 16
+
+ # Run API with Level 10 logging (without pre-allocated output)
+ with autotune():
+ original_output = mm_fp4(
+ input_fp4,
+ mat2_fp4.T,
+ input_inv_s,
+ mat2_inv_s.T,
+ alpha,
+ res_dtype,
+ block_size=block_size,
+ use_8x4_sf_layout=not use_128x4_sf_layout,
+ backend=backend,
+ use_nvfp4=True,
+ skip_check=False,
+ )
+
+ verify_and_replay_dump(level10_environment, original_output, mm_fp4)
+
+
+def test_single_prefill_with_kv_cache_replay(level10_environment):
+ """Test Level 10 logging with single_prefill_with_kv_cache API."""
+ from flashinfer import single_prefill_with_kv_cache
+
+ # Test configuration
+ qo_len, kv_len = 127, 501
+ num_qo_heads, num_kv_heads = 4, 1
+ head_dim = 128
+ causal = False
+
+ # Create inputs
+ torch.manual_seed(42)
+ q = torch.randn(qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ v = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+
+ # Run API with Level 10 logging
+ o = single_prefill_with_kv_cache(q, k, v, causal=causal)
+
+ original_output = o.clone()
+
+ verify_and_replay_dump(
+ level10_environment, original_output, single_prefill_with_kv_cache
+ )
+
+
+def test_single_decode_with_kv_cache_replay(level10_environment):
+ """Test Level 10 logging with single_decode_with_kv_cache API."""
+ from flashinfer import single_decode_with_kv_cache
+
+ # Test configuration
+ kv_len = 1024
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Create inputs
+ torch.manual_seed(42)
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ v = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+
+ # Run API with Level 10 logging
+ o = single_decode_with_kv_cache(q, k, v)
+
+ original_output = o.clone()
+
+ verify_and_replay_dump(
+ level10_environment, original_output, single_decode_with_kv_cache
+ )
+
+
+def test_cli_replay(level10_environment):
+ """Test the CLI replay command."""
+ from click.testing import CliRunner
+ from flashinfer.__main__ import cli
+ from flashinfer import single_decode_with_kv_cache
+
+ # Create some data
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ v = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16)
+ single_decode_with_kv_cache(q, k, v)
+
+ runner = CliRunner()
+ result = runner.invoke(cli, ["replay", "--dir", str(level10_environment)])
+
+ assert result.exit_code == 0
+ assert "Replaying session from" in result.output
+ assert "Passed" in result.output
+ assert "Summary: 1 passed" in result.output
+
+
+# =============================================================================
+# Tests for FLASHINFER_DUMP_INCLUDE / FLASHINFER_DUMP_EXCLUDE filtering
+# =============================================================================
+
+
+def test_dump_include_filter(tmp_path):
+ """Test that FLASHINFER_DUMP_INCLUDE only dumps matching functions."""
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment with include filter for decode only
+ dump_dir = tmp_path / "test_dumps_include"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_INCLUDE"] = "*decode*"
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode (should be dumped)
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Call prefill (should NOT be dumped due to include filter)
+ qo_len = 64
+ q_prefill = torch.randn(
+ qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ k_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill)
+
+ # Check that only decode was dumped (filter for directories only, exclude session.jsonl)
+ dump_subdirs = (
+ [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else []
+ )
+ assert len(dump_subdirs) == 1, f"Expected 1 dump, found {len(dump_subdirs)}"
+
+ # Verify the dump is for decode
+ dump_name = dump_subdirs[0].name
+ assert "decode" in dump_name.lower(), f"Expected decode dump, got {dump_name}"
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_dump_exclude_filter(tmp_path):
+ """Test that FLASHINFER_DUMP_EXCLUDE skips matching functions."""
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment with exclude filter for prefill
+ dump_dir = tmp_path / "test_dumps_exclude"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_EXCLUDE"] = "*prefill*"
+ if "FLASHINFER_DUMP_INCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_INCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode (should be dumped)
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Call prefill (should NOT be dumped due to exclude filter)
+ qo_len = 64
+ q_prefill = torch.randn(
+ qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ k_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill)
+
+ # Check that only decode was dumped (filter for directories only, exclude session.jsonl)
+ dump_subdirs = (
+ [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else []
+ )
+ assert len(dump_subdirs) == 1, f"Expected 1 dump, found {len(dump_subdirs)}"
+
+ # Verify the dump is for decode
+ dump_name = dump_subdirs[0].name
+ assert "decode" in dump_name.lower(), f"Expected decode dump, got {dump_name}"
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_dump_include_and_exclude_combined(tmp_path):
+ """Test that FLASHINFER_DUMP_INCLUDE and FLASHINFER_DUMP_EXCLUDE work together."""
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment: include all single_* APIs but exclude prefill
+ dump_dir = tmp_path / "test_dumps_combined"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_INCLUDE"] = "single_*"
+ os.environ["FLASHINFER_DUMP_EXCLUDE"] = "*prefill*"
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode (matches include, not excluded -> should be dumped)
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Call prefill (matches include BUT also matches exclude -> NOT dumped)
+ qo_len = 64
+ q_prefill = torch.randn(
+ qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ k_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill)
+
+ # Check that only decode was dumped (filter for directories only, exclude session.jsonl)
+ dump_subdirs = (
+ [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else []
+ )
+ assert len(dump_subdirs) == 1, f"Expected 1 dump, found {len(dump_subdirs)}"
+
+ # Verify the dump is for decode
+ dump_name = dump_subdirs[0].name
+ assert "decode" in dump_name.lower(), f"Expected decode dump, got {dump_name}"
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_dump_include_no_match(tmp_path):
+ """Test that no dumps are created when include filter matches nothing."""
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment with include filter that matches nothing
+ dump_dir = tmp_path / "test_dumps_nomatch"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_INCLUDE"] = "nonexistent_function_xyz"
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode (should NOT be dumped - doesn't match include filter)
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Check that no dumps were created (filter for directories only, exclude session.jsonl)
+ dump_subdirs = (
+ [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else []
+ )
+ assert len(dump_subdirs) == 0, f"Expected 0 dumps, found {len(dump_subdirs)}"
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_dump_multiple_include_patterns(tmp_path):
+ """Test that multiple comma-separated include patterns work."""
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment with multiple include patterns
+ dump_dir = tmp_path / "test_dumps_multi"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_INCLUDE"] = "*decode*, *prefill*"
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode (should be dumped)
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Call prefill (should also be dumped)
+ qo_len = 64
+ q_prefill = torch.randn(
+ qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ k_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v_prefill = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill)
+
+ # Check that both were dumped (filter for directories only, exclude session.jsonl)
+ dump_subdirs = (
+ [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else []
+ )
+ assert len(dump_subdirs) == 2, f"Expected 2 dumps, found {len(dump_subdirs)}"
+
+ # Verify we have one decode and one prefill
+ dump_names = [d.name.lower() for d in dump_subdirs]
+ has_decode = any("decode" in name for name in dump_names)
+ has_prefill = any("prefill" in name for name in dump_names)
+ assert has_decode, "Expected decode dump not found"
+ assert has_prefill, "Expected prefill dump not found"
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_jsonl_format(tmp_path):
+ """Test that JSONL format is used correctly for metadata files."""
+ import json
+
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment
+ dump_dir = tmp_path / "test_dumps_jsonl"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ if "FLASHINFER_DUMP_INCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_INCLUDE"]
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Verify dump directory was created
+ assert dump_dir.exists(), "Dump directory was not created"
+
+ # Find the dump subdirectory
+ dump_subdirs = [d for d in dump_dir.iterdir() if d.is_dir()]
+ assert len(dump_subdirs) == 1, (
+ f"Expected 1 dump subdir, found {len(dump_subdirs)}"
+ )
+ dump_subdir = dump_subdirs[0]
+
+ # Verify per-dump metadata.jsonl exists (not metadata.json)
+ metadata_jsonl_path = dump_subdir / "metadata.jsonl"
+ metadata_json_path = dump_subdir / "metadata.json"
+ assert metadata_jsonl_path.exists(), "metadata.jsonl was not created"
+ assert not metadata_json_path.exists(), "metadata.json should not exist"
+
+ # Verify per-dump metadata.jsonl has 2 lines (inputs_saved + completed)
+ with open(metadata_jsonl_path, "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+ assert len(lines) == 2, (
+ f"Expected 2 lines in metadata.jsonl, found {len(lines)}"
+ )
+
+ # Verify first line has execution_status="inputs_saved"
+ first_record = json.loads(lines[0])
+ assert first_record["execution_status"] == "inputs_saved"
+
+ # Verify second line has execution_status="completed"
+ second_record = json.loads(lines[1])
+ assert second_record["execution_status"] == "completed"
+ assert "output_metadata" in second_record
+ assert second_record["tensor_info"]["output_tensor_keys"]
+
+ # Verify central session.jsonl exists
+ session_jsonl_path = dump_dir / "session.jsonl"
+ assert session_jsonl_path.exists(), "session.jsonl was not created"
+
+ # Verify session.jsonl has 2 lines (inputs_saved + completed for this call)
+ with open(session_jsonl_path, "r") as f:
+ session_lines = [line.strip() for line in f if line.strip()]
+ assert len(session_lines) == 2, (
+ f"Expected 2 lines in session.jsonl, found {len(session_lines)}"
+ )
+
+ # Verify session.jsonl records match per-dump records
+ session_first = json.loads(session_lines[0])
+ session_second = json.loads(session_lines[1])
+ assert session_first["execution_status"] == "inputs_saved"
+ assert session_second["execution_status"] == "completed"
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_session_jsonl_multiple_calls(tmp_path):
+ """Test that session.jsonl accumulates records from multiple API calls."""
+ import json
+
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ }
+
+ try:
+ # Set up environment
+ dump_dir = tmp_path / "test_dumps_session"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ if "FLASHINFER_DUMP_INCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_INCLUDE"]
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call 1: decode
+ q1 = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k1 = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v1 = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q1, k1, v1)
+
+ # Call 2: prefill
+ qo_len = 64
+ q2 = torch.randn(
+ qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ k2 = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v2 = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_prefill_with_kv_cache(q2, k2, v2)
+
+ # Verify session.jsonl has 4 lines (2 per call: inputs_saved + completed)
+ session_jsonl_path = dump_dir / "session.jsonl"
+ assert session_jsonl_path.exists(), "session.jsonl was not created"
+
+ with open(session_jsonl_path, "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+ assert len(lines) == 4, f"Expected 4 lines in session.jsonl, found {len(lines)}"
+
+ # Verify the structure: inputs_saved, completed, inputs_saved, completed
+ records = [json.loads(line) for line in lines]
+ assert records[0]["execution_status"] == "inputs_saved"
+ assert records[1]["execution_status"] == "completed"
+ assert records[2]["execution_status"] == "inputs_saved"
+ assert records[3]["execution_status"] == "completed"
+
+ # Verify we have both function names
+ func_names = {r["function_name"] for r in records}
+ assert "single_decode_with_kv_cache" in func_names
+ assert "single_prefill_with_kv_cache" in func_names
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_safetensors_format(tmp_path):
+ """Test that FLASHINFER_DUMP_SAFETENSORS uses safetensors format."""
+ pytest.importorskip("safetensors", reason="safetensors package not installed")
+
+ import json
+
+ # Store original environment
+ original_env = {
+ "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"),
+ "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"),
+ "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"),
+ "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"),
+ "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"),
+ "FLASHINFER_DUMP_SAFETENSORS": os.environ.get("FLASHINFER_DUMP_SAFETENSORS"),
+ }
+
+ try:
+ # Set up environment with safetensors enabled
+ dump_dir = tmp_path / "test_dumps_safetensors"
+ os.environ["FLASHINFER_LOGLEVEL"] = "10"
+ os.environ["FLASHINFER_LOGDEST"] = "stdout"
+ os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir)
+ os.environ["FLASHINFER_DUMP_SAFETENSORS"] = "1"
+ if "FLASHINFER_DUMP_INCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_INCLUDE"]
+ if "FLASHINFER_DUMP_EXCLUDE" in os.environ:
+ del os.environ["FLASHINFER_DUMP_EXCLUDE"]
+
+ # Force reimport
+ _clean_flashinfer_modules()
+
+ from flashinfer import single_decode_with_kv_cache
+
+ # Test configuration
+ kv_len = 128
+ num_qo_heads, num_kv_heads = 32, 8
+ head_dim = 128
+
+ # Call decode
+ q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16)
+ k = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ v = torch.randn(
+ kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16
+ )
+ single_decode_with_kv_cache(q, k, v)
+
+ # Verify dump directory was created
+ assert dump_dir.exists(), "Dump directory was not created"
+
+ # Find the dump subdirectory
+ dump_subdirs = [d for d in dump_dir.iterdir() if d.is_dir()]
+ assert len(dump_subdirs) == 1, (
+ f"Expected 1 dump subdir, found {len(dump_subdirs)}"
+ )
+ dump_subdir = dump_subdirs[0]
+
+ # Verify safetensors files exist (not .pt files)
+ inputs_safetensors = dump_subdir / "inputs.safetensors"
+ outputs_safetensors = dump_subdir / "outputs.safetensors"
+ inputs_pt = dump_subdir / "inputs.pt"
+ outputs_pt = dump_subdir / "outputs.pt"
+
+ assert inputs_safetensors.exists(), "inputs.safetensors was not created"
+ assert outputs_safetensors.exists(), "outputs.safetensors was not created"
+ assert not inputs_pt.exists(), "inputs.pt should not exist in safetensors mode"
+ assert not outputs_pt.exists(), (
+ "outputs.pt should not exist in safetensors mode"
+ )
+
+ # Verify metadata has tensor_format field
+ with open(dump_subdir / "metadata.jsonl", "r") as f:
+ lines = [line.strip() for line in f if line.strip()]
+ last_record = json.loads(lines[-1])
+ assert last_record.get("tensor_format") == "safetensors", (
+ "tensor_format should be 'safetensors'"
+ )
+
+ # Verify replay works with safetensors format
+ _clean_flashinfer_modules()
+
+ from flashinfer.api_logging import replay_sequence
+
+ results = replay_sequence(str(dump_dir), device="cuda")
+ assert len(results) == 1
+ assert results[0]["comparison_match"] is True
+
+ finally:
+ # Restore environment
+ for key, value in original_env.items():
+ if value is not None:
+ os.environ[key] = value
+ elif key in os.environ:
+ del os.environ[key]
+ _clean_flashinfer_modules()
+
+
+def test_safetensors_replay_auto_detection(tmp_path):
+ """Test that replay auto-detects safetensors format."""
+ pytest.importorskip("safetensors", reason="safetensors package not installed")
+
+ from safetensors.torch import save_file
+
+ # Create a mock dump with safetensors files
+ dump_subdir = tmp_path / "mock_dump"
+ dump_subdir.mkdir()
+
+ # Create mock input tensors
+ input_tensors = {
+ "arg_0": torch.randn(32, 128, dtype=torch.float16),
+ "arg_1": torch.randn(128, 8, 128, dtype=torch.float16),
+ }
+ save_file(input_tensors, str(dump_subdir / "inputs.safetensors"))
+
+ # Create mock metadata.jsonl
+ import json
+
+ metadata = {
+ "function_name": "test_function",
+ "module": "test_module",
+ "call_sequence": 1,
+ "timestamp": "20260108_120000_000",
+ "process_id": os.getpid(),
+ "input_metadata": {},
+ "output_metadata": {},
+ "tensor_info": {
+ "input_tensor_keys": ["arg_0", "arg_1"],
+ "output_tensor_keys": [],
+ "input_size_bytes": 0,
+ "input_size_mb": 0,
+ },
+ "tensor_format": "safetensors",
+ "execution_status": "inputs_saved",
+ }
+ with open(dump_subdir / "metadata.jsonl", "w") as f:
+ f.write(json.dumps(metadata) + "\n")
+
+ # Force reimport to get clean state
+ _clean_flashinfer_modules()
+
+ from flashinfer.api_logging import replay_from_dump
+
+ # Replay should auto-detect safetensors format
+ result = replay_from_dump(
+ str(dump_subdir), compare_outputs=False, device="cpu", run=False
+ )
+
+ # Verify tensors were loaded
+ assert len(result["args"]) == 2
+ assert result["args"][0].shape == (32, 128)
+ assert result["args"][1].shape == (128, 8, 128)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "-s"])