Skip to content
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions benchmark/kernels/deepep/deepep_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# COPIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py

import os
import sys
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist


def init_dist(local_rank: int, num_local_ranks: int, args):
# NOTES: you may rewrite this function with your own cluster settings
ip = args.master_addr
port = args.master_port
num_nodes = args.nnodes
node_rank = args.node_rank
assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8

dist.init_process_group(
backend="nccl",
init_method=f"tcp://{ip}:{port}",
world_size=num_nodes * num_local_ranks,
rank=node_rank * num_local_ranks + local_rank,
)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.cuda.set_device(local_rank)

return (
dist.get_rank(),
dist.get_world_size(),
dist.new_group(list(range(num_local_ranks * num_nodes))),
)


def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double() + 1, y.double() + 1
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return (1 - sim).item()


def per_token_cast_to_fp8(x: torch.Tensor):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n
), (x_amax / 448.0).view(m, -1)


def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor):
x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128)
x_scales = x_scales.view(x_fp8.size(0), -1, 1)
return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16)


def inplace_unique(x: torch.Tensor, num_slots: int):
assert x.dim() == 2
mask = x < 0
x_padded = x.masked_fill(mask, num_slots)
bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device)
bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded))
bin_count = bin_count[:, :num_slots]
sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True)
sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1)
sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values
x[:, :].fill_(-1)
valid_len = min(num_slots, x.size(1))
x[:, :valid_len] = sorted_bin_idx[:, :valid_len]


def create_grouped_scores(
scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int
):
num_tokens, num_experts = scores.shape
scores = scores.view(num_tokens, num_groups, -1)
mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device)
mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores)
return (scores * mask).view(num_tokens, num_experts)


def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")

# Warmup
for _ in range(num_warmups):
fn()

# Flush L2
cache.zero_()

# Testing
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)]
for i in range(num_tests):
# Record
start_events[i].record()
fn()
end_events[i].record()
if post_fn is not None:
post_fn()
torch.cuda.synchronize()

times = np.array(
[s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)]
)[1:]
return np.average(times), np.min(times), np.max(times)


class empty_suppress:
def __enter__(self):
return self

def __exit__(self, *_):
pass


class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, "w")
self.errnull_file = open(os.devnull, "w")

self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()

self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())

self.old_stdout = sys.stdout
self.old_stderr = sys.stderr

os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)

sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self

def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr

os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)

os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)

self.outnull_file.close()
self.errnull_file.close()
Comment on lines +122 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The suppress_stdout_stderr class uses low-level file descriptor manipulations (os.dup, os.dup2) to suppress output. While this is effective for C-level library outputs, have you considered if contextlib.redirect_stdout and contextlib.redirect_stderr (available in Python 3.4+) could offer a simpler, standard library-based solution if only Python-level output needs suppression?

If C-level output suppression is a firm requirement (e.g., from underlying CUDA libraries or C extensions), the current approach is understandable. However, if not, using contextlib could improve readability and reduce complexity. What are your thoughts on this trade-off?



def bench_kineto(
fn,
kernel_names,
num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: Optional[str] = None,
barrier_comm_profiling: bool = False,
):
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
) as prof:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda")
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda"))
for _ in range(num_tests):
fn()
prof.step()

# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
prof_lines = (
prof.key_averages()
.table(sort_by="cuda_time_total", max_name_column_width=100)
.split("\n")
)
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
for name in kernel_names:
assert (
sum([name in line for line in prof_lines]) == 1
), f"Errors of the kernel {name} in the profiling table"
Comment on lines +194 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion sum([name in line for line in prof_lines]) == 1 checks if each kernel name appears exactly once in the profiler output. This could be fragile if a kernel name is a substring of another kernel name or appears in descriptive text within the profiler output lines.

Could this lead to false positives or negatives? Perhaps a more robust check, like ensuring the name is a whole word or matches a more specific pattern in the line, would be safer?


# Save chrome traces
if trace_path is not None:
prof.export_chrome_trace(trace_path)

# Return average kernel times
units = {"ms": 1e3, "us": 1e6}
kernel_times = []
for name in kernel_names:
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
for unit, scale in units.items():
if unit in time_str:
kernel_times.append(float(time_str.replace(unit, "")) / scale)
break
Comment on lines +208 to +212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Parsing the time_str using line.split()[-2] assumes a fixed format for the profiler output table. If the Kineto profiler's table format changes in future PyTorch versions (e.g., more columns added, different spacing), this parsing logic might break.

Would it be more resilient to parse based on column headers or use regular expressions if the format is somewhat stable but allows for minor variations?

break
return tuple(kernel_times) if is_tupled else kernel_times[0]


def hash_tensor(t: torch.Tensor):
return t.view(torch.int64).sum().item()
Loading