-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Support tuning DeepEP configs #6742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
ec47de3
3031397
6b91ead
72aa5cd
b610825
c7b1b53
9462bd9
437f8cf
73cda3b
5a9e648
485b5bd
0757f90
3a7e7a8
202e721
8fd0073
f3bd54c
5fbfca4
821e755
417fb7f
4ad1e25
46e298c
04cd1f6
2c7e653
976dd7c
f9160a7
4a92aae
c078964
827753a
9658682
6405495
113bbe4
cb9e72d
51cab5c
14bbd42
64ba387
518cb31
e3058de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| # COPIED FROM https://github.com/deepseek-ai/DeepEP/blob/main/tests/utils.py | ||
|
|
||
| import os | ||
| import sys | ||
| from typing import Optional | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
|
|
||
| def init_dist(local_rank: int, num_local_ranks: int, args): | ||
| # NOTES: you may rewrite this function with your own cluster settings | ||
| ip = args.master_addr | ||
| port = args.master_port | ||
| num_nodes = args.nnodes | ||
| node_rank = args.node_rank | ||
| assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 | ||
|
|
||
| dist.init_process_group( | ||
| backend="nccl", | ||
| init_method=f"tcp://{ip}:{port}", | ||
| world_size=num_nodes * num_local_ranks, | ||
| rank=node_rank * num_local_ranks + local_rank, | ||
| ) | ||
| torch.set_default_dtype(torch.bfloat16) | ||
| torch.set_default_device("cuda") | ||
| torch.cuda.set_device(local_rank) | ||
|
|
||
| return ( | ||
| dist.get_rank(), | ||
| dist.get_world_size(), | ||
| dist.new_group(list(range(num_local_ranks * num_nodes))), | ||
| ) | ||
|
|
||
|
|
||
| def calc_diff(x: torch.Tensor, y: torch.Tensor): | ||
| x, y = x.double() + 1, y.double() + 1 | ||
| denominator = (x * x + y * y).sum() | ||
| sim = 2 * (x * y).sum() / denominator | ||
| return (1 - sim).item() | ||
|
|
||
|
|
||
| def per_token_cast_to_fp8(x: torch.Tensor): | ||
| assert x.dim() == 2 and x.size(1) % 128 == 0 | ||
| m, n = x.shape | ||
| x_view = x.view(m, -1, 128) | ||
| x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) | ||
| return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( | ||
| m, n | ||
| ), (x_amax / 448.0).view(m, -1) | ||
|
|
||
|
|
||
| def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): | ||
| x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) | ||
| x_scales = x_scales.view(x_fp8.size(0), -1, 1) | ||
| return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) | ||
|
|
||
|
|
||
| def inplace_unique(x: torch.Tensor, num_slots: int): | ||
| assert x.dim() == 2 | ||
| mask = x < 0 | ||
| x_padded = x.masked_fill(mask, num_slots) | ||
| bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) | ||
| bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) | ||
| bin_count = bin_count[:, :num_slots] | ||
| sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) | ||
| sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) | ||
| sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values | ||
| x[:, :].fill_(-1) | ||
| valid_len = min(num_slots, x.size(1)) | ||
| x[:, :valid_len] = sorted_bin_idx[:, :valid_len] | ||
|
|
||
|
|
||
| def create_grouped_scores( | ||
| scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int | ||
| ): | ||
| num_tokens, num_experts = scores.shape | ||
| scores = scores.view(num_tokens, num_groups, -1) | ||
| mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) | ||
| mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) | ||
| return (scores * mask).view(num_tokens, num_experts) | ||
|
|
||
|
|
||
| def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): | ||
| # Flush L2 cache with 256 MB data | ||
| torch.cuda.synchronize() | ||
| cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") | ||
|
|
||
| # Warmup | ||
| for _ in range(num_warmups): | ||
| fn() | ||
|
|
||
| # Flush L2 | ||
| cache.zero_() | ||
|
|
||
| # Testing | ||
| start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] | ||
| end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] | ||
| for i in range(num_tests): | ||
| # Record | ||
| start_events[i].record() | ||
| fn() | ||
| end_events[i].record() | ||
| if post_fn is not None: | ||
| post_fn() | ||
| torch.cuda.synchronize() | ||
|
|
||
| times = np.array( | ||
| [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)] | ||
| )[1:] | ||
| return np.average(times), np.min(times), np.max(times) | ||
|
|
||
|
|
||
| class empty_suppress: | ||
| def __enter__(self): | ||
| return self | ||
|
|
||
| def __exit__(self, *_): | ||
| pass | ||
|
|
||
|
|
||
| class suppress_stdout_stderr: | ||
| def __enter__(self): | ||
| self.outnull_file = open(os.devnull, "w") | ||
| self.errnull_file = open(os.devnull, "w") | ||
|
|
||
| self.old_stdout_fileno_undup = sys.stdout.fileno() | ||
| self.old_stderr_fileno_undup = sys.stderr.fileno() | ||
|
|
||
| self.old_stdout_fileno = os.dup(sys.stdout.fileno()) | ||
| self.old_stderr_fileno = os.dup(sys.stderr.fileno()) | ||
|
|
||
| self.old_stdout = sys.stdout | ||
| self.old_stderr = sys.stderr | ||
|
|
||
| os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) | ||
| os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) | ||
|
|
||
| sys.stdout = self.outnull_file | ||
| sys.stderr = self.errnull_file | ||
| return self | ||
|
|
||
| def __exit__(self, *_): | ||
| sys.stdout = self.old_stdout | ||
| sys.stderr = self.old_stderr | ||
|
|
||
| os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) | ||
| os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) | ||
|
|
||
| os.close(self.old_stdout_fileno) | ||
| os.close(self.old_stderr_fileno) | ||
|
|
||
| self.outnull_file.close() | ||
| self.errnull_file.close() | ||
|
|
||
|
|
||
| def bench_kineto( | ||
| fn, | ||
| kernel_names, | ||
| num_tests: int = 30, | ||
| suppress_kineto_output: bool = False, | ||
| trace_path: Optional[str] = None, | ||
| barrier_comm_profiling: bool = False, | ||
| ): | ||
| # Profile | ||
| suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress | ||
| with suppress(): | ||
| schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) | ||
| with torch.profiler.profile( | ||
| activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule | ||
| ) as prof: | ||
| for i in range(2): | ||
| # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead | ||
| if barrier_comm_profiling: | ||
| lhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda") | ||
| rhs = torch.randn((8192, 8192), dtype=torch.float, device="cuda") | ||
| lhs @ rhs | ||
| dist.all_reduce(torch.ones(1, dtype=torch.float, device="cuda")) | ||
| for _ in range(num_tests): | ||
| fn() | ||
| prof.step() | ||
|
|
||
| # Parse the profiling table | ||
| assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) | ||
| is_tupled = isinstance(kernel_names, tuple) | ||
| prof_lines = ( | ||
| prof.key_averages() | ||
| .table(sort_by="cuda_time_total", max_name_column_width=100) | ||
| .split("\n") | ||
| ) | ||
| kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names | ||
| assert all([isinstance(name, str) for name in kernel_names]) | ||
| for name in kernel_names: | ||
| assert ( | ||
| sum([name in line for line in prof_lines]) == 1 | ||
| ), f"Errors of the kernel {name} in the profiling table" | ||
|
Comment on lines
+194
to
+196
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion Could this lead to false positives or negatives? Perhaps a more robust check, like ensuring the name is a whole word or matches a more specific pattern in the line, would be safer? |
||
|
|
||
| # Save chrome traces | ||
| if trace_path is not None: | ||
| prof.export_chrome_trace(trace_path) | ||
|
|
||
| # Return average kernel times | ||
| units = {"ms": 1e3, "us": 1e6} | ||
| kernel_times = [] | ||
| for name in kernel_names: | ||
| for line in prof_lines: | ||
| if name in line: | ||
| time_str = line.split()[-2] | ||
| for unit, scale in units.items(): | ||
| if unit in time_str: | ||
| kernel_times.append(float(time_str.replace(unit, "")) / scale) | ||
| break | ||
|
Comment on lines
+208
to
+212
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parsing the Would it be more resilient to parse based on column headers or use regular expressions if the format is somewhat stable but allows for minor variations? |
||
| break | ||
| return tuple(kernel_times) if is_tupled else kernel_times[0] | ||
|
|
||
|
|
||
| def hash_tensor(t: torch.Tensor): | ||
| return t.view(torch.int64).sum().item() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
suppress_stdout_stderrclass uses low-level file descriptor manipulations (os.dup,os.dup2) to suppress output. While this is effective for C-level library outputs, have you considered ifcontextlib.redirect_stdoutandcontextlib.redirect_stderr(available in Python 3.4+) could offer a simpler, standard library-based solution if only Python-level output needs suppression?If C-level output suppression is a firm requirement (e.g., from underlying CUDA libraries or C extensions), the current approach is understandable. However, if not, using
contextlibcould improve readability and reduce complexity. What are your thoughts on this trade-off?