diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index 734d6bae28..b46fa97162 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -20,7 +20,8 @@ import logging import os import sys -from typing import Any, Callable +import uuid +from typing import Any, Callable, Dict, Optional import contextlib import torch @@ -42,6 +43,23 @@ def _substitute_process_id(path: str) -> str: _API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) _API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout")) +# Bench logging environment variables +_BENCH_LOG_ENABLED = os.environ.get("FLASHINFER_BENCH_LOG", "0").lower() in ( + "1", + "true", + "yes", +) +# Default bench log directory is under the flashinfer package directory +_FLASHINFER_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEFAULT_BENCH_LOG_DIR = os.path.join(_FLASHINFER_DIR, ".flashinfer_bench_cache") +_BENCH_LOG_DIR = _substitute_process_id( + os.environ.get("FLASHINFER_BENCH_LOG_DIR", _DEFAULT_BENCH_LOG_DIR) +) +# Maximum file size for bench dump (default: 10GB) +_BENCH_MAX_FILE_SIZE = int( + os.environ.get("MAX_FLASHINFER_BENCH_DUMP_FILE_SIZE", str(10 * 1024 * 1024 * 1024)) +) + # Create logger using Python's logging library _logger = logging.getLogger("flashinfer.api") @@ -461,6 +479,155 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: _logger.debug("\n".join(lines)) +# ============================================================================= +# Bench Logging Functions (for flashinfer-bench workload dumping) +# ============================================================================= + + +def _sanitize_api_name(func_name: str) -> str: + """ + Sanitize API name to create a valid directory name. + + Replaces dots and other special characters with underscores. + """ + # Replace dots (from class.method) and other special chars with underscores + sanitized = func_name.replace(".", "_").replace("-", "_") + # Remove any other non-alphanumeric characters except underscore + sanitized = "".join(c if c.isalnum() or c == "_" else "_" for c in sanitized) + return sanitized + + +def _dump_workload( + func: Callable, func_name: str, args: tuple, kwargs: dict +) -> Optional[str]: + """ + Dump all function arguments to a safetensors file. + + Saves all tensor and scalar arguments to a safetensors file named by UUID, + organized under a folder named by the function. + + Directory structure: + FLASHINFER_BENCH_LOG_DIR/ + function_name/ + .safetensors + .safetensors + ... + + Parameters + ---------- + func : Callable + The function being called + func_name : str + Name of the function (may include class name) + args : tuple + Positional arguments + kwargs : dict + Keyword arguments + + Returns + ------- + Optional[str] + Path to the safetensors file, or None if dumping failed + """ + try: + from safetensors.torch import save_file + except ImportError: + _logger.warning( + "[BENCH LOG] safetensors not installed, skipping workload dump. " + "Install with: pip install safetensors" + ) + return None + + try: + # Create function-specific directory + sanitized_name = _sanitize_api_name(func_name) + func_dir = os.path.join(_BENCH_LOG_DIR, sanitized_name) + os.makedirs(func_dir, exist_ok=True) + + # Get parameter names from function signature + try: + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + except Exception: + param_names = [f"arg{i}" for i in range(len(args))] + + # Collect all tensors to save + tensors_to_save: Dict[str, torch.Tensor] = {} + + # Process positional arguments + for i, arg in enumerate(args): + param_name = param_names[i] if i < len(param_names) else f"arg{i}" + _collect_tensors(arg, param_name, tensors_to_save) + + # Process keyword arguments + for key, value in kwargs.items(): + _collect_tensors(value, key, tensors_to_save) + + if not tensors_to_save: + _logger.debug(f"[BENCH LOG] No tensors to dump for {func_name}") + return None + + # Estimate file size (sum of tensor bytes) + estimated_size = sum( + tensor.numel() * tensor.element_size() + for tensor in tensors_to_save.values() + if tensor is not None + ) + + # Skip if estimated size exceeds limit + if estimated_size > _BENCH_MAX_FILE_SIZE: + _logger.debug( + f"[BENCH LOG] Skipping dump for {func_name}: estimated size " + f"{estimated_size / (1024**3):.2f}GB exceeds limit " + f"{_BENCH_MAX_FILE_SIZE / (1024**3):.2f}GB" + ) + return None + + # Move tensors to CPU for saving + cpu_tensors = { + key: tensor.detach().cpu() + for key, tensor in tensors_to_save.items() + if tensor is not None + } + + # Generate UUID and save + file_uuid = uuid.uuid4().hex + file_path = os.path.join(func_dir, f"{file_uuid}.safetensors") + save_file(cpu_tensors, file_path) + + return file_path + + except Exception as e: + _logger.warning(f"[BENCH LOG] Failed to dump workload for {func_name}: {e}") + return None + + +def _collect_tensors(value: Any, name: str, tensors: Dict[str, torch.Tensor]) -> None: + """ + Recursively collect tensors from a value into the tensors dict. + + Parameters + ---------- + value : Any + The value to extract tensors from + name : str + The parameter name (used as key prefix) + tensors : Dict[str, torch.Tensor] + Dictionary to collect tensors into (modified in place) + """ + if value is None: + return + + if isinstance(value, torch.Tensor): + tensors[name] = value + elif isinstance(value, (list, tuple)): + for i, item in enumerate(value): + _collect_tensors(item, f"{name}_{i}", tensors) + elif isinstance(value, dict): + for key, item in value.items(): + _collect_tensors(item, f"{name}_{key}", tensors) + + def flashinfer_api(func: Callable = None) -> Callable: """ Decorator to FlashInfer's APIs. @@ -469,6 +636,8 @@ def flashinfer_api(func: Callable = None) -> Callable: This decorator integrates with Python's standard logging infrastructure while maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0). + Additionally supports workload dumping for benchmarking when FLASHINFER_BENCH_LOG is enabled. + NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress. Environment Variables @@ -485,6 +654,22 @@ def flashinfer_api(func: Callable = None) -> Callable: - : Log to specified file path - Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt") + FLASHINFER_BENCH_LOG : str (default: "0") + - "0", "false", "no": Disable workload dumping (zero overhead) + - "1", "true", "yes": Enable workload dumping + When enabled, dumps all function tensor arguments to safetensors files. + Files are organized by function name and named by UUID: + FLASHINFER_BENCH_LOG_DIR/function_name/.safetensors + + FLASHINFER_BENCH_LOG_DIR : str (default: "/.flashinfer_bench_cache") + Directory where workload safetensors files are saved. + By default, the cache is placed under the flashinfer package directory. + - Use %i in path for process ID substitution (e.g., "bench_%i/" -> "bench_12345/") + + MAX_FLASHINFER_BENCH_DUMP_FILE_SIZE : int (default: 10737418240, i.e., 10GB) + Maximum estimated file size in bytes for workload dumps. + Workloads exceeding this size will be skipped. + Examples -------- Basic usage: @@ -493,11 +678,17 @@ def flashinfer_api(func: Callable = None) -> Callable: ... def my_function(x, y): ... return x + y + Enable workload dumping: + + >>> # Set environment variables before importing: + >>> # export FLASHINFER_BENCH_LOG=1 + >>> # export FLASHINFER_BENCH_LOG_DIR=/path/to/workloads + Notes ----- - Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS] (e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information") - - When FLASHINFER_LOGLEVEL=0, the decorator has truly zero overhead + - When FLASHINFER_LOGLEVEL=0 and FLASHINFER_BENCH_LOG=0, the decorator has truly zero overhead as it returns the original function unchanged. - Function names and inputs are logged BEFORE execution: - Level 1: Function name only @@ -510,9 +701,13 @@ def flashinfer_api(func: Callable = None) -> Callable: The message "[statistics skipped: CUDA graph capture in progress]" will be logged. - The %i pattern is automatically replaced with the process ID for multi-process environments. - The logger does not propagate to the root logger to avoid duplicate logs. + - **Workload Dumping**: When FLASHINFER_BENCH_LOG=1, all tensor arguments are saved to + safetensors files organized by function name (e.g., "func_name/.safetensors"). + Each API call creates a new safetensors file containing all input tensors. + Requires the safetensors package: pip install safetensors """ - # If logging is disabled, return original function with zero overhead - if _API_LOG_LEVEL == 0: + # If both logging and bench logging are disabled, return original function with zero overhead + if _API_LOG_LEVEL == 0 and not _BENCH_LOG_ENABLED: if func is None: return lambda f: f return func @@ -533,28 +728,38 @@ def wrapper(*args, **kwargs): pass # Log BEFORE execution (crash-safe for all levels!) - try: - if _API_LOG_LEVEL == 1: - # Level 1: Just log function name before execution (crash-safe) - _logger.debug( - f"{_get_timestamp()} FlashInfer API Call: {func_name}" + if _API_LOG_LEVEL > 0: + try: + if _API_LOG_LEVEL == 1: + # Level 1: Just log function name before execution (crash-safe) + _logger.debug( + f"{_get_timestamp()} FlashInfer API Call: {func_name}" + ) + elif _API_LOG_LEVEL >= 3: + # Level 3+: Log full inputs before execution (crash-safe) + _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) + except Exception as e: + _logger.error( + f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}" ) - elif _API_LOG_LEVEL >= 3: - # Level 3+: Log full inputs before execution (crash-safe) - _log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL) - except Exception as e: - _logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}") + + # Dump workload BEFORE execution (crash-safe) if bench logging is enabled + if _BENCH_LOG_ENABLED: + try: + _dump_workload(f, func_name, args, kwargs) + except Exception as e: + _logger.warning(f"[BENCH LOG ERROR in {func_name}]: {e}") # Call the original function (may crash here with CUDA errors) result = f(*args, **kwargs) # Log outputs AFTER successful execution (level 3+ only) - try: - if _API_LOG_LEVEL >= 3: + if _API_LOG_LEVEL >= 3: + try: # Level 3+: Log outputs (inputs were already logged above) _log_function_outputs(func_name, result, _API_LOG_LEVEL) - except Exception as e: - _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") + except Exception as e: + _logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}") return result diff --git a/tests/log/test_log_workload.py b/tests/log/test_log_workload.py new file mode 100644 index 0000000000..7d7da84e13 --- /dev/null +++ b/tests/log/test_log_workload.py @@ -0,0 +1,365 @@ +""" +Test workload dumping functionality for FlashInfer API logging. + +This test verifies that when FLASHINFER_BENCH_LOG is enabled, the decorator +correctly dumps tensor arguments to safetensors files. +""" + +import os +import shutil +import tempfile + +import numpy as np +import pytest +import torch + +# Set environment variables BEFORE importing flashinfer +# This ensures the decorator picks up the settings at module load time +_TEST_BENCH_DIR = tempfile.mkdtemp(prefix="flashinfer_bench_test_") +os.environ["FLASHINFER_BENCH_LOG"] = "1" +os.environ["FLASHINFER_BENCH_LOG_DIR"] = _TEST_BENCH_DIR + +import flashinfer + + +def generate_random_inputs( + batch_size: int, + max_seq_len: int, + num_attention_heads: int = 32, + num_key_value_heads: int = 4, + head_dim: int = 128, + page_size: int = 1, + device: str = "cuda", +): + """Generate random inputs for testing batch decode attention.""" + # Generate random sequence lengths for each batch + seq_lens = torch.randint( + 1, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device + ) + + # Calculate total pages needed (page_size = 1 means num_pages = total_tokens) + total_pages_needed = seq_lens.sum().item() + + # Generate kv_indptr based on sequence lengths + kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + kv_indptr[1:] = torch.cumsum(seq_lens, dim=0) + + # Generate kv_indices (page indices for each sequence) + kv_indices = torch.arange(total_pages_needed, dtype=torch.int32, device=device) + + # For page_size=1, last page always has 1 token + kv_last_page_len = torch.ones(batch_size, dtype=torch.int32, device=device) + + # Generate query tensor + q = torch.randn( + batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device + ) + + # Generate K and V caches with extra pages + num_pages = total_pages_needed + 100 + k_cache = torch.randn( + num_pages, + page_size, + num_key_value_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + v_cache = torch.randn( + num_pages, + page_size, + num_key_value_heads, + head_dim, + dtype=torch.bfloat16, + device=device, + ) + + # Attention scale + sm_scale = 1.0 / np.sqrt(head_dim) + + return { + "q": q, + "k_cache": k_cache, + "v_cache": v_cache, + "kv_indptr": kv_indptr, + "kv_indices": kv_indices, + "kv_last_page_len": kv_last_page_len, + "sm_scale": sm_scale, + "seq_lens": seq_lens, + } + + +def get_dumped_safetensors_files(func_name: str) -> list: + """Get list of safetensors files for a given function.""" + func_dir = os.path.join(_TEST_BENCH_DIR, func_name) + if not os.path.exists(func_dir): + return [] + return [ + os.path.join(func_dir, f) + for f in os.listdir(func_dir) + if f.endswith(".safetensors") + ] + + +def load_safetensors(file_path: str) -> dict: + """Load tensors from a safetensors file.""" + from safetensors.torch import load_file + + return load_file(file_path) + + +@pytest.fixture(scope="module", autouse=True) +def cleanup_bench_dir(): + """Clean up bench directory after all tests.""" + yield + # Cleanup after tests + if os.path.exists(_TEST_BENCH_DIR): + shutil.rmtree(_TEST_BENCH_DIR) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_workload_dump_batch_decode(): + """Test that batch decode attention dumps workloads correctly.""" + # Skip if safetensors not installed + pytest.importorskip("safetensors") + + device = "cuda" + batch_size = 4 + max_seq_len = 32 + + # Constants + num_attention_heads = 32 + num_key_value_heads = 4 + head_dim = 128 + page_size = 1 + + # Generate inputs + inputs = generate_random_inputs( + batch_size, + max_seq_len, + num_attention_heads, + num_key_value_heads, + head_dim, + page_size, + device, + ) + + # Setup FlashInfer wrapper + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + ) + + # Plan the attention computation + decode_wrapper.plan( + indptr=inputs["kv_indptr"], + indices=inputs["kv_indices"], + last_page_len=inputs["kv_last_page_len"], + num_qo_heads=num_attention_heads, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + page_size=page_size, + pos_encoding_mode="NONE", + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + sm_scale=inputs["sm_scale"], + ) + + # Store original tensors for comparison + original_q = inputs["q"].clone() + + # Run FlashInfer - this should trigger workload dump + output, lse = decode_wrapper.run( + inputs["q"], + (inputs["k_cache"], inputs["v_cache"]), + return_lse=True, + ) + + # Verify output is valid + assert output is not None + assert output.shape == (batch_size, num_attention_heads, head_dim) + + # Check that safetensors files were created + # The run method is decorated, so look for BatchDecodeWithPagedKVCacheWrapper_run + func_name = "BatchDecodeWithPagedKVCacheWrapper_run" + safetensors_files = get_dumped_safetensors_files(func_name) + + assert len(safetensors_files) > 0, ( + f"No safetensors files found in {_TEST_BENCH_DIR}/{func_name}. " + f"Directory contents: {os.listdir(_TEST_BENCH_DIR) if os.path.exists(_TEST_BENCH_DIR) else 'N/A'}" + ) + + # Load the most recent safetensors file + latest_file = max(safetensors_files, key=os.path.getmtime) + loaded_tensors = load_safetensors(latest_file) + + # Verify that the key tensors were dumped + # The parameter names should match the function signature + assert "q" in loaded_tensors, ( + f"'q' not found in dumped tensors: {list(loaded_tensors.keys())}" + ) + + # For paged_kv_cache which is a tuple, it should be dumped as paged_kv_cache_0 and paged_kv_cache_1 + assert "paged_kv_cache_0" in loaded_tensors or "k_cache" in loaded_tensors, ( + f"KV cache not found in dumped tensors: {list(loaded_tensors.keys())}" + ) + + # Verify tensor values match (compare on CPU) + loaded_q = loaded_tensors["q"] + original_q_cpu = original_q.cpu() + assert torch.allclose( + loaded_q.to(original_q_cpu.dtype), original_q_cpu, atol=1e-6 + ), "Dumped 'q' tensor does not match original" + + print(" - Safetensors file: ", latest_file) + print(" - Dumped tensors: ", list(loaded_tensors.keys())) + print(" - q shape: ", loaded_tensors["q"].shape) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_workload_dump_tensor_shapes(): + """Test that dumped tensors have correct shapes.""" + pytest.importorskip("safetensors") + + device = "cuda" + batch_size = 2 + max_seq_len = 16 + num_attention_heads = 32 + num_key_value_heads = 4 + head_dim = 128 + page_size = 1 + + inputs = generate_random_inputs( + batch_size, + max_seq_len, + num_attention_heads, + num_key_value_heads, + head_dim, + page_size, + device, + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + ) + + decode_wrapper.plan( + indptr=inputs["kv_indptr"], + indices=inputs["kv_indices"], + last_page_len=inputs["kv_last_page_len"], + num_qo_heads=num_attention_heads, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + page_size=page_size, + pos_encoding_mode="NONE", + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + sm_scale=inputs["sm_scale"], + ) + + # Run to trigger dump + _ = decode_wrapper.run(inputs["q"], (inputs["k_cache"], inputs["v_cache"])) + + # Check dumped files + func_name = "BatchDecodeWithPagedKVCacheWrapper_run" + safetensors_files = get_dumped_safetensors_files(func_name) + + assert len(safetensors_files) > 0, "No safetensors files created" + + latest_file = max(safetensors_files, key=os.path.getmtime) + loaded_tensors = load_safetensors(latest_file) + + # Verify q shape + if "q" in loaded_tensors: + expected_q_shape = (batch_size, num_attention_heads, head_dim) + assert loaded_tensors["q"].shape == expected_q_shape, ( + f"q shape mismatch: expected {expected_q_shape}, got {loaded_tensors['q'].shape}" + ) + + print(" - Shapes verified correctly") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_multiple_runs_create_multiple_files(): + """Test that multiple runs create multiple safetensors files.""" + pytest.importorskip("safetensors") + + device = "cuda" + batch_size = 2 + max_seq_len = 8 + num_attention_heads = 32 + num_key_value_heads = 4 + head_dim = 128 + page_size = 1 + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) + decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout="NHD", + ) + + func_name = "BatchDecodeWithPagedKVCacheWrapper_run" + + # Count existing files before test + initial_files = set(get_dumped_safetensors_files(func_name)) + + num_runs = 3 + for _ in range(num_runs): + inputs = generate_random_inputs( + batch_size, + max_seq_len, + num_attention_heads, + num_key_value_heads, + head_dim, + page_size, + device, + ) + + decode_wrapper.plan( + indptr=inputs["kv_indptr"], + indices=inputs["kv_indices"], + last_page_len=inputs["kv_last_page_len"], + num_qo_heads=num_attention_heads, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + page_size=page_size, + pos_encoding_mode="NONE", + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + sm_scale=inputs["sm_scale"], + ) + + _ = decode_wrapper.run(inputs["q"], (inputs["k_cache"], inputs["v_cache"])) + + # Count files after runs + final_files = set(get_dumped_safetensors_files(func_name)) + new_files = final_files - initial_files + + assert len(new_files) >= num_runs, ( + f"Expected at least {num_runs} new safetensors files, got {len(new_files)}" + ) + + print(f" - Created {len(new_files)} new safetensors files from {num_runs} runs") + + +if __name__ == "__main__": + # Run tests manually + print("Running workload dump tests...") + print(f"Bench log directory: {_TEST_BENCH_DIR}") + + try: + test_workload_dump_batch_decode() + test_workload_dump_tensor_shapes() + test_multiple_runs_create_multiple_files() + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + finally: + # Cleanup + if os.path.exists(_TEST_BENCH_DIR): + print("Cleaning up ", _TEST_BENCH_DIR) + shutil.rmtree(_TEST_BENCH_DIR)