diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index eb96ac0579..e9693432da 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -1038,13 +1038,12 @@ def testMmFp4(args): input_fp4, input_inv_s = flashinfer.mxfp4_quantize(input) mat2_fp4, mat2_inv_s = flashinfer.mxfp4_quantize(mat2) - if "trtllm" in backends: - mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize( - mat2, - global_sf_mat2, - sfLayout=flashinfer.SfLayout.layout_128x4, - do_shuffle=True, - ) + mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize( + mat2, + global_sf_mat2, + sfLayout=flashinfer.SfLayout.layout_128x4, + do_shuffle=True, + ) if args.verbose >= 2: print(f"[VVERBOSE] {input_fp4.shape = }") diff --git a/docs/logging.rst b/docs/logging.rst index c3c2c83d8f..b5b866b128 100644 --- a/docs/logging.rst +++ b/docs/logging.rst @@ -12,7 +12,7 @@ Enable logging using two environment variables: .. code-block:: bash - # Set logging level (0-5) + # Set logging level (0-10) export FLASHINFER_LOGLEVEL=3 # Set log destination (default is stdout) @@ -45,6 +45,10 @@ Logging Levels - Statistics - Level 3 + tensor statistics (min, max, mean, NaN/Inf counts) - Numerical analysis + * - **10** + - Flight Recorder - Full Input/Output Dumps + - Level 5 + dumps all input/output tensors to ``.pt`` (or ``.safetensors``) files + - Full Reproducibility / Debugging Environment Variables --------------------- @@ -63,12 +67,156 @@ Main Configuration * - ``FLASHINFER_LOGLEVEL`` - int - 0 - - Logging level (0, 1, 3, 5) + - Logging level (0, 1, 3, 5, 10) * - ``FLASHINFER_LOGDEST`` - str - ``stdout`` - Log destination: ``stdout``, ``stderr``, or file path +Dump Configuration (Level 10) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When FLASHINFER_LOGLEVEL is set to 10, the following environment variables can be used to configure the dump behavior: + +.. list-table:: + :header-rows: 1 + :widths: 30 15 15 40 + + * - Variable + - Type + - Default + - Description + * - ``FLASHINFER_DUMP_DIR`` + - str + - ``flashinfer_dumps`` + - Directory to save dump files + * - ``FLASHINFER_DUMP_MAX_SIZE_GB`` + - float + - 20 + - Maximum size of dump directory in GB + * - ``FLASHINFER_DUMP_MAX_COUNT`` + - int + - 1000 + - Maximum number of API calls to dump + * - ``FLASHINFER_DUMP_INCLUDE`` + - str + - (empty) + - Comma-separated patterns to include (fnmatch-style) + * - ``FLASHINFER_DUMP_EXCLUDE`` + - str + - (empty) + - Comma-separated patterns to exclude (fnmatch-style) + * - ``FLASHINFER_DUMP_SAFETENSORS`` + - int + - 0 + - Set to 1 to use safetensors format (no pickle, but loses stride info) + +SafeTensors Format (Optional) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, tensors are saved using ``torch.save()`` which preserves tensor stride and contiguity information. +For faster, pickle-free serialization, you can enable safetensors format: + +.. code-block:: bash + + export FLASHINFER_DUMP_SAFETENSORS=1 + +.. warning:: + SafeTensors does NOT preserve tensor strides or non-contiguity. + All tensors are saved as contiguous. Use the default ``torch.save`` format + if stride preservation is important for your debugging. + +**Comparison**: + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - Aspect + - torch.save (default) + - safetensors + * - Speed + - Standard + - Faster + * - Safety + - Uses pickle + - No pickle (safer) + * - Stride preservation + - ✅ Yes + - ❌ No (contiguous only) + * - File extension + - ``.pt`` + - ``.safetensors`` + * - Dependency + - `torch`` + - Requires ``pip install safetensors`` + +**Replay is format-agnostic**: The replay command automatically detects the format based on file extension. + +Dump Filtering (Include/Exclude) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Use ``FLASHINFER_DUMP_INCLUDE`` and ``FLASHINFER_DUMP_EXCLUDE`` to control which API calls are dumped. +This is especially useful when running end-to-end inference with many API calls but you only care about specific ones. + +**Pattern Syntax** (fnmatch-style): + +- ``*`` matches any number of characters +- ``?`` matches a single character +- Matching is case-sensitive +- For class methods, the function name is formatted as ``ClassName.method_name`` + +**Filter Logic**: + +1. If ``FLASHINFER_DUMP_INCLUDE`` is set, only APIs matching at least one pattern are dumped +2. If ``FLASHINFER_DUMP_EXCLUDE`` is set, APIs matching any pattern are skipped +3. Both can be combined: include filter is applied first, then exclude filter + +**Examples**: + +.. code-block:: bash + + # Only dump decode-related APIs + export FLASHINFER_DUMP_INCLUDE="*decode*" + + # Dump everything except __init__ and plan methods + export FLASHINFER_DUMP_EXCLUDE="*.__init__,*.plan" + + # Only dump run() methods from wrapper classes + export FLASHINFER_DUMP_INCLUDE="*Wrapper.run" + + # Dump all single_* APIs except prefill + export FLASHINFER_DUMP_INCLUDE="single_*" + export FLASHINFER_DUMP_EXCLUDE="*prefill*" + + # Only dump a specific wrapper's run method + export FLASHINFER_DUMP_INCLUDE="BatchDecodeWithPagedKVCacheWrapper.run" + + # Dump FP8 APIs but not quantization steps + export FLASHINFER_DUMP_INCLUDE="*fp8*,*FP8*" + export FLASHINFER_DUMP_EXCLUDE="*quantize*" + +**Common Patterns**: + +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Pattern + - Matches + * - ``*decode*`` + - ``single_decode_with_kv_cache``, ``BatchDecodeWithPagedKVCacheWrapper.run`` + * - ``*Wrapper.run`` + - ``BatchDecodeWithPagedKVCacheWrapper.run``, ``BatchPrefillWithPagedKVCacheWrapper.run`` + * - ``*.__init__`` + - All wrapper ``__init__`` methods + * - ``*.plan`` + - All wrapper ``plan`` methods + * - ``mm_fp8`` + - Exact match for ``mm_fp8`` function + * - ``single_*`` + - ``single_decode_with_kv_cache``, ``single_prefill_with_kv_cache`` + Process ID Substitution ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -116,3 +264,415 @@ Level 0 has zero overhead ^^^^^^^^^^^^^^^^^^^^^^^^^^^ At Level 0, the decorator returns the original function unchanged. No wrapper, no checks, no overhead. + +Flight Recorder & Replay +------------------------ + +FlashInfer includes a "Flight Recorder" mode (Level 10) that captures inputs/outputs for reproducibility. + +Dump Directory Structure +^^^^^^^^^^^^^^^^^^^^^^^^ + +When Level 10 logging is enabled, FlashInfer creates the following structure: + +.. code-block:: text + + FLASHINFER_DUMP_DIR/ + ├── session.jsonl # Central log: one line per event (quick scanning) + ├── 20250108_143216_802_pid12345_mm_fp8_call0001/ + │ ├── metadata.jsonl # Per-dump metadata (JSONL format) + │ ├── inputs.pt # Input tensors (or .safetensors if enabled) + │ └── outputs.pt # Output tensors (or .safetensors if enabled) + ├── 20250108_143216_868_pid12345_single_decode_call0001/ + │ ├── metadata.jsonl + │ ├── inputs.pt # (or .safetensors) + │ └── outputs.pt # (or .safetensors) + └── ... + +**JSONL Format**: Both ``session.jsonl`` and ``metadata.jsonl`` use `JSON Lines `_ format +(one JSON object per line). This enables: + +- **Crash-safe logging**: Each API call appends two lines (inputs_saved, then completed) +- **Quick scanning**: Use ``session.jsonl`` to browse all recorded calls without reading subdirectories +- **Streaming reads**: Process records line-by-line for large sessions + +**Per-dump metadata.jsonl**: + +- Line 1: Written **before** execution (``execution_status: "inputs_saved"``) +- Line 2: Appended **after** successful execution (``execution_status: "completed"``) + +If a crash occurs, only line 1 will be present, preserving the inputs for debugging. + +**Central session.jsonl**: + +One-stop log of all API calls. Use standard tools to filter and analyze: + +.. code-block:: bash + + # Enable Flight Recorder (Metadata + Tensors) + export FLASHINFER_LOGLEVEL=10 + export FLASHINFER_DUMP_DIR=./my_dumps + + # Run your application + python3 benchmarks/flashinfer_benchmark.py --routine mm_fp4 --m 4 --n 1024 --k 7168 --out_dtype bfloat16 --backends cudnn --use_128x4_sf_layout --use_nvfp4 --refcheck -vv --generate_repro_command --use_cupti --no_cuda_graph --num_iters 5 + ... output redacted ... + + # Replay recorded calls + export FLASHINFER_LOGLEVEL=0 # 1 for more detailed replay results. + flashinfer replay --dir ./my_dumps + # or + python -m flashinfer replay --dir ./my_dumps + + [1] nvfp4_quantize (20251204_143216_802_pid12345_nvfp4_quantize_call0001): ✅ Passed + [2] fp4_quantize (20251204_143216_868_pid12345_fp4_quantize_call0001): ✅ Passed + [3] nvfp4_quantize (20251204_143216_949_pid12345_nvfp4_quantize_call0002): ✅ Passed + [4] fp4_quantize (20251204_143217_003_pid12345_fp4_quantize_call0002): ✅ Passed + [5] mm_fp4 (20251204_143217_178_pid12345_mm_fp4_call0001): ✅ Passed + [6] mm_fp4 (20251204_143217_346_pid12345_mm_fp4_call0002): ✅ Passed + [7] mm_fp4 (20251204_143217_427_pid12345_mm_fp4_call0003): ✅ Passed + [8] mm_fp4 (20251204_143217_475_pid12345_mm_fp4_call0004): ✅ Passed + [9] mm_fp4 (20251204_143217_510_pid12345_mm_fp4_call0005): ✅ Passed + [10] mm_fp4 (20251204_143217_551_pid12345_mm_fp4_call0006): ✅ Passed + [11] mm_fp4 (20251204_143217_591_pid12345_mm_fp4_call0007): ✅ Passed + [12] mm_fp4 (20251204_143217_631_pid12345_mm_fp4_call0008): ✅ Passed + [13] mm_fp4 (20251204_143217_672_pid12345_mm_fp4_call0009): ✅ Passed + [14] mm_fp4 (20251204_143217_708_pid12345_mm_fp4_call0010): ✅ Passed + [15] mm_fp4 (20251204_143217_769_pid12345_mm_fp4_call0011): ✅ Passed + [16] mm_fp4 (20251204_143217_812_pid12345_mm_fp4_call0012): ✅ Passed + [17] mm_fp4 (20251204_143217_852_pid12345_mm_fp4_call0013): ✅ Passed + [18] mm_fp4 (20251204_143217_904_pid12345_mm_fp4_call0014): ✅ Passed + [19] mm_fp4 (20251204_143218_153_pid12345_mm_fp4_call0015): ✅ Passed + [20] mm_fp4 (20251204_143218_390_pid12345_mm_fp4_call0016): ✅ Passed + [21] mm_fp4 (20251204_143218_627_pid12345_mm_fp4_call0017): ✅ Passed + [22] mm_fp4 (20251204_143218_862_pid12345_mm_fp4_call0018): ✅ Passed + + Summary: 22 passed, 0 failed/mismatch + +Python-Based Replay Examples +---------------------------- + +The following examples demonstrate how to use Level 10 logging to dump and replay API calls programmatically using Python. + +Example 1: ``bmm_fp8`` - Simple Function Call +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**Producer Script** (``bmm_fp8_producer.py``): + +This script initializes tensors, calls ``bmm_fp8``, and dumps the inputs/outputs to disk. + +.. code-block:: python + + """ + Producer script: Run bmm_fp8 with Level 10 logging to dump tensors. + + Usage: + FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./bmm_fp8_dumps python bmm_fp8_producer.py + """ + import torch + from flashinfer import bmm_fp8 + + def to_float8(x, dtype=torch.float8_e4m3fn): + """Convert tensor to FP8 with per-tensor scaling.""" + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + # Parameters + b, m, n, k = 4, 64, 128, 256 + input_dtype = torch.float8_e4m3fn + mat2_dtype = torch.float8_e4m3fn + res_dtype = torch.bfloat16 + + # Create input tensors + input_bf16 = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input_bf16, dtype=input_dtype) + + # mat2: row major -> column major (transposed) + mat2_bf16 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + mat2_fp8, mat2_inv_s = to_float8(mat2_bf16, dtype=mat2_dtype) + + # Pre-allocate output + res = torch.empty([b, m, n], device="cuda", dtype=res_dtype) + + # Call bmm_fp8 - this will be logged/dumped at Level 10 + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res, backend="cublas") + + # Print a small portion of the output for verification + print("Output shape:", res.shape) + print("Output[0, :3, :3]:") + print(res[0, :3, :3]) + +**Reproducer Script** (``bmm_fp8_reproducer.py``): + +This script loads the dumped tensors and replays the ``bmm_fp8`` call. + +.. code-block:: python + + """ + Reproducer script: Load dumped tensors and replay bmm_fp8. + + Usage: + python bmm_fp8_reproducer.py + """ + import torch + from pathlib import Path + from flashinfer import bmm_fp8 + from flashinfer.api_logging import replay_from_dump + + DUMP_DIR = "./bmm_fp8_dumps" + + # Find the bmm_fp8 dump directory (should be the only one or the latest) + dump_path = Path(DUMP_DIR) + bmm_dumps = sorted([d for d in dump_path.iterdir() if d.is_dir() and "bmm_fp8" in d.name]) + latest_dump = bmm_dumps[-1] # Use the latest dump + print(f"Loading dump from: {latest_dump}") + + # Use replay_from_dump to load inputs and optionally execute + result = replay_from_dump( + str(latest_dump), + compare_outputs=True, # Load expected outputs for comparison + device="cuda", + run=False, # We'll call the function manually below + ) + + # Extract the loaded arguments - args contains all positional args including the output tensor + args = result["args"] + kwargs = result["kwargs"] + expected_tensors = result.get("expected_tensors", {}) + + # Replay the call - args already contains (input, mat2, input_inv_s, mat2_inv_s, dtype, out) + res = bmm_fp8(*args, **kwargs) + + # Print the same portion for comparison + print("Replayed output shape:", res.shape) + print("Replayed output[0, :3, :3]:") + print(res[0, :3, :3]) + + # Compare with expected output if available + if "result" in expected_tensors: + expected = expected_tensors["result"] + if torch.allclose(res, expected, rtol=1e-3, atol=1e-3): + print("\n✅ Output matches expected result!") + else: + diff = (res - expected).abs().max().item() + print(f"\n❌ Output mismatch! Max diff: {diff}") + +Example 2: ``BatchDecodeWithPagedKVCacheWrapper`` - Stateful Wrapper Class +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**Producer Script** (``batch_decode_producer.py``): + +This script demonstrates logging with a stateful wrapper class that requires ``__init__``, ``plan``, and ``run`` calls. + +.. code-block:: python + + """ + Producer script: Run BatchDecodeWithPagedKVCacheWrapper with Level 10 logging. + + Usage: + FLASHINFER_LOGLEVEL=10 FLASHINFER_DUMP_DIR=./batch_decode_dumps python batch_decode_producer.py + """ + import torch + import flashinfer + + # Parameters + batch_size = 4 + kv_len = 512 + page_size = 16 + num_kv_heads = 4 + num_qo_heads = 32 + head_dim = 128 + kv_layout = "NHD" + + # Create query tensor + q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + + # Create paged KV cache + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] # NHD layout + kv_data = torch.randn(*kv_shape, device="cuda", dtype=torch.float16) + + # Create index tensors + kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages, device="cuda", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda" + ) + + # Create workspace and wrapper - __init__ will be logged + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda") + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + + # Plan - will be logged + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + data_type=torch.float16, + q_data_type=torch.float16, + ) + + # Run - will be logged + output, lse = wrapper.run(q, kv_data, return_lse=True) + + # Print a small portion of the output + print("Output shape:", output.shape) + print("Output[0, :3, :3]:") + print(output[0, :3, :3]) + print("\nLSE shape:", lse.shape) + print("LSE[0, :5]:", lse[0, :5]) + +**Reproducer Script** (``batch_decode_reproducer.py``): + +This script demonstrates replaying a sequence of stateful API calls. + +.. code-block:: python + + """ + Reproducer script: Replay BatchDecodeWithPagedKVCacheWrapper calls. + + Usage: + python batch_decode_reproducer.py + """ + import torch + from pathlib import Path + from flashinfer.api_logging import replay_sequence + + DUMP_DIR = "./batch_decode_dumps" + + # replay_sequence handles stateful objects automatically via object_registry + # It will: + # 1. Replay __init__ to create the wrapper instance + # 2. Replay plan() on the same instance + # 3. Replay run() on the same instance and compare outputs + results = replay_sequence(DUMP_DIR, device="cuda") + + # Print summary + passed = 0 + failed = 0 + for i, res in enumerate(results): + func_name = res.get("metadata", {}).get("function_name", "unknown") + dump_dir = Path(res.get("dump_dir", "")).name + + if "error" in res: + print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Error: {res['error']}") + failed += 1 + elif res.get("comparison_match", True): + print(f"[{i+1}] {func_name} ({dump_dir}): ✅ Passed") + passed += 1 + else: + print(f"[{i+1}] {func_name} ({dump_dir}): ❌ Mismatch") + failed += 1 + + print(f"\nSummary: {passed} passed, {failed} failed") + + # For manual inspection, you can also access individual results + # Find the 'run' call result (usually the last non-init, non-plan call) + for res in results: + func_name = res.get("metadata", {}).get("function_name", "") + if "run" in func_name and "execution_result" in res: + output = res["execution_result"] + if isinstance(output, tuple): + output_tensor, lse = output + print("\nReplayed output[0, :3, :3]:") + print(output_tensor[0, :3, :3]) + print("Replayed LSE[0, :5]:", lse[0, :5]) + break + +Manual Replay Without ``replay_from_dump`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more control, you can manually load the dumped tensors: + +.. note:: + This example assumes the default ``torch.save`` format (``.pt`` files). + If dumps were created with ``FLASHINFER_DUMP_SAFETENSORS=1``, use + ``safetensors.torch.load_file()`` instead of ``torch.load()``. + +.. code-block:: python + + """ + Manual replay: Load tensors directly from .pt files. + """ + import json + import torch + from pathlib import Path + from flashinfer import bmm_fp8 + + # Path is an example, replace with the actual path. + dump_dir = Path("./bmm_fp8_dumps/20250108_103217_012_pid12345_bmm_fp8_call0001") + + # Load metadata from JSONL (read last line for most complete state) + with open(dump_dir / "metadata.jsonl") as f: + lines = [line.strip() for line in f if line.strip()] + metadata = json.loads(lines[-1]) # Last line has completed state + + print(f"Function: {metadata['function_name']}") + print(f"Module: {metadata['module']}") + print(f"Status: {metadata['execution_status']}") + print(f"Input tensors: {metadata['tensor_info']['input_tensor_keys']}") + + # Load input tensors + inputs = torch.load(dump_dir / "inputs.pt", map_location="cuda") + + # Load expected outputs (if execution completed successfully) + outputs_path = dump_dir / "outputs.pt" + if outputs_path.exists(): + expected = torch.load(outputs_path, map_location="cuda") + print(f"Output tensors: {list(expected.keys())}") + + # Tensors are ready to use - reconstruct the call as needed + for key, tensor in inputs.items(): + print(f" {key}: shape={tensor.shape}, dtype={tensor.dtype}") + +Scanning Session History +^^^^^^^^^^^^^^^^^^^^^^^^ + +Use the central ``session.jsonl`` to quickly scan all recorded API calls: + +.. code-block:: python + + """ + Scan session.jsonl for quick overview of recorded calls. + """ + import json + from pathlib import Path + from collections import Counter + + dump_root = Path("./my_dumps") + session_file = dump_root / "session.jsonl" + + # Read all records + records = [] + with open(session_file) as f: + for line in f: + if line.strip(): + records.append(json.loads(line)) + + # Filter to completed calls only + completed = [r for r in records if r["execution_status"] == "completed"] + print(f"Total completed calls: {len(completed)}") + + # Count by function name + func_counts = Counter(r["function_name"] for r in completed) + print("\nCalls by function:") + for func, count in func_counts.most_common(): + print(f" {func}: {count}") + + # Find calls that didn't complete (potential crashes) + inputs_only = [r for r in records if r["execution_status"] == "inputs_saved"] + # Group by dump_dir to find incomplete calls + completed_dirs = {r["dump_dir"] for r in completed} + incomplete = [r for r in inputs_only if r["dump_dir"] not in completed_dirs] + if incomplete: + print(f"\n⚠️ Found {len(incomplete)} incomplete calls (potential crashes):") + for r in incomplete: + print(f" - {r['function_name']} at {r['dump_dir']}") diff --git a/flashinfer/__main__.py b/flashinfer/__main__.py index 4e822da19f..eed6671359 100644 --- a/flashinfer/__main__.py +++ b/flashinfer/__main__.py @@ -383,5 +383,78 @@ def export_compile_commands_cmd(path, output): click.secho(f"❌ Failed to write compile commands: {e}", fg="red") +@cli.command("replay") +@click.option( + "--dir", + "dump_dir", + required=True, + help="Directory containing dump files (or root directory of session)", +) +def replay_cmd(dump_dir): + """Replay API calls from dump directory""" + from .api_logging import replay_sequence, replay_from_dump + + device = "cuda" + + if not os.path.exists(dump_dir): + click.secho(f"❌ Directory not found: {dump_dir}", fg="red") + return + + # Check if this is a single dump or a session / sequence root + is_single_dump = os.path.exists(os.path.join(dump_dir, "metadata.jsonl")) + + try: + if is_single_dump: + click.secho(f"Replaying single dump from {dump_dir}...", fg="cyan") + result = replay_from_dump( + dump_dir, compare_outputs=True, device=device, run=True + ) + if result.get("comparison_match"): + click.secho("✅ Replay passed (outputs matched)", fg="green") + elif result.get("execution_error"): + click.secho( + f"❌ Execution failed: {result['execution_error']}", fg="red" + ) + else: + click.secho("⚠️ Replay finished but outputs did not match", fg="yellow") + else: + # Session / sequence replay + click.secho(f"Replaying session from {dump_dir}...", fg="cyan") + results = replay_sequence(dump_dir, device=device) + + passed = 0 + failed = 0 + + for i, res in enumerate(results): + dump_name = ( + os.path.basename(res.get("dump_dir", "")) + if "dump_dir" in res + else f"call_{i + 1}" + ) + # If replay_from_dump returned successfully, metadata might have the name + if "metadata" in res and "function_name" in res["metadata"]: + func_name = res["metadata"]["function_name"] + dump_name = f"{func_name} ({dump_name})" + + if "error" in res: + click.secho( + f"[{i + 1}] {dump_name}: ❌ Error: {res['error']}", fg="red" + ) + failed += 1 + elif res.get("comparison_match"): + click.secho(f"[{i + 1}] {dump_name}: ✅ Passed", fg="green") + passed += 1 + else: + click.secho(f"[{i + 1}] {dump_name}: ⚠️ Mismatch", fg="yellow") + failed += 1 + + click.secho( + f"\nSummary: {passed} passed, {failed} failed/mismatch", fg="white" + ) + + except Exception as e: + click.secho(f"❌ Replay failed: {e}", fg="red") + + if __name__ == "__main__": cli() diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 734d6bae28..bfd4571531 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -15,13 +15,18 @@ """ import enum +import fnmatch import functools import inspect +import json import logging import os import sys -from typing import Any, Callable +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, Tuple, Optional import contextlib +import importlib import torch @@ -42,6 +47,27 @@ def _substitute_process_id(path: str) -> str: _API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) _API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout")) +# Configuration for Level 10 tensor dumping +_DUMP_DIR = os.environ.get("FLASHINFER_DUMP_DIR", "flashinfer_dumps") +_DUMP_MAX_SIZE_GB = float(os.environ.get("FLASHINFER_DUMP_MAX_SIZE_GB", "20")) +_DUMP_MAX_COUNT = int(os.environ.get("FLASHINFER_DUMP_MAX_COUNT", "1000")) + +# Dump filtering: include/exclude patterns (fnmatch-style, comma-separated) +# Examples: "*decode*,*prefill*" or "BatchDecodeWrapper.run,mm_fp8" +_DUMP_INCLUDE = os.environ.get("FLASHINFER_DUMP_INCLUDE", "") +_DUMP_EXCLUDE = os.environ.get("FLASHINFER_DUMP_EXCLUDE", "") +_DUMP_INCLUDE_PATTERNS = [p.strip() for p in _DUMP_INCLUDE.split(",") if p.strip()] +_DUMP_EXCLUDE_PATTERNS = [p.strip() for p in _DUMP_EXCLUDE.split(",") if p.strip()] + +# SafeTensors format option (default: use torch.save which preserves stride/contiguity) +_DUMP_SAFETENSORS = os.environ.get("FLASHINFER_DUMP_SAFETENSORS", "0") == "1" + +# Global tracking for dump limits (reset per process) +_dump_count = 0 +_dump_total_size_bytes = 0 +_dump_call_counter = {} # Track call count per function +_session_jsonl_initialized = False # Track if session.jsonl header was written + # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -82,11 +108,947 @@ def _setup_logger(): def _get_timestamp() -> str: """Get current timestamp in the format [YYYY-MM-DD HH:MM:SS].""" - from datetime import datetime - return datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") +def _warn_dump(): + """Warn users about security implications of Level 10 logging.""" + if _API_LOG_LEVEL >= 10: + print("=" * 80) + print( + "WARNING: FlashInfer API Logging is set to Level 10 (Tensor Dumping).\n" + "This will dump ALL input and outputs including tensors for FlashInfer APIs to disk in\n" + "the configured dump directory. Ensure that you are NOT processing sensitive data\n" + "or that the dump directory is secure. To disable dumping, unset FLASHINFER_LOGLEVEL or\n" + "set it to below 10. For more information, see https://docs.flashinfer.ai/logging.html" + ) + print(f"Current dump directory is: {_DUMP_DIR}") + if _DUMP_SAFETENSORS: + print( + "⚠️ SAFETENSORS mode enabled: tensor stride/non-contiguity will NOT be preserved.\n" + " Tensors will be saved as contiguous. Use torch.save (default) to preserve strides." + ) + if _DUMP_INCLUDE_PATTERNS: + print(f"Include filter: {_DUMP_INCLUDE_PATTERNS}") + if _DUMP_EXCLUDE_PATTERNS: + print(f"Exclude filter: {_DUMP_EXCLUDE_PATTERNS}") + print("=" * 80) + + +def _should_dump_function(func_name: str) -> bool: + """ + Check if a function should be dumped based on include/exclude filters. + + Uses fnmatch-style patterns (wildcards: * for any chars, ? for single char). + Matching is case-sensitive. + + Parameters + ---------- + func_name : str + The function name to check. For class methods, this is formatted as + "ClassName.method_name" (e.g., "BatchDecodeWrapper.run"). + + Returns + ------- + bool + True if the function should be dumped, False otherwise. + + Filter Logic + ------------ + 1. If FLASHINFER_DUMP_INCLUDE is set: + - Function must match at least one include pattern + - If it doesn't match any, return False (skip dump) + 2. If FLASHINFER_DUMP_EXCLUDE is set: + - If function matches any exclude pattern, return False (skip dump) + 3. Otherwise, return True (dump the function) + """ + # If include patterns are specified, func must match at least one + if _DUMP_INCLUDE_PATTERNS: + if not any(fnmatch.fnmatch(func_name, pat) for pat in _DUMP_INCLUDE_PATTERNS): + return False + + # If exclude patterns are specified, func must not match any + if _DUMP_EXCLUDE_PATTERNS: + if any(fnmatch.fnmatch(func_name, pat) for pat in _DUMP_EXCLUDE_PATTERNS): + return False + + return True + + +def _append_to_jsonl(filepath: Path, record: Dict[str, Any]) -> None: + """ + Append a JSON record as a single line to a JSONL file. + + Parameters + ---------- + filepath : Path + Path to the JSONL file + record : Dict[str, Any] + Record to append (will be serialized as single-line JSON) + """ + with open(filepath, "a") as f: + f.write(json.dumps(record) + "\n") + + +def _read_jsonl_last_record(filepath: Path) -> Optional[Dict[str, Any]]: + """ + Read the last record from a JSONL file. + + For metadata.jsonl, this returns the most complete state (completed if available, + otherwise inputs_saved). + + Parameters + ---------- + filepath : Path + Path to the JSONL file + + Returns + ------- + Optional[Dict[str, Any]] + The last record, or None if file is empty/doesn't exist + """ + if not filepath.exists(): + return None + + last_line = None + with open(filepath, "r") as f: + for line in f: + line = line.strip() + if line: + last_line = line + + if last_line: + return json.loads(last_line) + return None + + +def _get_tensor_size_bytes(tensor: torch.Tensor) -> int: + """Calculate the size of a tensor in bytes.""" + return tensor.element_size() * tensor.nelement() + + +def _serialize_value(value: Any) -> Any: + """ + Convert a non-tensor value to a JSON-serializable format for metadata. + + This function is intended for serializing non-tensor arguments/values + that are used in API input or output metadata. Tensor arguments are not handled here. + """ + try: + if isinstance(value, torch.dtype): + # Special handling for torch.dtype + return { + "type": "torch.dtype", + "value": str(value), # e.g., "torch.bfloat16" + } + elif isinstance(value, enum.Enum): + return { + "type": "enum", + "name": f"{type(value).__name__}.{value.name}", + "value": value.value, + } + elif isinstance(value, (int, float, str, bool, type(None))): + return value + elif isinstance(value, (list, tuple, dict)): + return { + "type": type(value).__name__, + "value": str(value)[:1000], + } # Truncate long structures + else: + return { + "type": type(value).__name__, + "repr": str(value)[:1000], + } + except Exception: + return { + "type": type(value).__name__, + "repr": "", + } + + +def _extract_tensors_and_metadata( + args: tuple, kwargs: dict +) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """ + Extract tensors and non-tensor metadata from function arguments. + + Tensors are moved to CPU but preserve their stride/contiguity information. + + Returns + ------- + tensors : Dict[str, torch.Tensor] + Dictionary of tensor arguments with keys like "arg_0", "kwarg_name" + All tensors are on CPU with original stride preserved. + metadata : Dict[str, Any] + Dictionary of non-tensor arguments (serializable to JSON) + """ + tensors = {} + metadata = {} + + # Process positional arguments + for i, arg in enumerate(args): + key = f"arg_{i}" + if isinstance(arg, torch.Tensor): + tensors[key] = arg.cpu() + else: + metadata[key] = _serialize_value(arg) + + # Process keyword arguments + for key, value in kwargs.items(): + kwarg_key = f"kwarg_{key}" + if isinstance(value, torch.Tensor): + tensors[kwarg_key] = value.cpu() + else: + metadata[kwarg_key] = _serialize_value(value) + + return tensors, metadata + + +def _dump_function_inputs( + func: Callable, + func_name: str, + args: tuple, + kwargs: dict, + self_id: Optional[int] = None, +) -> Optional[str]: + """ + Dump function inputs to disk BEFORE execution (crash-safe). + + This function: + 1. Extracts tensors and metadata from inputs + 2. Creates a timestamped directory + 3. Saves inputs.pt and partial metadata.json + 4. Tracks cumulative size and count limits + + Parameters + ---------- + func : Callable + The function being called + func_name : str + Name of the function + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + self_id : Optional[int] + The id() of the 'self' object if this is a method call + + Returns + ------- + Optional[str] + Path to the dump directory, or None if dump was skipped + """ + global _dump_count, _dump_total_size_bytes + + # Check include/exclude filters first (before any work is done) + if not _should_dump_function(func_name): + _logger.debug( + f"Skipping dump for {func_name} (filtered by include/exclude patterns)" + ) + return None + + if _dump_count >= _DUMP_MAX_COUNT: + _logger.warning( + f"Dump limit reached ({_DUMP_MAX_COUNT} dumps). Skipping dump for {func_name}. " + f"Increase FLASHINFER_DUMP_MAX_COUNT if needed." + ) + return None + + try: + # Get call counter for this function + if func_name not in _dump_call_counter: + _dump_call_counter[func_name] = 0 + _dump_call_counter[func_name] += 1 + call_seq = _dump_call_counter[func_name] + + # Create dump directory structure + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[ + :-3 + ] # Include milliseconds + pid = os.getpid() + dump_name = f"{timestamp}_pid{pid}_{func_name}_call{call_seq:04d}" + dump_dir = Path(_DUMP_DIR) / dump_name + dump_dir.mkdir(parents=True, exist_ok=True) + + # Extract tensors and metadata from inputs + input_tensors, input_metadata = _extract_tensors_and_metadata(args, kwargs) + + # Calculate input size + input_size = sum(_get_tensor_size_bytes(t) for t in input_tensors.values()) + + # Check size limit (conservative check - only inputs for now) + max_size_bytes = _DUMP_MAX_SIZE_GB * 1024 * 1024 * 1024 + if _dump_total_size_bytes + input_size > max_size_bytes: + _logger.warning( + f"Dump size limit reached ({_DUMP_MAX_SIZE_GB} GB). Skipping dump for {func_name}. " + f"Increase FLASHINFER_DUMP_MAX_SIZE_GB if needed." + ) + # Clean up empty directory + dump_dir.rmdir() + return None + + # Save input tensors + if input_tensors: + if _DUMP_SAFETENSORS: + # SafeTensors format: faster, no pickle, but loses stride/contiguity + try: + from safetensors.torch import save_file + + # safetensors requires contiguous tensors + tensors_contiguous = { + k: v.contiguous() for k, v in input_tensors.items() + } + save_file(tensors_contiguous, str(dump_dir / "inputs.safetensors")) + except ImportError: + _logger.error( + "safetensors package not installed. " + "Install with: pip install safetensors" + ) + raise + else: + # torch.save format: preserves stride/contiguity + torch.save(input_tensors, dump_dir / "inputs.pt") + + # Create partial metadata (inputs only, outputs will be added later) + metadata: Dict[str, Any] = { + "function_name": func_name, + "module": func.__module__ if hasattr(func, "__module__") else "", + "call_sequence": call_seq, + "timestamp": timestamp, + "process_id": os.getpid(), + "input_metadata": input_metadata, + "output_metadata": {}, # Placeholder, will be updated after execution + "tensor_info": { + "input_tensor_keys": list(input_tensors.keys()), + "output_tensor_keys": [], # Placeholder, will be updated after execution + "input_size_bytes": input_size, + "input_size_mb": input_size / (1024 * 1024), + }, + "tensor_details": {}, # Detailed shape/dtype/stride info for reconstruction + "tensor_format": "safetensors" if _DUMP_SAFETENSORS else "torch", + "function_signature": str(inspect.signature(func)) + if hasattr(inspect, "signature") + else "", + "versions": { + "torch": torch.__version__, + "python": sys.version, + }, + "execution_status": "inputs_saved", # Will be updated to "completed" after outputs + } + + # Add self_id to metadata if it is a class method call + if self_id is not None: + metadata["self_id"] = self_id + + # Add tensor details for random generation fallback + for key, tensor in input_tensors.items(): + metadata["tensor_details"][key] = { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "stride": list(tensor.stride()), + "device": str(tensor.device), + } + + # Try to get FlashInfer version + try: + from .version import __version__ as flashinfer_version + + metadata["versions"]["flashinfer"] = flashinfer_version # type: ignore[index] + except Exception: + metadata["versions"]["flashinfer"] = "" # type: ignore[index] + + # Add dump_dir to metadata for central session.jsonl reference + metadata["dump_dir"] = str(dump_dir) + + # Save metadata to per-dump JSONL (first line: inputs_saved) + _append_to_jsonl(dump_dir / "metadata.jsonl", metadata) + + # Append to central session.jsonl for quick scanning + session_jsonl_path = Path(_DUMP_DIR) / "session.jsonl" + _append_to_jsonl(session_jsonl_path, metadata) + + # Update global tracking (only input size for now) + _dump_count += 1 + _dump_total_size_bytes += input_size + + _logger.debug( + f"Dumped inputs to: {dump_dir} " + f"(size: {input_size / (1024 * 1024):.2f} MB, " + f"total: {_dump_count}/{_DUMP_MAX_COUNT} dumps)" + ) + + return str(dump_dir) + + except Exception as e: + _logger.error(f"Failed to dump function call {func_name}: {e}") + import traceback + + _logger.error(traceback.format_exc()) + return None + + +def _dump_function_outputs(dump_dir: str, result: Any) -> None: + """ + Add function outputs to an existing dump directory (crash-safe). + + This function is called AFTER successful execution to append outputs + to the dump that was created before execution. + + Parameters + ---------- + dump_dir : str + Path to the dump directory created by _dump_function_inputs + result : Any + Function return value + """ + global _dump_total_size_bytes + + try: + dump_path = Path(dump_dir) + if not dump_path.exists(): + _logger.error(f"Dump directory not found: {dump_dir}") + return + + # Extract tensors and metadata from outputs + output_tensors = {} + output_metadata = {} + if isinstance(result, torch.Tensor): + output_tensors["result"] = result.cpu() + elif isinstance(result, tuple): + for i, item in enumerate(result): + if isinstance(item, torch.Tensor): + output_tensors[f"result_{i}"] = item.cpu() + else: + output_metadata[f"result_{i}"] = _serialize_value(item) + else: + output_metadata["result"] = _serialize_value(result) + + # Calculate output size + output_size = sum(_get_tensor_size_bytes(t) for t in output_tensors.values()) + + # Save output tensors + if output_tensors: + if _DUMP_SAFETENSORS: + # SafeTensors format: faster, no pickle, but loses stride/contiguity + from safetensors.torch import save_file + + tensors_contiguous = { + k: v.contiguous() for k, v in output_tensors.items() + } + save_file(tensors_contiguous, str(dump_path / "outputs.safetensors")) + else: + # torch.save format: preserves stride/contiguity + torch.save(output_tensors, dump_path / "outputs.pt") + + # Load existing metadata from JSONL (last record) and update it + metadata_jsonl_path = dump_path / "metadata.jsonl" + metadata = _read_jsonl_last_record(metadata_jsonl_path) + + if metadata is not None: + # Update with output information + metadata["output_metadata"] = output_metadata + metadata["tensor_info"]["output_tensor_keys"] = list(output_tensors.keys()) + metadata["tensor_info"]["output_size_bytes"] = output_size + metadata["tensor_info"]["output_size_mb"] = output_size / (1024 * 1024) + metadata["tensor_info"]["total_size_bytes"] = ( + metadata["tensor_info"]["input_size_bytes"] + output_size + ) + metadata["tensor_info"]["total_size_mb"] = metadata["tensor_info"][ + "total_size_bytes" + ] / (1024 * 1024) + metadata["execution_status"] = "completed" + + # Add output tensor details + if "tensor_details" not in metadata: + metadata["tensor_details"] = {} + for key, tensor in output_tensors.items(): + metadata["tensor_details"][key] = { + "shape": list(tensor.shape), + "dtype": str(tensor.dtype), + "stride": list(tensor.stride()), + "device": str(tensor.device), + } + + # Append completion record to per-dump JSONL + _append_to_jsonl(metadata_jsonl_path, metadata) + + # Append completion record to central session.jsonl + session_jsonl_path = Path(_DUMP_DIR) / "session.jsonl" + _append_to_jsonl(session_jsonl_path, metadata) + + # Update global size tracking + _dump_total_size_bytes += output_size + + _logger.debug( + f"Dumped outputs to: {dump_dir} " + f"(output size: {output_size / (1024 * 1024):.2f} MB, " + f"total dump size: {metadata['tensor_info']['total_size_mb']:.2f} MB)" + ) + else: + _logger.error(f"metadata.jsonl not found or empty in {dump_dir}") + + except Exception as e: + _logger.error(f"Failed to dump outputs to {dump_dir}: {e}") + import traceback + + _logger.error(traceback.format_exc()) + + +def _reconstruct_value(value: Any) -> Any: + """ + Reconstruct special types from metadata format. + + Handles: + - torch.dtype objects + - enum.Enum objects (future) + - Other serialized types + """ + if isinstance(value, dict): + value_type = value.get("type") + + if value_type == "torch.dtype": + # Reconstruct torch.dtype from string + dtype_str = value.get("value", "") + # Parse strings like "torch.bfloat16", "torch.float16", etc. + dtype_name = dtype_str.replace("torch.", "") + try: + return getattr(torch, dtype_name) + except AttributeError: + _logger.warning(f"Could not reconstruct dtype: {dtype_str}") + return value + + # For other dict types, return as-is + return value + + return value + + +def _resolve_function(module_name: str, function_name: str) -> Optional[Callable]: + """Resolve a function from module name and function name.""" + try: + module = importlib.import_module(module_name) + # Handle nested function names (e.g. Class.method) + parts = function_name.split(".") + obj: Any = module + for part in parts: + obj = getattr(obj, part) + if not callable(obj): + return None + return obj + except Exception as e: + _logger.warning( + f"Could not resolve function {module_name}.{function_name}: {e}" + ) + return None + + +def _compare_results( + actual: Any, expected: Any, rtol: float = 1e-3, atol: float = 1e-3 +) -> bool: + """Recursively compare execution results.""" + # torch.Tensor comparison + if isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor): + # Check shape + if actual.shape != expected.shape: + _logger.warning( + f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}" + ) + return False + # Check dtype + if actual.dtype != expected.dtype: + _logger.warning( + f"Dtype mismatch: actual {actual.dtype} vs expected {expected.dtype}" + ) + return False + + # Check values; apply relative and absolute tolerance. + if not torch.allclose(actual, expected, rtol=rtol, atol=atol): + diff = (actual - expected).abs().max().item() + _logger.warning(f"Value mismatch: max diff {diff}") + return False + return True + + # list/tuple comparison + elif isinstance(actual, (list, tuple)) and isinstance(expected, (list, tuple)): + if len(actual) != len(expected): + _logger.warning( + f"Length mismatch: actual {len(actual)} vs expected {len(expected)}" + ) + return False + return all( + _compare_results(a, e, rtol, atol) + for a, e in zip(actual, expected, strict=True) + ) + + # dict comparison + elif isinstance(actual, dict) and isinstance(expected, dict): + if actual.keys() != expected.keys(): + _logger.warning( + f"Key mismatch: actual {actual.keys()} vs expected {expected.keys()}" + ) + return False + return all(_compare_results(actual[k], expected[k], rtol, atol) for k in actual) + + # fallback for other types (including None). Just do a naive comparison. + else: + if actual != expected: + _logger.warning(f"Value mismatch: actual {actual} vs expected {expected}") + return False + return True + + +def replay_from_dump( + dump_dir: str, + compare_outputs: bool = False, + device: str = "cuda", + run: bool = False, + object_registry: Optional[Dict[Tuple[int, int], Any]] = None, +) -> Any: + """ + Replay a function call from a dumped directory. + + This function: + 1. Loads metadata.jsonl to get function info + 2. Loads inputs.pt to get input tensors + 3. Moves tensors to specified device (default: cuda) + 4. Reconstructs the function call + 5. Optionally executes the function (if run=True) + 6. Optionally compares with saved outputs + + Parameters + ---------- + dump_dir : str + Path to the dump directory + compare_outputs : bool + If True, load and compare with saved outputs + device : str + Target device for tensors. Options: + - "cuda" (default): Load to cuda:0 + - "cpu": Load to CPU + - "cuda:N": Load to specific CUDA device + run : bool + If True, try to resolve and execute the function + object_registry : Optional[Dict[Tuple[int, int], Any]] + Registry of stateful objects mapped by (process_id, self_id) tuple. + This composite key ensures objects from different processes don't collide + in multi-GPU environments where different processes may have objects + at the same memory address. + + Returns + ------- + result : dict + Dictionary containing: + - 'args': Positional arguments (tensors on specified device) + - 'kwargs': Keyword arguments (tensors on specified device) + - 'metadata': Full metadata + - 'execution_result': Result of execution (if run=True) + - 'comparison_match': Boolean indicating if result matched expected (if run=True and compare_outputs=True) + If compare_outputs=True, also includes: + - 'expected_tensors': Expected output tensors + - 'expected_metadata': Expected output metadata + """ + dump_path = Path(dump_dir) + if not dump_path.exists(): + raise FileNotFoundError(f"Dump directory not found: {dump_dir}") + + # Load metadata from JSONL (last record has most complete state) + metadata_jsonl_path = dump_path / "metadata.jsonl" + if not metadata_jsonl_path.exists(): + raise FileNotFoundError(f"metadata.jsonl not found in {dump_dir}") + + metadata = _read_jsonl_last_record(metadata_jsonl_path) + if metadata is None: + raise ValueError(f"metadata.jsonl is empty in {dump_dir}") + + func_name = metadata["function_name"] + + # Load input tensors - auto-detect format (torch.save or safetensors) + inputs_pt_path = dump_path / "inputs.pt" + inputs_safetensors_path = dump_path / "inputs.safetensors" + + if inputs_pt_path.exists(): + input_tensors = torch.load(str(inputs_pt_path), map_location="cpu") + elif inputs_safetensors_path.exists(): + try: + from safetensors.torch import load_file + + input_tensors = load_file(str(inputs_safetensors_path), device="cpu") + except ImportError: + raise ImportError( + "Dump was saved with safetensors but package not installed. " + "Install with: pip install safetensors" + ) from None + else: + raise FileNotFoundError( + f"Neither inputs.pt nor inputs.safetensors found in {dump_dir}" + ) + + # Move tensors to specified device + for key, tensor in input_tensors.items(): + input_tensors[key] = tensor.to(device) + + # Reconstruct args and kwargs + args = [] + kwargs = {} + input_metadata = metadata.get("input_metadata", {}) + + # Get max arg index from both tensors and metadata + max_arg_idx = -1 + + for key in input_tensors.keys(): + if key.startswith("arg_"): + idx = int(key.split("_")[1]) + max_arg_idx = max(max_arg_idx, idx) + + for key in input_metadata.keys(): + if key.startswith("arg_"): + idx = int(key.split("_")[1]) + max_arg_idx = max(max_arg_idx, idx) + + # Reconstruct positional args in order so that we can replay + # the function call exactly as it was logged. + for i in range(max_arg_idx + 1): + key = f"arg_{i}" + if key in input_tensors: + args.append(input_tensors[key]) + elif key in input_metadata: + args.append(_reconstruct_value(input_metadata[key])) + else: + # Should not happen if dump is consistent, but safeguard + _logger.warning(f"Missing argument {i} in dump.") + args.append(None) + + # Add keyword arguments. Here the ordering is not important. + for key in input_tensors.keys(): + if key.startswith("kwarg_"): + kwarg_name = key.replace("kwarg_", "") + kwargs[kwarg_name] = input_tensors[key] + + for key in input_metadata.keys(): + if key.startswith("kwarg_"): + kwarg_name = key.replace("kwarg_", "") + if kwarg_name not in kwargs: # Don't override tensor kwargs + kwargs[kwarg_name] = _reconstruct_value(input_metadata[key]) + + _logger.info(f"Replaying {func_name} from {dump_dir}") + _logger.info(f" Args: {len(args)}, Kwargs: {list(kwargs.keys())}") + + result_dict: Dict[str, Any] = {"args": args, "kwargs": kwargs, "metadata": metadata} + + # Load expected outputs if needed - auto-detect format + expected_outputs = {} + output_metadata = {} + if compare_outputs: + outputs_pt_path = dump_path / "outputs.pt" + outputs_safetensors_path = dump_path / "outputs.safetensors" + + if outputs_pt_path.exists(): + expected_outputs = torch.load(str(outputs_pt_path), map_location="cpu") + elif outputs_safetensors_path.exists(): + try: + from safetensors.torch import load_file + + expected_outputs = load_file( + str(outputs_safetensors_path), device="cpu" + ) + except ImportError: + raise ImportError( + "Dump was saved with safetensors but package not installed. " + "Install with: pip install safetensors" + ) from None + + # Move output tensors to specified device + for key, tensor in expected_outputs.items(): + expected_outputs[key] = tensor.to(device) + + output_metadata = metadata.get("output_metadata", {}) + result_dict["expected_tensors"] = expected_outputs + result_dict["expected_metadata"] = output_metadata + + if run: + module_name = metadata.get("module") + self_id = metadata.get("self_id") + process_id = metadata.get("process_id") + + func = None + obj = None + + # Stateful replay logic for class methods calls. + # Necessary for wrapped classes like BatchDecodeWithPagedKVCacheWrapper. + # Use (process_id, self_id) as composite key to avoid collisions across processes. + # In multi-GPU environments, different processes may have objects with the same + # memory address (self_id), so we need to scope by process_id. + if self_id is not None: + registry_key = (process_id, self_id) + if func_name.endswith(".__init__"): + # This is a constructor call + # Resolution: Get the class and instantiate it + class_name = func_name.split(".")[ + -2 + ] # e.g. "Wrapper.__init__" -> "Wrapper" + cls_obj = _resolve_function(module_name, class_name) + if cls_obj and callable(cls_obj): + # Instantiate: obj = Class(*args[1:], **kwargs) + # Note: args[0] is 'self' placeholder in the dump for __init__, skip it + real_args = args[1:] if len(args) > 0 else [] + try: + _logger.info( + f"Instantiating {class_name} (PID: {process_id}, ID: {self_id})..." + ) + # We need to handle the case where __init__ is called. + # The safest way is to just call the class constructor. + # We assume the logged args match the constructor args. + obj = cls_obj(*real_args, **kwargs) + if object_registry is not None: + object_registry[registry_key] = obj + # __init__ returns None, but effectively we returned the object + execution_result = None + result_dict["execution_result"] = execution_result + + # Since we successfully "ran" (instantiated), we can mark it done + # But there is no output to compare for __init__ usually (returns None) + if compare_outputs: + result_dict["comparison_match"] = ( + True # Trivial pass for __init__ + ) + return result_dict + except Exception as e: + _logger.error(f"Failed to instantiate {class_name}: {e}") + result_dict["execution_error"] = str(e) + return result_dict + else: + # Instance method call + if object_registry is not None and registry_key in object_registry: + obj = object_registry[registry_key] + method_name = func_name.split(".")[-1] + if hasattr(obj, method_name): + func = getattr(obj, method_name) + # args[0] is 'self' placeholder, skip it + args = args[1:] if len(args) > 0 else [] + else: + _logger.warning(f"Object {obj} has no method {method_name}") + else: + _logger.warning( + f"Object (PID: {process_id}, ID: {self_id}) not found in registry." + ) + + if func is None: + func = _resolve_function(module_name, func_name) + + if func: + try: + _logger.info(f"Executing {module_name}.{func_name}...") + execution_result = func(*args, **kwargs) + result_dict["execution_result"] = execution_result + + if compare_outputs: + # Flatten execution result to dict for comparison + actual_outputs = {} + if isinstance(execution_result, torch.Tensor): + actual_outputs["result"] = execution_result + elif isinstance(execution_result, (tuple, list)): + for i, item in enumerate(execution_result): + if isinstance(item, torch.Tensor): + actual_outputs[f"result_{i}"] = item + elif isinstance(execution_result, dict): + # If result is already a dict of tensors? Unlikely for FlashInfer but possible + actual_outputs = execution_result + + # Compare tensors + match = True + if expected_outputs: + match = _compare_results(actual_outputs, expected_outputs) + + result_dict["comparison_match"] = match + if match: + _logger.info("Replay comparison passed!") + else: + _logger.warning("Replay comparison FAILED.") + + except Exception as e: + _logger.error(f"Execution failed: {e}") + import traceback + + _logger.error(traceback.format_exc()) + result_dict["execution_error"] = str(e) + else: + _logger.warning( + f"Skipping execution: could not resolve {module_name}.{func_name}" + ) + elif not compare_outputs: + _logger.warning( + "Automatic function resolution disabled. " + "Pass run=True to execute, or manually call function." + ) + + return result_dict + + +def replay_sequence(root_dir: str, device: str = "cuda") -> list: + """ + Replay a sequence of API calls from a root dump directory. + + This function iterates through all dump directories in the root directory, + sorted by timestamp/sequence number, and replays them in order. + + Parameters + ---------- + root_dir : str + Path to the root directory containing dump subdirectories + device : str + Target device for execution (default: "cuda") + + Returns + ------- + list + List of results from replay_from_dump calls + """ + root_path = Path(root_dir) + if not root_path.exists(): + raise FileNotFoundError(f"Root dump directory not found: {root_dir}") + + # Find all subdirectories that look like dumps + # Pattern: YYYYMMDD_HHMMSS_milliseconds_pid_funcname_callXXXX + dump_dirs = [] + for item in root_path.iterdir(): + if item.is_dir() and (item / "metadata.jsonl").exists(): + dump_dirs.append(item) + + # Sort by directory name (which starts with timestamp) + dump_dirs.sort(key=lambda x: x.name) + + results = [] + total = len(dump_dirs) + _logger.info(f"Found {total} dumps to replay from {root_dir}") + + # Registry for stateful objects (mapped by (process_id, self_id) tuple) + # This composite key prevents collisions in multi-GPU/multi-process environments + object_registry: Dict[Tuple[int, int], Any] = {} + + for i, dump_dir in enumerate(dump_dirs): + _logger.info(f"[{i + 1}/{total}] Replaying {dump_dir.name}...") + try: + # We assume that for sequence replay, we want to EXECUTE the calls + # and assume outputs are not necessarily present or we just want to verify it runs. + # If outputs are present, we can compare. + res = replay_from_dump( + str(dump_dir), + compare_outputs=True, + device=device, + run=True, + object_registry=object_registry, + ) + # Add dump_dir to the result for CLI reporting + res["dump_dir"] = str(dump_dir) + results.append(res) + except Exception as e: + # Let's record error and continue. + _logger.error(f"Failed to replay {dump_dir.name}: {e}") + results.append({"error": str(e), "dump_dir": str(dump_dir)}) + + return results + + def _log_system_info(): """Log system information once at module initialization.""" if _API_LOG_LEVEL == 0: @@ -160,6 +1122,7 @@ def _log_system_info(): # Log system information once at module load time (if logging is enabled) _log_system_info() +_warn_dump() def _format_value(value: Any, level: int, indent: int = 0) -> str: @@ -174,7 +1137,6 @@ def _format_value(value: Any, level: int, indent: int = 0) -> str: The logging level (1, 2, or 3) indent : int The indentation level for nested structures - Returns ------- str @@ -349,7 +1311,6 @@ def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: Positional arguments that were provided kwargs : dict Keyword arguments that were provided - Returns ------- dict @@ -358,7 +1319,6 @@ def _get_default_params(func: Callable, args: tuple, kwargs: dict) -> dict: try: sig = inspect.signature(func) default_params = {} - # Determine which parameters were NOT provided for i, (param_name, param) in enumerate(sig.parameters.items()): # Skip if parameter has no default @@ -416,7 +1376,6 @@ def _log_function_inputs( for i, arg in enumerate(args): lines.append(f" arg[{i}]:") lines.append(_format_value(arg, level, indent=2)) - # Keyword arguments if kwargs: lines.append("Keyword input arguments:") @@ -425,7 +1384,6 @@ def _log_function_inputs( lines.append(_format_value(value, level, indent=2)) else: lines.append("(No explicit arguments)") - # Log default parameters that were not explicitly provided default_params = _get_default_params(func, args, kwargs) if default_params: @@ -433,14 +1391,12 @@ def _log_function_inputs( for param_name, default_value in default_params.items(): lines.append(f" {param_name}= [DEFAULT]") lines.append(_format_value(default_value, level, indent=2)) - _logger.debug("\n".join(lines)) def _log_function_outputs(func_name: str, result: Any, level: int) -> None: """ Log function outputs AFTER successful execution. - Parameters ---------- func_name : str @@ -465,6 +1421,9 @@ def flashinfer_api(func: Callable = None) -> Callable: """ Decorator to FlashInfer's APIs. + .. warning:: + This API logging feature is experimental and may change in future versions. + Currently logs input and output values of the function using Python's logging library. This decorator integrates with Python's standard logging infrastructure while maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). @@ -478,6 +1437,7 @@ def flashinfer_api(func: Callable = None) -> Callable: - 1: Log function name only (logged BEFORE execution - crash-safe) - 3: Log function name + inputs/outputs with metadata (inputs logged BEFORE execution - crash-safe) - 5: Log function name + inputs/outputs with metadata + tensor statistics (inputs logged BEFORE execution - crash-safe) + - 10: Level 5 logging + dump metadata and input/output tensors to disk for reproducibility (preserves stride/contiguity) FLASHINFER_LOGDEST : str (default: "stdout") - "stdout": Log to standard output @@ -485,6 +1445,26 @@ def flashinfer_api(func: Callable = None) -> Callable: - : Log to specified file path - Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt") + Level 10 Tensor Dumping (additional variables): + FLASHINFER_DUMP_DIR : str (default: "flashinfer_dumps") + - Directory where tensor dumps are saved + + FLASHINFER_DUMP_MAX_SIZE_GB : float (default: 20) + - Maximum total size of dumps in GB + + FLASHINFER_DUMP_MAX_COUNT : int (default: 1000) + - Maximum number of function call dumps + + FLASHINFER_DUMP_SAFETENSORS : int (default: 0) + - 0: Use torch.save format (preserves stride/contiguity) + - 1: Use safetensors format (no pickle, but loses stride info) + + FLASHINFER_DUMP_INCLUDE : str (default: "") + - Comma-separated list of patterns to include for dumping (fnmatch-style) + + FLASHINFER_DUMP_EXCLUDE : str (default: "") + - Comma-separated list of patterns to exclude for dumping (fnmatch-style) + Examples -------- Basic usage: @@ -522,6 +1502,7 @@ def decorator(f: Callable) -> Callable: def wrapper(*args, **kwargs): # Determine function name (with class name if applicable) func_name = f.__name__ + self_id = None if args and hasattr(args[0], "__class__"): try: class_name = args[0].__class__.__name__ @@ -529,9 +1510,22 @@ def wrapper(*args, **kwargs): "BatchMLAPagedAttentionWrapper" ]: func_name = f"{class_name}.{func_name}" + self_id = id(args[0]) except Exception: pass + # Level 10: Dump inputs BEFORE execution (crash-safe) + dump_dir = None + if _API_LOG_LEVEL >= 10: + try: + dump_dir = _dump_function_inputs( + f, func_name, args, kwargs, self_id=self_id + ) + if dump_dir: + _logger.debug(f"Inputs dumped to: {dump_dir}") + except Exception as e: + _logger.error(f"[DUMP ERROR (inputs) in {func_name}]: {e}") + # Log BEFORE execution (crash-safe for all levels!) try: if _API_LOG_LEVEL == 1: @@ -541,7 +1535,9 @@ def wrapper(*args, **kwargs): ) elif _API_LOG_LEVEL >= 3: # Level 3+: Log full inputs before execution (crash-safe) - _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) + # For level 10, we use level 5 logging (includes statistics) + effective_level = min(_API_LOG_LEVEL, 5) # Cap at 5 for logging + _log_function_inputs(f, func_name, args, kwargs, effective_level) except Exception as e: _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") @@ -552,10 +1548,19 @@ def wrapper(*args, **kwargs): try: if _API_LOG_LEVEL >= 3: # Level 3+: Log outputs (inputs were already logged above) - _log_function_outputs(func_name, result, _API_LOG_LEVEL) + effective_level = min(_API_LOG_LEVEL, 5) + _log_function_outputs(func_name, result, effective_level) except Exception as e: _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") + # Level 10: Dump outputs AFTER successful execution (crash-safe) + if _API_LOG_LEVEL >= 10 and dump_dir: + try: + _dump_function_outputs(dump_dir, result) + _logger.info(f"Outputs dumped to: {dump_dir}") + except Exception as e: + _logger.error(f"[DUMP ERROR (outputs) in {func_name}]: {e}") + return result return wrapper diff --git a/tests/utils/test_logging_replay.py b/tests/utils/test_logging_replay.py new file mode 100644 index 0000000000..11cd83fb80 --- /dev/null +++ b/tests/utils/test_logging_replay.py @@ -0,0 +1,1067 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Tests for Level 10 API logging with actual FlashInfer APIs. + +This test suite verifies that Level 10 logging (tensor dumping) works correctly +with all decorated FlashInfer APIs. For each API, we: +1. Run the API with Level 10 logging enabled +2. Verify a dump was created +3. Load the dump using replay_from_dump +4. Run the API again with replayed inputs +5. Verify: original_output ≈ dumped_output ≈ replayed_output +""" + +import os +import sys + +import pytest +import torch + +from flashinfer.utils import get_compute_capability +from tests.utils_fp8 import to_float8 + + +def _clean_flashinfer_modules(): + """Remove flashinfer modules from sys.modules.""" + modules_to_delete = [k for k in sys.modules.keys() if k.startswith("flashinfer")] + for module in modules_to_delete: + del sys.modules[module] + + +@pytest.fixture +def level10_environment(tmp_path): + """Set up test environment and clean up after each test.""" + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + "FLASHINFER_DUMP_SAFETENSORS": os.environ.get("FLASHINFER_DUMP_SAFETENSORS"), + } + + # Set up test environment + dump_dir = tmp_path / "test_dumps" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_MAX_COUNT"] = "1000" + os.environ["FLASHINFER_DUMP_MAX_SIZE_GB"] = "10" + # Clear any existing filters and safetensors mode + if "FLASHINFER_DUMP_INCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_INCLUDE"] + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + if "FLASHINFER_DUMP_SAFETENSORS" in os.environ: + del os.environ["FLASHINFER_DUMP_SAFETENSORS"] + + # Force reimport to pick up new environment variables + _clean_flashinfer_modules() + + yield dump_dir + + # Restore original environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + + # Force reimport + _clean_flashinfer_modules() + + +def verify_and_replay_dump( + dump_dir, original_output, func_to_replay, expected_dumps=1, dump_idx=0 +): + """Helper to verify dump creation and replay functionality.""" + from flashinfer.api_logging import replay_sequence + + # Replay sequence + results = replay_sequence(str(dump_dir), device="cuda") + + # Filter results if checking for specific function + if func_to_replay: + target_name = func_to_replay.__name__ + results = [ + res + for res in results + if res.get("metadata", {}).get("function_name") == target_name + ] + + assert len(results) >= expected_dumps, ( + f"Expected at least {expected_dumps} dumps, found {len(results)}" + ) + + # Get the target result (usually the last one if multiple) + if dump_idx == -1: + result = results[-1] + else: + result = results[dump_idx] + + # Verify comparison passed + assert result["comparison_match"] is True + + # Verify execution result matches original (in-memory check) + execution_result = result["execution_result"] + assert torch.allclose(original_output, execution_result, atol=1e-5, rtol=1e-3) + + +def test_replay_sequence(level10_environment): + """Test replaying a sequence of calls.""" + from flashinfer import single_decode_with_kv_cache + from flashinfer.api_logging import replay_sequence + + # Generate two calls + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call 1 + q1 = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k1 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v1 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + _ = single_decode_with_kv_cache(q1, k1, v1) + + # Call 2 + q2 = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k2 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v2 = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + _ = single_decode_with_kv_cache(q2, k2, v2) + + # Replay sequence + results = replay_sequence(str(level10_environment), device="cuda") + + assert len(results) == 2 + for res in results: + assert res["comparison_match"] is True + assert "execution_result" in res + + +def test_mm_fp8_replay(level10_environment): + """Test Level 10 logging with mm_fp8 API.""" + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] not in [10]: + pytest.skip("mm_fp8 is only supported on Blackwell GPUs.") + + from flashinfer import mm_fp8, autotune, prepare_low_latency_gemm_weights + + # Test configuration + m, n, k = 4, 2560, 8192 + input_dtype = torch.float8_e4m3fn + mat2_dtype = torch.float8_e4m3fn + + # Create inputs + torch.manual_seed(42) + input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + global_scale = input_inv_s * mat2_inv_s + + _cache_permute_indices = {} + prepared_weights = prepare_low_latency_gemm_weights( + mat2_fp8, _cache_permute_indices + ) + + # Run API with Level 10 logging (without pre-allocated output) + with autotune(): + original_output = mm_fp8( + input_fp8, + prepared_weights, + global_scale, + ) + + verify_and_replay_dump(level10_environment, original_output, mm_fp8) + + +def test_bmm_fp8_replay(level10_environment): + """Test Level 10 logging with bmm_fp8 API.""" + + from flashinfer import bmm_fp8, autotune + + # Test configuration + b, m, n, k = 1, 48, 80, 64 + input_dtype = torch.float8_e4m3fn + mat2_dtype = torch.float8_e4m3fn + res_dtype = torch.bfloat16 + backend = "cudnn" + + # Create inputs + torch.manual_seed(42) + input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + # Run API with Level 10 logging + with autotune(): + original_output = bmm_fp8( + input_fp8, + mat2_fp8, + input_inv_s, + mat2_inv_s, + res_dtype, + backend=backend, + ) + + verify_and_replay_dump(level10_environment, original_output, bmm_fp8) + + +def test_mm_fp4_replay(level10_environment): + """Test Level 10 logging with mm_fp4 API.""" + compute_capability = get_compute_capability(torch.device("cuda")) + compute_capability_number = compute_capability[0] * 10 + compute_capability[1] + + from flashinfer import mm_fp4, autotune, nvfp4_quantize, SfLayout + + backend = "cudnn" + if not mm_fp4.is_backend_supported(backend, compute_capability_number): + pytest.skip( + f"Skipping test for {backend} because it is not supported on compute capability {compute_capability_number}." + ) + + # Test configuration + m, n, k = 48, 128, 128 + res_dtype = torch.bfloat16 + use_128x4_sf_layout = True + + # Create inputs + torch.manual_seed(42) + input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + + a_sf_layout = SfLayout.layout_128x4 if use_128x4_sf_layout else SfLayout.layout_8x4 + global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max() + global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max() + + input_fp4, input_inv_s = nvfp4_quantize( + input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False + ) + mat2_fp4, mat2_inv_s = nvfp4_quantize( + mat2, global_sf_mat2, sfLayout=SfLayout.layout_128x4, do_shuffle=False + ) + + alpha = 1.0 / (global_sf_input * global_sf_mat2) + block_size = 16 + + # Run API with Level 10 logging (without pre-allocated output) + with autotune(): + original_output = mm_fp4( + input_fp4, + mat2_fp4.T, + input_inv_s, + mat2_inv_s.T, + alpha, + res_dtype, + block_size=block_size, + use_8x4_sf_layout=not use_128x4_sf_layout, + backend=backend, + use_nvfp4=True, + skip_check=False, + ) + + verify_and_replay_dump(level10_environment, original_output, mm_fp4) + + +def test_single_prefill_with_kv_cache_replay(level10_environment): + """Test Level 10 logging with single_prefill_with_kv_cache API.""" + from flashinfer import single_prefill_with_kv_cache + + # Test configuration + qo_len, kv_len = 127, 501 + num_qo_heads, num_kv_heads = 4, 1 + head_dim = 128 + causal = False + + # Create inputs + torch.manual_seed(42) + q = torch.randn(qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + # Run API with Level 10 logging + o = single_prefill_with_kv_cache(q, k, v, causal=causal) + + original_output = o.clone() + + verify_and_replay_dump( + level10_environment, original_output, single_prefill_with_kv_cache + ) + + +def test_single_decode_with_kv_cache_replay(level10_environment): + """Test Level 10 logging with single_decode_with_kv_cache API.""" + from flashinfer import single_decode_with_kv_cache + + # Test configuration + kv_len = 1024 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Create inputs + torch.manual_seed(42) + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + + # Run API with Level 10 logging + o = single_decode_with_kv_cache(q, k, v) + + original_output = o.clone() + + verify_and_replay_dump( + level10_environment, original_output, single_decode_with_kv_cache + ) + + +def test_cli_replay(level10_environment): + """Test the CLI replay command.""" + from click.testing import CliRunner + from flashinfer.__main__ import cli + from flashinfer import single_decode_with_kv_cache + + # Create some data + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + v = torch.randn(kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16) + single_decode_with_kv_cache(q, k, v) + + runner = CliRunner() + result = runner.invoke(cli, ["replay", "--dir", str(level10_environment)]) + + assert result.exit_code == 0 + assert "Replaying session from" in result.output + assert "Passed" in result.output + assert "Summary: 1 passed" in result.output + + +# ============================================================================= +# Tests for FLASHINFER_DUMP_INCLUDE / FLASHINFER_DUMP_EXCLUDE filtering +# ============================================================================= + + +def test_dump_include_filter(tmp_path): + """Test that FLASHINFER_DUMP_INCLUDE only dumps matching functions.""" + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment with include filter for decode only + dump_dir = tmp_path / "test_dumps_include" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_INCLUDE"] = "*decode*" + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode (should be dumped) + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Call prefill (should NOT be dumped due to include filter) + qo_len = 64 + q_prefill = torch.randn( + qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16 + ) + k_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill) + + # Check that only decode was dumped (filter for directories only, exclude session.jsonl) + dump_subdirs = ( + [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else [] + ) + assert len(dump_subdirs) == 1, f"Expected 1 dump, found {len(dump_subdirs)}" + + # Verify the dump is for decode + dump_name = dump_subdirs[0].name + assert "decode" in dump_name.lower(), f"Expected decode dump, got {dump_name}" + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_dump_exclude_filter(tmp_path): + """Test that FLASHINFER_DUMP_EXCLUDE skips matching functions.""" + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment with exclude filter for prefill + dump_dir = tmp_path / "test_dumps_exclude" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_EXCLUDE"] = "*prefill*" + if "FLASHINFER_DUMP_INCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_INCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode (should be dumped) + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Call prefill (should NOT be dumped due to exclude filter) + qo_len = 64 + q_prefill = torch.randn( + qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16 + ) + k_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill) + + # Check that only decode was dumped (filter for directories only, exclude session.jsonl) + dump_subdirs = ( + [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else [] + ) + assert len(dump_subdirs) == 1, f"Expected 1 dump, found {len(dump_subdirs)}" + + # Verify the dump is for decode + dump_name = dump_subdirs[0].name + assert "decode" in dump_name.lower(), f"Expected decode dump, got {dump_name}" + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_dump_include_and_exclude_combined(tmp_path): + """Test that FLASHINFER_DUMP_INCLUDE and FLASHINFER_DUMP_EXCLUDE work together.""" + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment: include all single_* APIs but exclude prefill + dump_dir = tmp_path / "test_dumps_combined" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_INCLUDE"] = "single_*" + os.environ["FLASHINFER_DUMP_EXCLUDE"] = "*prefill*" + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode (matches include, not excluded -> should be dumped) + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Call prefill (matches include BUT also matches exclude -> NOT dumped) + qo_len = 64 + q_prefill = torch.randn( + qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16 + ) + k_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill) + + # Check that only decode was dumped (filter for directories only, exclude session.jsonl) + dump_subdirs = ( + [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else [] + ) + assert len(dump_subdirs) == 1, f"Expected 1 dump, found {len(dump_subdirs)}" + + # Verify the dump is for decode + dump_name = dump_subdirs[0].name + assert "decode" in dump_name.lower(), f"Expected decode dump, got {dump_name}" + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_dump_include_no_match(tmp_path): + """Test that no dumps are created when include filter matches nothing.""" + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment with include filter that matches nothing + dump_dir = tmp_path / "test_dumps_nomatch" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_INCLUDE"] = "nonexistent_function_xyz" + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode (should NOT be dumped - doesn't match include filter) + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Check that no dumps were created (filter for directories only, exclude session.jsonl) + dump_subdirs = ( + [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else [] + ) + assert len(dump_subdirs) == 0, f"Expected 0 dumps, found {len(dump_subdirs)}" + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_dump_multiple_include_patterns(tmp_path): + """Test that multiple comma-separated include patterns work.""" + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment with multiple include patterns + dump_dir = tmp_path / "test_dumps_multi" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_INCLUDE"] = "*decode*, *prefill*" + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode (should be dumped) + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Call prefill (should also be dumped) + qo_len = 64 + q_prefill = torch.randn( + qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16 + ) + k_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v_prefill = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_prefill_with_kv_cache(q_prefill, k_prefill, v_prefill) + + # Check that both were dumped (filter for directories only, exclude session.jsonl) + dump_subdirs = ( + [d for d in dump_dir.iterdir() if d.is_dir()] if dump_dir.exists() else [] + ) + assert len(dump_subdirs) == 2, f"Expected 2 dumps, found {len(dump_subdirs)}" + + # Verify we have one decode and one prefill + dump_names = [d.name.lower() for d in dump_subdirs] + has_decode = any("decode" in name for name in dump_names) + has_prefill = any("prefill" in name for name in dump_names) + assert has_decode, "Expected decode dump not found" + assert has_prefill, "Expected prefill dump not found" + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_jsonl_format(tmp_path): + """Test that JSONL format is used correctly for metadata files.""" + import json + + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment + dump_dir = tmp_path / "test_dumps_jsonl" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + if "FLASHINFER_DUMP_INCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_INCLUDE"] + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Verify dump directory was created + assert dump_dir.exists(), "Dump directory was not created" + + # Find the dump subdirectory + dump_subdirs = [d for d in dump_dir.iterdir() if d.is_dir()] + assert len(dump_subdirs) == 1, ( + f"Expected 1 dump subdir, found {len(dump_subdirs)}" + ) + dump_subdir = dump_subdirs[0] + + # Verify per-dump metadata.jsonl exists (not metadata.json) + metadata_jsonl_path = dump_subdir / "metadata.jsonl" + metadata_json_path = dump_subdir / "metadata.json" + assert metadata_jsonl_path.exists(), "metadata.jsonl was not created" + assert not metadata_json_path.exists(), "metadata.json should not exist" + + # Verify per-dump metadata.jsonl has 2 lines (inputs_saved + completed) + with open(metadata_jsonl_path, "r") as f: + lines = [line.strip() for line in f if line.strip()] + assert len(lines) == 2, ( + f"Expected 2 lines in metadata.jsonl, found {len(lines)}" + ) + + # Verify first line has execution_status="inputs_saved" + first_record = json.loads(lines[0]) + assert first_record["execution_status"] == "inputs_saved" + + # Verify second line has execution_status="completed" + second_record = json.loads(lines[1]) + assert second_record["execution_status"] == "completed" + assert "output_metadata" in second_record + assert second_record["tensor_info"]["output_tensor_keys"] + + # Verify central session.jsonl exists + session_jsonl_path = dump_dir / "session.jsonl" + assert session_jsonl_path.exists(), "session.jsonl was not created" + + # Verify session.jsonl has 2 lines (inputs_saved + completed for this call) + with open(session_jsonl_path, "r") as f: + session_lines = [line.strip() for line in f if line.strip()] + assert len(session_lines) == 2, ( + f"Expected 2 lines in session.jsonl, found {len(session_lines)}" + ) + + # Verify session.jsonl records match per-dump records + session_first = json.loads(session_lines[0]) + session_second = json.loads(session_lines[1]) + assert session_first["execution_status"] == "inputs_saved" + assert session_second["execution_status"] == "completed" + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_session_jsonl_multiple_calls(tmp_path): + """Test that session.jsonl accumulates records from multiple API calls.""" + import json + + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + } + + try: + # Set up environment + dump_dir = tmp_path / "test_dumps_session" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + if "FLASHINFER_DUMP_INCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_INCLUDE"] + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache, single_prefill_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call 1: decode + q1 = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k1 = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v1 = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q1, k1, v1) + + # Call 2: prefill + qo_len = 64 + q2 = torch.randn( + qo_len, num_qo_heads, head_dim, device="cuda", dtype=torch.float16 + ) + k2 = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v2 = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_prefill_with_kv_cache(q2, k2, v2) + + # Verify session.jsonl has 4 lines (2 per call: inputs_saved + completed) + session_jsonl_path = dump_dir / "session.jsonl" + assert session_jsonl_path.exists(), "session.jsonl was not created" + + with open(session_jsonl_path, "r") as f: + lines = [line.strip() for line in f if line.strip()] + assert len(lines) == 4, f"Expected 4 lines in session.jsonl, found {len(lines)}" + + # Verify the structure: inputs_saved, completed, inputs_saved, completed + records = [json.loads(line) for line in lines] + assert records[0]["execution_status"] == "inputs_saved" + assert records[1]["execution_status"] == "completed" + assert records[2]["execution_status"] == "inputs_saved" + assert records[3]["execution_status"] == "completed" + + # Verify we have both function names + func_names = {r["function_name"] for r in records} + assert "single_decode_with_kv_cache" in func_names + assert "single_prefill_with_kv_cache" in func_names + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_safetensors_format(tmp_path): + """Test that FLASHINFER_DUMP_SAFETENSORS uses safetensors format.""" + pytest.importorskip("safetensors", reason="safetensors package not installed") + + import json + + # Store original environment + original_env = { + "FLASHINFER_LOGLEVEL": os.environ.get("FLASHINFER_LOGLEVEL"), + "FLASHINFER_LOGDEST": os.environ.get("FLASHINFER_LOGDEST"), + "FLASHINFER_DUMP_DIR": os.environ.get("FLASHINFER_DUMP_DIR"), + "FLASHINFER_DUMP_INCLUDE": os.environ.get("FLASHINFER_DUMP_INCLUDE"), + "FLASHINFER_DUMP_EXCLUDE": os.environ.get("FLASHINFER_DUMP_EXCLUDE"), + "FLASHINFER_DUMP_SAFETENSORS": os.environ.get("FLASHINFER_DUMP_SAFETENSORS"), + } + + try: + # Set up environment with safetensors enabled + dump_dir = tmp_path / "test_dumps_safetensors" + os.environ["FLASHINFER_LOGLEVEL"] = "10" + os.environ["FLASHINFER_LOGDEST"] = "stdout" + os.environ["FLASHINFER_DUMP_DIR"] = str(dump_dir) + os.environ["FLASHINFER_DUMP_SAFETENSORS"] = "1" + if "FLASHINFER_DUMP_INCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_INCLUDE"] + if "FLASHINFER_DUMP_EXCLUDE" in os.environ: + del os.environ["FLASHINFER_DUMP_EXCLUDE"] + + # Force reimport + _clean_flashinfer_modules() + + from flashinfer import single_decode_with_kv_cache + + # Test configuration + kv_len = 128 + num_qo_heads, num_kv_heads = 32, 8 + head_dim = 128 + + # Call decode + q = torch.randn(num_qo_heads, head_dim, device="cuda", dtype=torch.float16) + k = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + v = torch.randn( + kv_len, num_kv_heads, head_dim, device="cuda", dtype=torch.float16 + ) + single_decode_with_kv_cache(q, k, v) + + # Verify dump directory was created + assert dump_dir.exists(), "Dump directory was not created" + + # Find the dump subdirectory + dump_subdirs = [d for d in dump_dir.iterdir() if d.is_dir()] + assert len(dump_subdirs) == 1, ( + f"Expected 1 dump subdir, found {len(dump_subdirs)}" + ) + dump_subdir = dump_subdirs[0] + + # Verify safetensors files exist (not .pt files) + inputs_safetensors = dump_subdir / "inputs.safetensors" + outputs_safetensors = dump_subdir / "outputs.safetensors" + inputs_pt = dump_subdir / "inputs.pt" + outputs_pt = dump_subdir / "outputs.pt" + + assert inputs_safetensors.exists(), "inputs.safetensors was not created" + assert outputs_safetensors.exists(), "outputs.safetensors was not created" + assert not inputs_pt.exists(), "inputs.pt should not exist in safetensors mode" + assert not outputs_pt.exists(), ( + "outputs.pt should not exist in safetensors mode" + ) + + # Verify metadata has tensor_format field + with open(dump_subdir / "metadata.jsonl", "r") as f: + lines = [line.strip() for line in f if line.strip()] + last_record = json.loads(lines[-1]) + assert last_record.get("tensor_format") == "safetensors", ( + "tensor_format should be 'safetensors'" + ) + + # Verify replay works with safetensors format + _clean_flashinfer_modules() + + from flashinfer.api_logging import replay_sequence + + results = replay_sequence(str(dump_dir), device="cuda") + assert len(results) == 1 + assert results[0]["comparison_match"] is True + + finally: + # Restore environment + for key, value in original_env.items(): + if value is not None: + os.environ[key] = value + elif key in os.environ: + del os.environ[key] + _clean_flashinfer_modules() + + +def test_safetensors_replay_auto_detection(tmp_path): + """Test that replay auto-detects safetensors format.""" + pytest.importorskip("safetensors", reason="safetensors package not installed") + + from safetensors.torch import save_file + + # Create a mock dump with safetensors files + dump_subdir = tmp_path / "mock_dump" + dump_subdir.mkdir() + + # Create mock input tensors + input_tensors = { + "arg_0": torch.randn(32, 128, dtype=torch.float16), + "arg_1": torch.randn(128, 8, 128, dtype=torch.float16), + } + save_file(input_tensors, str(dump_subdir / "inputs.safetensors")) + + # Create mock metadata.jsonl + import json + + metadata = { + "function_name": "test_function", + "module": "test_module", + "call_sequence": 1, + "timestamp": "20260108_120000_000", + "process_id": os.getpid(), + "input_metadata": {}, + "output_metadata": {}, + "tensor_info": { + "input_tensor_keys": ["arg_0", "arg_1"], + "output_tensor_keys": [], + "input_size_bytes": 0, + "input_size_mb": 0, + }, + "tensor_format": "safetensors", + "execution_status": "inputs_saved", + } + with open(dump_subdir / "metadata.jsonl", "w") as f: + f.write(json.dumps(metadata) + "\n") + + # Force reimport to get clean state + _clean_flashinfer_modules() + + from flashinfer.api_logging import replay_from_dump + + # Replay should auto-detect safetensors format + result = replay_from_dump( + str(dump_subdir), compare_outputs=False, device="cpu", run=False + ) + + # Verify tensors were loaded + assert len(result["args"]) == 2 + assert result["args"][0].shape == (32, 128) + assert result["args"][1].shape == (128, 8, 128) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])