-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathllama_mlp.py
226 lines (204 loc) · 9.68 KB
/
llama_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import torch
import ttnn
from models.common.lightweightmodule import LightweightModule
from models.demos.llama3.tt.llama_ccl import tt_all_reduce
class TtLlamaMLP(LightweightModule):
def __init__(
self, mesh_device, args, state_dict, weight_cache_path, layer_num, dtype, model_config, state_dict_prefix=None
):
super().__init__()
self.state_dict = state_dict
self.mesh_device = mesh_device
self.args = args
self.dim = args.dim
self.model_config = model_config
state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num)
torch_weight = lambda name: torch.transpose(self.state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1)
if args.dummy_weights:
cache_name = lambda _: None
else:
cache_name = lambda name: weight_cache_path / (state_dict_prefix + f".{name}")
w1_w3_mem_config = args.create_dram_sharded_mem_config(args.dim, args.hidden_dim // args.num_devices)
w2_mem_config = args.create_dram_sharded_mem_config(args.hidden_dim // args.num_devices, args.dim)
# TODO Clean up this code. With sharding, we load the normal weights and then shard them
as_sharded_tensor = lambda name, type, dim: ttnn.as_tensor(
torch_weight(name[:2]), # Grab only the wX part of the name
dtype=type,
device=self.mesh_device,
mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=dim, mesh_shape=args.cluster_shape),
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG
if args.is_galaxy
else w2_mem_config
if "w2" in name
else w1_w3_mem_config,
cache_file_name=cache_name(name),
)
self.four_bit_mlp = args.optimizations.bfp4_mlp
# Sharded weights
w1_dim = (-1, -2) if args.is_galaxy else (-2, -1)
w2_dim = (-2, -1) if args.is_galaxy else (-1, -2)
self.w1 = as_sharded_tensor(
"w1_sharded", ttnn.bfloat4_b if self.four_bit_mlp else ttnn.bfloat8_b, dim=w1_dim
) # bfp4 normally ok here but sub .99 pcc for llama 3.1 weights
self.w2 = as_sharded_tensor("w2_sharded", ttnn.bfloat8_b, dim=w2_dim)
self.w3 = as_sharded_tensor("w3_sharded", ttnn.bfloat4_b if self.four_bit_mlp else ttnn.bfloat8_b, dim=w1_dim)
def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
"""
w1 -> gate_proj
w2 -> down_proj
w3 -> up_proj
HF reference: self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
"""
seq_len = x.shape[-2]
TG = self.args.is_galaxy
if mode == "decode": # Sharded config
if TG: # TODO: Fix this when TG supports DRAM sharded matmuls
pc_1 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None
pc_2 = self.model_config["FF2_TG_PROGCFG"] if self.dim >= 4096 else None
pc_3 = self.model_config["FF1_3_TG_PROGCFG"] if self.dim >= 4096 else None
else:
pc_1 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"]
pc_2 = self.model_config["DECODE_MLP_W2_PRG_CONFIG"]
pc_3 = self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"]
else: # Update the program configs based for prefill
if seq_len >= 1024:
# Reshape input to to fit on device and parallelize computation
x = ttnn.reshape(x, [1, seq_len // 1024, 1024, -1])
pc_1 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len)
pc_2 = self.model_config["PREFILL_MLP_W2_PRG_CONFIG"](seq_len)
pc_3 = self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"](seq_len)
# In decode mode (seqlen <= 32) do DRAM sharded matmuls
# These use HiFi2; this drops 1 bit of the activations but would be FLOP-bound on 12 cores with HiFi4
w1_out = ttnn.linear(
x,
self.w1,
compute_kernel_config=self.args.compute_kernel_config_lofi
if self.four_bit_mlp
else self.args.compute_kernel_config_hifi2_fp16,
core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_1 else None,
dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16,
program_config=pc_1,
memory_config=x.memory_config(),
)
w3_out = ttnn.linear(
x,
self.w3,
compute_kernel_config=self.args.compute_kernel_config_lofi
if self.four_bit_mlp
else self.args.compute_kernel_config_hifi2_fp16,
core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_3 else None,
dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16,
program_config=pc_3,
memory_config=x.memory_config(),
)
ttnn.deallocate(x)
if TG:
# if mode == "decode" and self.dim!=8192:
# w1_out = ttnn.to_memory_config(w1_out, ttnn.DRAM_MEMORY_CONFIG)
# w3_out = ttnn.to_memory_config(w3_out, ttnn.DRAM_MEMORY_CONFIG)
if self.dim == 8192 or mode == "prefill":
input_mem_cfg = w1_out.memory_config()
w1_out = ttnn.reduce_scatter(
w1_out,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=self.args.num_reduce_scatter_links,
cluster_axis=1,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None,
)
w3_out = ttnn.reduce_scatter(
w3_out,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
cluster_axis=1,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
memory_config=self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] if mode == "decode" else None,
)
else:
w1_out = tt_all_reduce(
w1_out,
self.mesh_device,
cluster_axis=1,
num_all_gather_links=2,
sharded=True if mode == "decode" else False,
memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None,
)
w3_out = tt_all_reduce(
w3_out,
self.mesh_device,
cluster_axis=1,
num_all_gather_links=2,
sharded=True if mode == "decode" else False,
memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None,
)
w2_in = ttnn.mul(
w1_out,
w3_out,
input_tensor_a_activation=ttnn.UnaryOpType.SILU,
dtype=ttnn.bfloat8_b,
memory_config=w1_out.memory_config(),
)
if mode == "decode" and not TG:
# w2 may use a different core grid, this is a no-op if they already match
w2_in = ttnn.to_memory_config(w2_in, self.model_config["SHARDED_MLP2_INPUT_MEMCFG"])
ttnn.deallocate(w3_out)
ttnn.deallocate(w1_out)
if TG and (self.dim == 8192 or mode == "prefill"):
w2_in = ttnn.all_gather(
w2_in,
3,
num_links=2,
cluster_axis=1,
mesh_device=self.mesh_device,
topology=ttnn.Topology.Linear,
memory_config=input_mem_cfg,
)
if mode == "decode":
w2_in = ttnn.to_memory_config(w2_in, ttnn.L1_MEMORY_CONFIG)
w2_out = ttnn.linear(
w2_in,
self.w2,
compute_kernel_config=self.args.compute_kernel_config_hifi2_fp16,
dtype=self.args.ccl_dtype if TG else ttnn.bfloat16,
program_config=pc_2,
memory_config=(ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG)
if TG
else w2_in.memory_config(),
core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_2 else None,
)
ttnn.deallocate(w2_in)
# if mode == "decode" and not TG:
# w2_out = ttnn.sharded_to_interleaved(w2_out, ttnn.DRAM_MEMORY_CONFIG)
w2_out_reduced = tt_all_reduce(
w2_out,
self.mesh_device,
cluster_axis=0,
dim=0 if (TG and self.dim < 8192) else 3,
num_reduce_scatter_links=self.args.num_reduce_scatter_links,
num_all_gather_links=self.args.num_all_gather_links,
sharded=(mode == "decode"),
memory_config=(self.model_config["FF2_OUT_REDUCE_SCATTER_MEMCFG"] if TG else w2_out.memory_config())
if mode == "decode"
else ttnn.DRAM_MEMORY_CONFIG,
dtype=self.args.ccl_dtype,
use_composite=True if self.dim == 8192 else False,
)
# Ensure dim 0 and 1 are 1
original_shape = w2_out_reduced.shape
w2_out_reduced = ttnn.reshape(
w2_out_reduced, (1, 1, original_shape[-4] * original_shape[-3] * original_shape[-2], original_shape[-1])
)
if mode == "decode":
w2_out_reduced = ttnn.to_memory_config(
w2_out_reduced,
self.model_config["SHARDED_ATTN_INPUT_MEMCFG"] if TG else self.model_config["DECODE_RESIDUAL_MEMCFG"],
)
# ttnn.deallocate(w2_out)
return w2_out_reduced