Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sgl-kernel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));

## Kernel Size Analysis

Analyze CUDA kernel sizes in compiled wheel files to identify optimization opportunities:
Analyze CUDA kernel sizes in compiled wheel files to identify oversized kernels and template-instantiation bloat:

This tool requires `cubloaty` (install with `pip install cubloaty`) to work.

```bash
# Install cubloaty
Expand All @@ -118,9 +120,9 @@ python analyze_whl_kernel_sizes.py path/to/sgl_kernel-*.whl --output my_analysis
```

The tool generates:
- Text report with kernel groups (by name prefix) and individual kernel sizes
- JSON file with detailed structured data
- Timing information for each analysis step
- A text report with:
- Kernel groups (by name prefix)
- Individual kernel sizes (sorted by size)

Use this to identify large kernels and potential template instantiation bloat.

Expand Down
94 changes: 28 additions & 66 deletions sgl-kernel/analyze_whl_kernel_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import subprocess
import sys
import tempfile
import time
import zipfile
from pathlib import Path

Expand Down Expand Up @@ -53,39 +52,21 @@ def analyze_whl(whl_file):
temp_dir = tempfile.mkdtemp(prefix="sgl_kernel_analysis_")

try:
t0 = time.time()
print(f"Extracting {whl_file}...")
extract_whl(whl_file, temp_dir)
print(f" Extraction took {time.time() - t0:.2f}s\n")

t0 = time.time()
binary_files = find_binary_files(temp_dir)
if not binary_files:
print(f"No .so or .cubin files found in {whl_file}")
return []

print(
f"Found {len(binary_files)} binary files (took {time.time() - t0:.2f}s)\n"
)

all_kernels = []
total_analyzed = 0
total_skipped = 0

for binary_file in binary_files:
file_name = os.path.basename(binary_file)
t0 = time.time()
print(f"Analyzing {file_name}...", end=" ", flush=True)

data = run_cubloaty(binary_file)
elapsed = time.time() - t0

if not data or "kernels" not in data:
print(f"skipped (no CUDA code, {elapsed:.2f}s)")
total_skipped += 1
continue

kernel_count = 0
for kernel in data["kernels"]:
all_kernels.append(
{
Expand All @@ -96,14 +77,6 @@ def analyze_whl(whl_file):
"size_mb": kernel.get("size", 0) / 1024 / 1024,
}
)
kernel_count += 1

print(f"found {kernel_count} kernels ({elapsed:.2f}s)")
total_analyzed += 1

print(
f"\nSummary: {total_analyzed} files analyzed, {total_skipped} files skipped\n"
)
return all_kernels

finally:
Expand All @@ -121,14 +94,10 @@ def generate_report(all_kernels, output_file):
print("No kernels found")
return

t0 = time.time()
print("Generating report...")

sorted_kernels = sorted(all_kernels, key=lambda x: x["size"], reverse=True)
total_size = sum(k["size"] for k in all_kernels)
total_size_mb = total_size / 1024 / 1024

# Group by kernel prefix
from collections import defaultdict

kernel_groups = defaultdict(lambda: {"size": 0, "count": 0})
Expand All @@ -151,16 +120,16 @@ def generate_report(all_kernels, output_file):
lines.append(f"Average kernel size: {total_size / len(all_kernels) / 1024:.2f} KB")
lines.append("")

# Grouped by kernel name prefix
lines.append("=" * 140)
lines.append("Kernel Groups (by name prefix)")
lines.append("Kernel Groups (by name prefix) - Top 20")
lines.append("=" * 140)
lines.append(
f"{'Rank':<6} {'Kernel Prefix':<80} {'Count':<8} {'Total (MB)':<12} {'%':<8}"
)
lines.append("-" * 140)

for i, (prefix, stats) in enumerate(sorted_groups, 1):
TOP_N = 20
for i, (prefix, stats) in enumerate(sorted_groups[:TOP_N], 1):
percentage = (stats["size"] / total_size * 100) if total_size > 0 else 0
size_mb = stats["size"] / 1024 / 1024

Expand All @@ -172,16 +141,27 @@ def generate_report(all_kernels, output_file):
f"{i:<6} {display_prefix:<80} {stats['count']:<8} {size_mb:<12.2f} {percentage:<8.2f}"
)

if len(sorted_groups) > TOP_N:
other_size = sum(stats["size"] for _, stats in sorted_groups[TOP_N:])
other_count = sum(stats["count"] for _, stats in sorted_groups[TOP_N:])
other_percentage = (other_size / total_size * 100) if total_size > 0 else 0
other_size_mb = other_size / 1024 / 1024

lines.append(
f"{'Other':<6} {'(remaining ' + str(len(sorted_groups) - TOP_N) + ' kernel groups)':<80} "
f"{other_count:<8} {other_size_mb:<12.2f} {other_percentage:<8.2f}"
)

lines.append("")
lines.append("=" * 140)
lines.append("Individual Kernels (sorted by size)")
lines.append("Individual Kernels (sorted by size) - Top 20")
lines.append("=" * 140)
lines.append(
f"{'Rank':<6} {'File':<40} {'Kernel Name':<70} {'Size (KB)':<12} {'Size (MB)':<12} {'%':<8}"
)
lines.append("-" * 140)

for i, kernel in enumerate(sorted_kernels, 1):
for i, kernel in enumerate(sorted_kernels[:TOP_N], 1):
percentage = (kernel["size"] / total_size * 100) if total_size > 0 else 0
kernel_name = kernel["name"]
if len(kernel_name) > 67:
Expand All @@ -196,39 +176,24 @@ def generate_report(all_kernels, output_file):
f"{kernel['size_kb']:<12.2f} {kernel['size_mb']:<12.4f} {percentage:<8.2f}"
)

if len(sorted_kernels) > TOP_N:
other_size = sum(k["size"] for k in sorted_kernels[TOP_N:])
other_count = len(sorted_kernels) - TOP_N
other_percentage = (other_size / total_size * 100) if total_size > 0 else 0
other_size_kb = other_size / 1024
other_size_mb = other_size / 1024 / 1024

lines.append(
f"{'Other':<6} {'(remaining ' + str(other_count) + ' kernels)':<40} "
f"{'':<70} {other_size_kb:<12.2f} {other_size_mb:<12.4f} {other_percentage:<8.2f}"
)

report_text = "\n".join(lines)

with open(output_file, "w") as f:
f.write(report_text)
print(f"Report saved to: {output_file}")

json_output = output_file.replace(".txt", ".json")
with open(json_output, "w") as f:
json.dump(
{
"total_kernels": len(all_kernels),
"total_size_bytes": total_size,
"total_size_mb": total_size_mb,
"kernel_groups": [
{
"prefix": prefix,
"count": stats["count"],
"size_bytes": stats["size"],
"size_mb": stats["size"] / 1024 / 1024,
"percentage": (
(stats["size"] / total_size * 100) if total_size > 0 else 0
),
}
for prefix, stats in sorted_groups
],
"kernels": sorted_kernels,
},
f,
indent=2,
)
print(f"JSON data saved to: {json_output}")
print(f"Report generation took {time.time() - t0:.2f}s")


def main():
parser = argparse.ArgumentParser(
Expand All @@ -244,13 +209,10 @@ def main():
print(f"Error: {args.whl} not found")
sys.exit(1)

total_start = time.time()
print(f"Analyzing {args.whl}\n")
all_kernels = analyze_whl(args.whl)

if all_kernels:
generate_report(all_kernels, args.output)
print(f"\nTotal time: {time.time() - total_start:.2f}s")
else:
print("No kernel information extracted")

Expand Down
Loading