Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 223 additions & 18 deletions flashinfer/api_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import logging
import os
import sys
from typing import Any, Callable
import uuid
from typing import Any, Callable, Dict, Optional
import contextlib
import torch

Expand All @@ -42,6 +43,23 @@ def _substitute_process_id(path: str) -> str:
_API_LOG_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0"))
_API_LOG_DEST = _substitute_process_id(os.environ.get("FLASHINFER_LOGDEST", "stdout"))

# Bench logging environment variables
_BENCH_LOG_ENABLED = os.environ.get("FLASHINFER_BENCH_LOG", "0").lower() in (
"1",
"true",
"yes",
)
# Default bench log directory is under the flashinfer package directory
_FLASHINFER_DIR = os.path.dirname(os.path.abspath(__file__))
_DEFAULT_BENCH_LOG_DIR = os.path.join(_FLASHINFER_DIR, ".flashinfer_bench_cache")
_BENCH_LOG_DIR = _substitute_process_id(
os.environ.get("FLASHINFER_BENCH_LOG_DIR", _DEFAULT_BENCH_LOG_DIR)
)
# Maximum file size for bench dump (default: 10GB)
_BENCH_MAX_FILE_SIZE = int(
os.environ.get("MAX_FLASHINFER_BENCH_DUMP_FILE_SIZE", str(10 * 1024 * 1024 * 1024))
)

# Create logger using Python's logging library
_logger = logging.getLogger("flashinfer.api")

Expand Down Expand Up @@ -461,6 +479,155 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None:
_logger.debug("\n".join(lines))


# =============================================================================
# Bench Logging Functions (for flashinfer-bench workload dumping)
# =============================================================================


def _sanitize_api_name(func_name: str) -> str:
"""
Sanitize API name to create a valid directory name.

Replaces dots and other special characters with underscores.
"""
# Replace dots (from class.method) and other special chars with underscores
sanitized = func_name.replace(".", "_").replace("-", "_")
# Remove any other non-alphanumeric characters except underscore
sanitized = "".join(c if c.isalnum() or c == "_" else "_" for c in sanitized)
return sanitized


def _dump_workload(
func: Callable, func_name: str, args: tuple, kwargs: dict
) -> Optional[str]:
"""
Dump all function arguments to a safetensors file.

Saves all tensor and scalar arguments to a safetensors file named by UUID,
organized under a folder named by the function.

Directory structure:
FLASHINFER_BENCH_LOG_DIR/
function_name/
<uuid>.safetensors
<uuid>.safetensors
...

Parameters
----------
func : Callable
The function being called
func_name : str
Name of the function (may include class name)
args : tuple
Positional arguments
kwargs : dict
Keyword arguments

Returns
-------
Optional[str]
Path to the safetensors file, or None if dumping failed
"""
try:
from safetensors.torch import save_file
except ImportError:
_logger.warning(
"[BENCH LOG] safetensors not installed, skipping workload dump. "
"Install with: pip install safetensors"
)
return None

try:
# Create function-specific directory
sanitized_name = _sanitize_api_name(func_name)
func_dir = os.path.join(_BENCH_LOG_DIR, sanitized_name)
os.makedirs(func_dir, exist_ok=True)

# Get parameter names from function signature
try:
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
except Exception:
param_names = [f"arg{i}" for i in range(len(args))]

# Collect all tensors to save
tensors_to_save: Dict[str, torch.Tensor] = {}

# Process positional arguments
for i, arg in enumerate(args):
param_name = param_names[i] if i < len(param_names) else f"arg{i}"
_collect_tensors(arg, param_name, tensors_to_save)

# Process keyword arguments
for key, value in kwargs.items():
_collect_tensors(value, key, tensors_to_save)

if not tensors_to_save:
_logger.debug(f"[BENCH LOG] No tensors to dump for {func_name}")
return None

# Estimate file size (sum of tensor bytes)
estimated_size = sum(
tensor.numel() * tensor.element_size()
for tensor in tensors_to_save.values()
if tensor is not None
)

# Skip if estimated size exceeds limit
if estimated_size > _BENCH_MAX_FILE_SIZE:
_logger.debug(
f"[BENCH LOG] Skipping dump for {func_name}: estimated size "
f"{estimated_size / (1024**3):.2f}GB exceeds limit "
f"{_BENCH_MAX_FILE_SIZE / (1024**3):.2f}GB"
)
return None

# Move tensors to CPU for saving
cpu_tensors = {
key: tensor.detach().cpu()
for key, tensor in tensors_to_save.items()
if tensor is not None
}

# Generate UUID and save
file_uuid = uuid.uuid4().hex
file_path = os.path.join(func_dir, f"{file_uuid}.safetensors")
save_file(cpu_tensors, file_path)

return file_path

except Exception as e:
_logger.warning(f"[BENCH LOG] Failed to dump workload for {func_name}: {e}")
return None


def _collect_tensors(value: Any, name: str, tensors: Dict[str, torch.Tensor]) -> None:
"""
Recursively collect tensors from a value into the tensors dict.

Parameters
----------
value : Any
The value to extract tensors from
name : str
The parameter name (used as key prefix)
tensors : Dict[str, torch.Tensor]
Dictionary to collect tensors into (modified in place)
"""
if value is None:
return

if isinstance(value, torch.Tensor):
tensors[name] = value
elif isinstance(value, (list, tuple)):
for i, item in enumerate(value):
_collect_tensors(item, f"{name}_{i}", tensors)
elif isinstance(value, dict):
for key, item in value.items():
_collect_tensors(item, f"{name}_{key}", tensors)


def flashinfer_api(func: Callable = None) -> Callable:
"""
Decorator to FlashInfer's APIs.
Expand All @@ -469,6 +636,8 @@ def flashinfer_api(func: Callable = None) -> Callable:
This decorator integrates with Python's standard logging infrastructure while
maintaining zero overhead when disabled (FLASHINFER_LOGLEVEL=0).

Additionally supports workload dumping for benchmarking when FLASHINFER_BENCH_LOG is enabled.

NOTE/TODO: Not all FlashInfer APIs are decorated with this decorator yet. This is a work in progress.

Environment Variables
Expand All @@ -485,6 +654,22 @@ def flashinfer_api(func: Callable = None) -> Callable:
- <path>: Log to specified file path
- Use %i in path for process ID substitution (e.g., "log_%i.txt" -> "log_12345.txt")

FLASHINFER_BENCH_LOG : str (default: "0")
- "0", "false", "no": Disable workload dumping (zero overhead)
- "1", "true", "yes": Enable workload dumping
When enabled, dumps all function tensor arguments to safetensors files.
Files are organized by function name and named by UUID:
FLASHINFER_BENCH_LOG_DIR/function_name/<uuid>.safetensors

FLASHINFER_BENCH_LOG_DIR : str (default: "<flashinfer_package>/.flashinfer_bench_cache")
Directory where workload safetensors files are saved.
By default, the cache is placed under the flashinfer package directory.
- Use %i in path for process ID substitution (e.g., "bench_%i/" -> "bench_12345/")

MAX_FLASHINFER_BENCH_DUMP_FILE_SIZE : int (default: 10737418240, i.e., 10GB)
Maximum estimated file size in bytes for workload dumps.
Workloads exceeding this size will be skipped.

Examples
--------
Basic usage:
Expand All @@ -493,11 +678,17 @@ def flashinfer_api(func: Callable = None) -> Callable:
... def my_function(x, y):
... return x + y

Enable workload dumping:

>>> # Set environment variables before importing:
>>> # export FLASHINFER_BENCH_LOG=1
>>> # export FLASHINFER_BENCH_LOG_DIR=/path/to/workloads

Notes
-----
- Key header lines include a timestamp in the format: [YYYY-MM-DD HH:MM:SS]
(e.g., "FlashInfer API Call: function_name", "FlashInfer API Logging - System Information")
- When FLASHINFER_LOGLEVEL=0, the decorator has truly zero overhead
- When FLASHINFER_LOGLEVEL=0 and FLASHINFER_BENCH_LOG=0, the decorator has truly zero overhead
as it returns the original function unchanged.
- Function names and inputs are logged BEFORE execution:
- Level 1: Function name only
Expand All @@ -510,9 +701,13 @@ def flashinfer_api(func: Callable = None) -> Callable:
The message "[statistics skipped: CUDA graph capture in progress]" will be logged.
- The %i pattern is automatically replaced with the process ID for multi-process environments.
- The logger does not propagate to the root logger to avoid duplicate logs.
- **Workload Dumping**: When FLASHINFER_BENCH_LOG=1, all tensor arguments are saved to
safetensors files organized by function name (e.g., "func_name/<uuid>.safetensors").
Each API call creates a new safetensors file containing all input tensors.
Requires the safetensors package: pip install safetensors
"""
# If logging is disabled, return original function with zero overhead
if _API_LOG_LEVEL == 0:
# If both logging and bench logging are disabled, return original function with zero overhead
if _API_LOG_LEVEL == 0 and not _BENCH_LOG_ENABLED:
if func is None:
return lambda f: f
return func
Expand All @@ -533,28 +728,38 @@ def wrapper(*args, **kwargs):
pass

# Log BEFORE execution (crash-safe for all levels!)
try:
if _API_LOG_LEVEL == 1:
# Level 1: Just log function name before execution (crash-safe)
_logger.debug(
f"{_get_timestamp()} FlashInfer API Call: {func_name}"
if _API_LOG_LEVEL > 0:
try:
if _API_LOG_LEVEL == 1:
# Level 1: Just log function name before execution (crash-safe)
_logger.debug(
f"{_get_timestamp()} FlashInfer API Call: {func_name}"
)
elif _API_LOG_LEVEL >= 3:
# Level 3+: Log full inputs before execution (crash-safe)
_log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL)
except Exception as e:
_logger.error(
f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}"
)
elif _API_LOG_LEVEL >= 3:
# Level 3+: Log full inputs before execution (crash-safe)
_log_function_inputs(f, func_name, args, kwargs, _API_LOG_LEVEL)
except Exception as e:
_logger.error(f"[LOGGING ERROR in {func_name} (pre-execution)]: {e}")

# Dump workload BEFORE execution (crash-safe) if bench logging is enabled
if _BENCH_LOG_ENABLED:
try:
_dump_workload(f, func_name, args, kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _dump_workload is called here, but it is not defined anywhere in the file. This will cause a NameError at runtime when FLASHINFER_BENCH_LOG is enabled. Please add the implementation for the _dump_workload function.

except Exception as e:
_logger.warning(f"[BENCH LOG ERROR in {func_name}]: {e}")

# Call the original function (may crash here with CUDA errors)
result = f(*args, **kwargs)

# Log outputs AFTER successful execution (level 3+ only)
try:
if _API_LOG_LEVEL >= 3:
if _API_LOG_LEVEL >= 3:
try:
# Level 3+: Log outputs (inputs were already logged above)
_log_function_outputs(func_name, result, _API_LOG_LEVEL)
except Exception as e:
_logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}")
except Exception as e:
_logger.error(f"[LOGGING ERROR in {func_name} (outputs)]: {e}")

return result

Expand Down
Loading