-
Notifications
You must be signed in to change notification settings - Fork 88
/
distributed_norm.py
83 lines (72 loc) · 3.5 KB
/
distributed_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import ttnn
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.llama_ccl import tt_sharded_distributed_rmsnorm, tt_distributed_rmsnorm
class DistributedNorm(LightweightModule):
def __init__(self, norm, args, TG=False):
self.norm = norm
self.args = args
if TG:
core_grid_ln = (
min(4, args.dim // 4 // 32 // 8),
8,
) # dividing by 4 and 8 for num_cols and num_rows of mesh, and 32 for tile size
num_cores_ln = core_grid_ln[0] * core_grid_ln[1]
hidden_size_per_device_distributed_ln = args.dim // 4
self.gather_in_mem_cfg = ttnn.create_sharded_memory_config(
shape=(1, 1, 32, hidden_size_per_device_distributed_ln),
core_grid=ttnn.CoreGrid(y=core_grid_ln[0], x=core_grid_ln[1]),
strategy=ttnn.ShardStrategy.WIDTH,
)
self.ln_prg_cfg = ttnn.LayerNormShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=(core_grid_ln[1], core_grid_ln[0]),
subblock_w=(hidden_size_per_device_distributed_ln // num_cores_ln) // 32,
block_h=1,
block_w=(hidden_size_per_device_distributed_ln // num_cores_ln) // 32,
inplace=False,
)
self.ln_sharded_stats_memcfg = ttnn.create_sharded_memory_config(
shape=[1, 1, 32, 32 * 4],
core_grid=ttnn.CoreGrid(y=1, x=1),
strategy=ttnn.ShardStrategy.WIDTH,
)
self.ln_cfg = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
self.TG = TG
def forward(self, x, mode):
"""Apply a norm, possibly gathering inputs if required."""
if self.TG:
if mode == "decode":
return tt_sharded_distributed_rmsnorm(
x,
epsilon=self.norm.eps,
gamma=self.norm.weight_distributed,
mesh_device=self.args.mesh_device,
ln_sharded_input_memcfg=self.gather_in_mem_cfg,
ln_sharded_progcfg=self.ln_prg_cfg,
ln_sharded_stats_memcfg=self.ln_sharded_stats_memcfg,
)
else:
return tt_distributed_rmsnorm(
x,
epsilon=self.norm.eps,
gamma=self.norm.weight_distributed,
mesh_device=self.args.mesh_device,
compute_kernel_config=self.ln_cfg,
)
input_mem_cfg = self.norm.sharded_output_config if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG
# Distributed norm already performs a gather
if self.args.is_multichip and not self.args.is_distributed_norm(mode):
x = ttnn.all_gather(x, dim=3, num_links=1, topology=self.args.ccl_topology(), memory_config=input_mem_cfg)
else:
x = ttnn.to_memory_config(x, input_mem_cfg)
x = self.norm(x, mode=mode, in_sharded=(mode == "decode"), out_sharded=(mode == "decode"))
# Distributed norm requires a gather
if self.args.is_distributed_norm(mode):
x = ttnn.all_gather(x, dim=3, num_links=1, topology=self.args.ccl_topology())
return x