|
7 | 7 | import torch.distributed as dist |
8 | 8 | import torch.distributed._symmetric_memory as symm_mem |
9 | 9 |
|
| 10 | +import kraken |
10 | 11 | from kraken import _logging as log |
11 | | -from kraken.all_reduce_fusion import ( |
12 | | - rms_norm, |
13 | | - one_shot_all_reduce_bias, |
14 | | - one_shot_all_reduce_bias_rms_norm, |
15 | | - two_shot_all_reduce_bias, |
16 | | - two_shot_all_reduce_bias_rms_norm, |
17 | | -) |
18 | 12 |
|
19 | 13 |
|
20 | 14 | def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input): |
21 | 15 | y = torch.empty_like(x) |
22 | | - one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y) |
| 16 | + kraken.all_reduce_fusion.one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y) |
23 | 17 | return y |
24 | 18 |
|
25 | 19 |
|
26 | 20 | def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input): |
27 | 21 | y = torch.empty_like(x) |
28 | | - one_shot_all_reduce_bias(symm_mem_input, x, bias, y) |
29 | | - return rms_norm(y, rms_weight) |
| 22 | + kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y) |
| 23 | + return kraken.all_reduce_fusion.rms_norm(y, rms_weight) |
30 | 24 |
|
31 | 25 |
|
32 | 26 | def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input): |
33 | 27 | y = torch.empty_like(x) |
34 | | - two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y) |
| 28 | + kraken.all_reduce_fusion.two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y) |
35 | 29 | return y |
36 | 30 |
|
37 | 31 |
|
38 | 32 | def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input): |
39 | 33 | y = torch.empty_like(x) |
40 | | - two_shot_all_reduce_bias(symm_mem_input, x, bias, y) |
41 | | - return rms_norm(y, rms_weight) |
| 34 | + kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y) |
| 35 | + return kraken.all_reduce_fusion.rms_norm(y, rms_weight) |
42 | 36 |
|
43 | 37 |
|
44 | 38 | def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight): |
45 | 39 | dist.all_reduce(x) |
46 | 40 | y = x + bias |
47 | | - return rms_norm(y, rms_weight) |
| 41 | + return kraken.all_reduce_fusion.rms_norm(y, rms_weight) |
48 | 42 |
|
49 | 43 |
|
50 | 44 | def create_benchmarks(b, t, d_size, device, dtype): |
|
0 commit comments