diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 4ed6cc9af4be..e5d148d6eaba 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -104,7 +104,9 @@ m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); ## Kernel Size Analysis -Analyze CUDA kernel sizes in compiled wheel files to identify optimization opportunities: +Analyze CUDA kernel sizes in compiled wheel files to identify oversized kernels and template-instantiation bloat: + +This tool requires `cubloaty` (install with `pip install cubloaty`) to work. ```bash # Install cubloaty @@ -118,9 +120,9 @@ python analyze_whl_kernel_sizes.py path/to/sgl_kernel-*.whl --output my_analysis ``` The tool generates: -- Text report with kernel groups (by name prefix) and individual kernel sizes -- JSON file with detailed structured data -- Timing information for each analysis step +- A text report with: + - Kernel groups (by name prefix) + - Individual kernel sizes (sorted by size) Use this to identify large kernels and potential template instantiation bloat. diff --git a/sgl-kernel/analyze_whl_kernel_sizes.py b/sgl-kernel/analyze_whl_kernel_sizes.py index f845c81c9a18..56e4ca6be61b 100644 --- a/sgl-kernel/analyze_whl_kernel_sizes.py +++ b/sgl-kernel/analyze_whl_kernel_sizes.py @@ -5,7 +5,6 @@ import subprocess import sys import tempfile -import time import zipfile from pathlib import Path @@ -53,39 +52,21 @@ def analyze_whl(whl_file): temp_dir = tempfile.mkdtemp(prefix="sgl_kernel_analysis_") try: - t0 = time.time() - print(f"Extracting {whl_file}...") extract_whl(whl_file, temp_dir) - print(f" Extraction took {time.time() - t0:.2f}s\n") - t0 = time.time() binary_files = find_binary_files(temp_dir) if not binary_files: print(f"No .so or .cubin files found in {whl_file}") return [] - print( - f"Found {len(binary_files)} binary files (took {time.time() - t0:.2f}s)\n" - ) - all_kernels = [] - total_analyzed = 0 - total_skipped = 0 for binary_file in binary_files: file_name = os.path.basename(binary_file) - t0 = time.time() - print(f"Analyzing {file_name}...", end=" ", flush=True) - data = run_cubloaty(binary_file) - elapsed = time.time() - t0 if not data or "kernels" not in data: - print(f"skipped (no CUDA code, {elapsed:.2f}s)") - total_skipped += 1 continue - - kernel_count = 0 for kernel in data["kernels"]: all_kernels.append( { @@ -96,14 +77,6 @@ def analyze_whl(whl_file): "size_mb": kernel.get("size", 0) / 1024 / 1024, } ) - kernel_count += 1 - - print(f"found {kernel_count} kernels ({elapsed:.2f}s)") - total_analyzed += 1 - - print( - f"\nSummary: {total_analyzed} files analyzed, {total_skipped} files skipped\n" - ) return all_kernels finally: @@ -121,14 +94,10 @@ def generate_report(all_kernels, output_file): print("No kernels found") return - t0 = time.time() - print("Generating report...") - sorted_kernels = sorted(all_kernels, key=lambda x: x["size"], reverse=True) total_size = sum(k["size"] for k in all_kernels) total_size_mb = total_size / 1024 / 1024 - # Group by kernel prefix from collections import defaultdict kernel_groups = defaultdict(lambda: {"size": 0, "count": 0}) @@ -151,16 +120,16 @@ def generate_report(all_kernels, output_file): lines.append(f"Average kernel size: {total_size / len(all_kernels) / 1024:.2f} KB") lines.append("") - # Grouped by kernel name prefix lines.append("=" * 140) - lines.append("Kernel Groups (by name prefix)") + lines.append("Kernel Groups (by name prefix) - Top 20") lines.append("=" * 140) lines.append( f"{'Rank':<6} {'Kernel Prefix':<80} {'Count':<8} {'Total (MB)':<12} {'%':<8}" ) lines.append("-" * 140) - for i, (prefix, stats) in enumerate(sorted_groups, 1): + TOP_N = 20 + for i, (prefix, stats) in enumerate(sorted_groups[:TOP_N], 1): percentage = (stats["size"] / total_size * 100) if total_size > 0 else 0 size_mb = stats["size"] / 1024 / 1024 @@ -172,16 +141,27 @@ def generate_report(all_kernels, output_file): f"{i:<6} {display_prefix:<80} {stats['count']:<8} {size_mb:<12.2f} {percentage:<8.2f}" ) + if len(sorted_groups) > TOP_N: + other_size = sum(stats["size"] for _, stats in sorted_groups[TOP_N:]) + other_count = sum(stats["count"] for _, stats in sorted_groups[TOP_N:]) + other_percentage = (other_size / total_size * 100) if total_size > 0 else 0 + other_size_mb = other_size / 1024 / 1024 + + lines.append( + f"{'Other':<6} {'(remaining ' + str(len(sorted_groups) - TOP_N) + ' kernel groups)':<80} " + f"{other_count:<8} {other_size_mb:<12.2f} {other_percentage:<8.2f}" + ) + lines.append("") lines.append("=" * 140) - lines.append("Individual Kernels (sorted by size)") + lines.append("Individual Kernels (sorted by size) - Top 20") lines.append("=" * 140) lines.append( f"{'Rank':<6} {'File':<40} {'Kernel Name':<70} {'Size (KB)':<12} {'Size (MB)':<12} {'%':<8}" ) lines.append("-" * 140) - for i, kernel in enumerate(sorted_kernels, 1): + for i, kernel in enumerate(sorted_kernels[:TOP_N], 1): percentage = (kernel["size"] / total_size * 100) if total_size > 0 else 0 kernel_name = kernel["name"] if len(kernel_name) > 67: @@ -196,39 +176,24 @@ def generate_report(all_kernels, output_file): f"{kernel['size_kb']:<12.2f} {kernel['size_mb']:<12.4f} {percentage:<8.2f}" ) + if len(sorted_kernels) > TOP_N: + other_size = sum(k["size"] for k in sorted_kernels[TOP_N:]) + other_count = len(sorted_kernels) - TOP_N + other_percentage = (other_size / total_size * 100) if total_size > 0 else 0 + other_size_kb = other_size / 1024 + other_size_mb = other_size / 1024 / 1024 + + lines.append( + f"{'Other':<6} {'(remaining ' + str(other_count) + ' kernels)':<40} " + f"{'':<70} {other_size_kb:<12.2f} {other_size_mb:<12.4f} {other_percentage:<8.2f}" + ) + report_text = "\n".join(lines) with open(output_file, "w") as f: f.write(report_text) print(f"Report saved to: {output_file}") - json_output = output_file.replace(".txt", ".json") - with open(json_output, "w") as f: - json.dump( - { - "total_kernels": len(all_kernels), - "total_size_bytes": total_size, - "total_size_mb": total_size_mb, - "kernel_groups": [ - { - "prefix": prefix, - "count": stats["count"], - "size_bytes": stats["size"], - "size_mb": stats["size"] / 1024 / 1024, - "percentage": ( - (stats["size"] / total_size * 100) if total_size > 0 else 0 - ), - } - for prefix, stats in sorted_groups - ], - "kernels": sorted_kernels, - }, - f, - indent=2, - ) - print(f"JSON data saved to: {json_output}") - print(f"Report generation took {time.time() - t0:.2f}s") - def main(): parser = argparse.ArgumentParser( @@ -244,13 +209,10 @@ def main(): print(f"Error: {args.whl} not found") sys.exit(1) - total_start = time.time() - print(f"Analyzing {args.whl}\n") all_kernels = analyze_whl(args.whl) if all_kernels: generate_report(all_kernels, args.output) - print(f"\nTotal time: {time.time() - total_start:.2f}s") else: print("No kernel information extracted")