-
Notifications
You must be signed in to change notification settings - Fork 88
/
model_config.py
1072 lines (958 loc) · 46.1 KB
/
model_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
import math
import os
import json
import ttnn
from pathlib import Path
from loguru import logger
import torch
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer
from models.demos.llama3.tt.llama_common import (
precompute_freqs,
freqs_to_rotation_matrix,
num_to_core_range_set,
calculate_hidden_dim,
get_out_subblock_w,
)
from typing import Tuple
from models.utility_functions import nearest_32
from pathlib import Path
from tqdm import tqdm
class TtModelArgs:
paged_attention_config = None
# TODO Update these params. In init we update the max_seq_len to 32k if it's a single device
max_batch_size = 1
# Context length for Llama models (if single device, reduce to 32k in init)
max_seq_len = 8192 * 16 # 128k
kv_seq_len = 8192 * 16 # 128k
sliding_window = 8192 * 16 # 128k
tile_size = 32
OP_KEYS = (
# Embedding
"EMB_WEIGHTS",
# Feed forward
"MLP_WEIGHTS",
"FF1_OUTPUT",
"FF3_OUTPUT",
"FF2_OUTPUT",
"MLP_W_LAYOUT",
# Attention
"ATTN_WEIGHTS",
"XQKV_MM_OUTPUT",
"QKV_HEADS_OUTPUT",
"QV_ROT_EMB_OUTPUT",
"KV_UNPAD_OUTPUT",
"QK_MM_OUTPUT",
"QKV_MM_OUTPUT",
"CONCAT_HEADS_OUTPUT",
"ATTN_OUTPUT",
"ATTN_W_LAYOUT",
# Decoder
"DECODE_RESIDUAL",
"OUTPUT_MM",
)
LOCAL_LLAMA_PARAMS = {
"LLAMA3_2_1B_PARAMS": "models/demos/llama3/model_params/Llama3.2-1B-Instruct",
"LLAMA3_2_3B_PARAMS": "models/demos/llama3/model_params/Llama3.2-3B-Instruct",
"LLAMA3_1_8B_PARAMS": "models/demos/llama3/model_params/Llama3.1-8B-Instruct",
"LLAMA3_2_11B_PARAMS": "models/demos/llama3/model_params/Llama3.2-11B-Vision-Instruct",
"LLAMA3_1_70B_PARAMS": "models/demos/llama3/model_params/Llama3.1-70B-Instruct",
}
def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_size=1):
# Add this near the top of the class, with other class attributes
self.num_devices = mesh_device.get_num_devices() if mesh_device else 0
self.mesh_device = mesh_device
self.device_name = {0: "CPU", 1: "N150", 2: "N300", 8: "T3K", 32: "TG"}[self.num_devices]
self.is_large_model = False
self.model_name = "Unknown" # Llama model name will be dependent on the checkpoint directory
LLAMA_DIR = os.getenv("LLAMA_DIR")
if LLAMA_DIR:
if any([os.getenv("LLAMA_CKPT_DIR"), os.getenv("LLAMA_TOKENIZER_PATH"), os.getenv("LLAMA_CACHE_PATH")]):
logger.warning(
"LLAMA_DIR is set and will override LLAMA_CKPT_DIR, LLAMA_TOKENIZER_PATH, and LLAMA_CACHE_PATH"
)
self.DEFAULT_CKPT_DIR = LLAMA_DIR
self.DEFAULT_TOKENIZER_PATH = LLAMA_DIR
self.DEFAULT_CACHE_PATH = os.path.join(LLAMA_DIR, self.device_name)
else:
assert "Please set $LLAMA_DIR to a valid checkpoint directory"
if not dummy_weights:
# Assert if all folders and files exist
assert os.path.exists(
self.DEFAULT_CKPT_DIR
), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please set LLAMA_DIR=... or LLAMA_CKPT_DIR=..."
assert os.path.isfile(
self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model"
), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please set LLAMA_TOKENIZER_PATH=..."
if not os.path.exists(self.DEFAULT_CACHE_PATH):
os.makedirs(self.DEFAULT_CACHE_PATH)
assert os.path.exists(
self.DEFAULT_CACHE_PATH
), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please set LLAMA_CACHE_PATH=..."
# Check if weights exist in the specified folder. If not warn the user to run the download and untar script.
# assert os.path.isfile(
# self.DEFAULT_CKPT_DIR + "/consolidated.00.pth"
# ), f"weights consolidated.00.pth file does not exist. Please use the script `models/demos/llama3/scripts/get_weights.py` to download and untar the weights."
logger.info(f"Checkpoint directory: {self.DEFAULT_CKPT_DIR}")
logger.info(f"Tokenizer file: {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'}")
logger.info(f"Cache directory: {self.DEFAULT_CACHE_PATH}")
# Set the model name based on the checkpoint directory being loaded
if "3.2-1B" in LLAMA_DIR:
local_params = "LLAMA3_2_1B_PARAMS"
self.model_name = "3.2-1B"
elif "3.2-3B" in LLAMA_DIR:
local_params = "LLAMA3_2_3B_PARAMS"
self.model_name = "3.2-3B"
elif "3.1-8B" in LLAMA_DIR:
local_params = "LLAMA3_1_8B_PARAMS"
self.model_name = "3.1-8B"
elif "3.2-11B" in LLAMA_DIR:
local_params = "LLAMA3_2_11B_PARAMS"
self.model_name = "3.2-11B"
elif "3.1-70B" in LLAMA_DIR:
local_params = "LLAMA3_1_70B_PARAMS"
self.model_name = "3.1-70B"
self.is_large_model = True
else:
raise ValueError(f"Unsupported LLAMA model: {LLAMA_DIR}")
# Load model params
if not dummy_weights:
self._set_llama_params(self.DEFAULT_CKPT_DIR)
else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders.
self._set_llama_params(self.LOCAL_LLAMA_PARAMS[local_params])
# Reduce full 128k context length for combinations with memory constraints
# Currently: n150 8b and t3k 70b with 8b/8b/8b MLPs
# Default folder location for weights and cached files
# FIXME: Setup the max cache size accordingly depending on the target model, architecture and test type.
if (
self.num_devices <= 2
): # for 1-chip or 2-chip devices limit the seqlen to 4K (to avoid OoO on N150/N300 CI tests)
self.max_seq_len = 1024 * 4
self.kv_seq_len = 1024 * 4
self.sliding_window = 1024 * 4
if (
self.n_layers == 1
): # When running a single layer just reduce the seq len to 128, since we won't be decoding that many iterations
self.max_seq_len = 128
self.kv_seq_len = 128
self.sliding_window = 128
# Some consumers like SentencePiece only accept str not Path for files
self.model_base_path = Path(self.DEFAULT_CKPT_DIR)
self.model_cache_path = Path(self.DEFAULT_CACHE_PATH)
# Load weights and tokenizer
self.consolidated_weights_path = self.DEFAULT_CKPT_DIR + "/consolidated.00.pth"
self.tokenizer_path = self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model"
self.instruct = instruct
# If the weights file contain the keyword `instruct` also set self.instruct to true
if "instruct" in self.DEFAULT_CACHE_PATH.lower():
self.instruct = True
self.dummy_weights = dummy_weights
self.max_batch_size = max_batch_size
self.tile_padded_batch_rows = self.tile_size * int(math.ceil(self.max_batch_size / self.tile_size))
# Enable workarounds by default until di/dt issues are fixed
self.di_dt_workaround = os.getenv("DISABLE_DI_DT_WORKAROUND") != "1"
if not self.di_dt_workaround:
logger.info("Disabling di/dt workaround, re-enable if you see hangs")
DRAM_MEMCFG = ttnn.DRAM_MEMORY_CONFIG
L1_MEMCFG = ttnn.L1_MEMORY_CONFIG
self.model_config = {}
# Update memory configs (weights->DRAM, activations->L1)
self.model_config.update(
{f"{key}_MEMCFG": DRAM_MEMCFG if "WEIGHTS" in key else L1_MEMCFG for key in self.OP_KEYS}
)
# Update memory layouts (Tile, except MLP)
self.model_config.update({f"{key}_TILE": ttnn.TILE_LAYOUT for key in self.OP_KEYS if "LAYOUT" in key})
self.cos, self.sin = precompute_freqs(
self.head_dim, self.max_seq_len * 2, self.rope_theta, self.use_scaled_rope
) # for prefill
self.rot_emb = freqs_to_rotation_matrix(self.cos, self.sin) # for decode
device = mesh_device.get_devices()[0] if mesh_device is not None else None
if device is not None: # Avoid issue with test_llama_torch.py not having a device
self.n_local_heads = self.n_heads // self.num_devices
grid = device.compute_with_storage_grid_size()
self.max_grid_size = ttnn.CoreGrid(x=grid.x, y=grid.y)
# DRAM weight grid specs for dram sharding matmuls
self.dram_weight_grid = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1),
)
}
)
# Compute kernels. FP32 acc does not appear to be needed for accuracy in model tests or demo runs.
self.compute_kernel_config_lofi = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
self.compute_kernel_config_hifi4 = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=True,
)
self.compute_kernel_config_sdpa = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=False,
)
self.model_config["COMPUTE_KERNEL_CONFIG_HIFI2"] = self.compute_kernel_config_hifi2
residual_grid = self.dram_shard_core_grid_for_k(self.dim // self.num_devices)
self.model_config["DECODE_RESIDUAL_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
self.dim // residual_grid.num_cores // self.num_devices,
),
residual_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
# Chunk values based on what works best empirically
self.model_config["SDPA_PROGCFG"] = lambda seqlen: ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=(8, 8),
q_chunk_size=256 if seqlen >= 2048 else 64,
k_chunk_size=256 if seqlen >= 2048 else 64,
)
def find_largest_divisor(n, max_divisor=8):
for i in range(max_divisor, 0, -1):
if n % i == 0:
return i
return 1 # Fallback to 1 if no divisor found
# nlp_concat_heads_decode will shard the data across this number of cores
assert (
self.n_heads % self.num_devices == 0
), f"n_heads must be divisible by num_devices: {self.n_heads} % {self.num_devices}"
self.model_config["ATTN_OUTPUT_PROGCFG"] = self.dram_matmul_config(
m=self.tile_padded_batch_rows,
k=self.dim // self.num_devices,
n=self.dim,
num_cores=self.n_heads // self.num_devices,
)
# All Gather Matmul for Dense Out (DO)
# TODO: Is there a better way to decide if fused all gather matmul should be used? And is there a better way to use the flag, instead of passing it into model_config?
# NOTE: Fused all gather matmul only suppports a core grid of size num_devices x 1
self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] = (
self.ccl_topology() == ttnn.Topology.Ring
and (self.dim // self.tile_size // self.num_devices) % self.num_devices == 0
)
if self.model_config["USE_FUSED_ALL_GATHER_MATMUL"]:
do_core_grid_size = (8, 1)
do_per_core_N = (
self.dim // self.num_devices // self.tile_size // (do_core_grid_size[0] * do_core_grid_size[1])
)
self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"] = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=do_core_grid_size,
in0_block_w=self.dim
// self.tile_size
// (do_core_grid_size[0] * do_core_grid_size[1]), # [32 x 8k] x [8k x 1k] = [32 x 1k]
out_subblock_h=1,
out_subblock_w=get_out_subblock_w(
do_per_core_N, out_subblock_h=1
), # Max out_subblock_w = 4, needs to be divisible by per_core_N
per_core_M=self.tile_padded_batch_rows // self.tile_size,
per_core_N=do_per_core_N,
fuse_batch=True,
fused_activation=None,
mcast_in0=True,
)
else:
self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"] = None
self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"] = self.matmul_config(
m=1024, k=self.dim, n=self.hidden_dim // self.num_devices, grid_size=(8, 8)
)
self.model_config["PREFILL_MLP_W2_PRG_CONFIG"] = self.matmul_config(
m=1024, k=self.hidden_dim, n=self.dim, grid_size=(8, 8)
)
self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG_128"] = lambda seq_len: self.matmul_config(
m=seq_len, k=self.dim, n=self.hidden_dim // self.num_devices, grid_size=(8, 4)
)
self.model_config["PREFILL_MLP_W2_PRG_CONFIG_128"] = lambda seq_len: self.matmul_config(
m=seq_len, k=self.hidden_dim, n=self.dim, grid_size=(8, 4)
)
self.model_config["WO_PREFILL_PROGCFG"] = lambda seq_len: self.matmul_config(
m=min(seq_len, 2048),
k=self.dim,
n=self.dim,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= 2048,
)
# Calculate largest number of lm_head_num_rows such that self.dim % (lm_head_num_rows * 8) == 0
lm_head_num_rows = 8
while self.dim % (32 * lm_head_num_rows * 8) != 0:
lm_head_num_rows -= 1
assert (
lm_head_num_rows > 0
), f"Could not find a lm_head_num_rows such that self.dim(={self.dim}) % (lm_head_num_rows * 8) == 0"
self.lm_head_core_grid = ttnn.CoreGrid(y=lm_head_num_rows, x=8)
self.model_config["LM_HEAD_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
nearest_32(self.dim // self.lm_head_core_grid.num_cores),
), # Shard shape: [32, 128] -> 1 shard per core
self.lm_head_core_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.qkv_size = self.head_dim * (2 * self.n_kv_heads + self.n_heads)
self.model_config["XQKV_PREFILL_PROGCFG"] = lambda seq_len: ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=(8, 8),
in0_block_w=1, # FIXME: optimize this config for prefill, careful use DI_DT_WORKAROUND if necessary
out_subblock_h=1, # Must be divisible by per_core_M
out_subblock_w=1, # Must be divisible by per_core_N, out_subblock_w * out_subblock_h <= 4
per_core_M=max(
1, 8 if seq_len >= 2048 else seq_len // self.tile_size // 8 # 8 rows
), # M / TILE_HEIGHT / Grid_Size (dynamic based on seqlen)
per_core_N=math.ceil(self.qkv_size / self.num_devices / 32 / 8), # N / TILE_WIDTH / grid width
transpose_mcast=False,
fused_activation=None,
fuse_batch=seq_len <= 2048,
)
assert self.n_kv_heads % self.num_devices == 0, "n_kv_heads must be divisible by num_devices"
self.min_kv_prefill_shard_seqlen = (self.tile_size * 8 * 8) / (self.n_kv_heads // self.num_devices)
self.model_config["KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config(
(((self.n_kv_heads // self.num_devices) * seq_len // (8 * 8)), self.head_dim),
ttnn.CoreGrid(y=8, x=8),
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["SDPA_DECODE_PROGCFG"] = ttnn.SDPAProgramConfig(
compute_with_storage_grid_size=(8, 8),
q_chunk_size=32,
k_chunk_size=32,
)
self.model_config["SDPA_DECODE_COMPUTE_PROGCFG"] = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
# Useful core grid based on batch size
if self.max_batch_size == 32:
grid_by_batch = (8, 4)
elif self.max_batch_size == 16:
grid_by_batch = (8, 2)
elif self.max_batch_size == 8:
grid_by_batch = (8, 1)
elif self.max_batch_size == 4:
grid_by_batch = (4, 1)
elif self.max_batch_size == 2:
grid_by_batch = (2, 1)
elif self.max_batch_size == 1:
grid_by_batch = (1, 1)
else:
raise ValueError(f"Batch size {self.max_batch_size} not supported")
core_grid_by_batch = ttnn.CoreGrid(y=grid_by_batch[1], x=grid_by_batch[0])
core_range_set_by_batch = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(grid_by_batch[0] - 1, grid_by_batch[1] - 1),
),
}
)
self.model_config["SCORES_BATCHED_MM_OUTPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
shape=(math.ceil(self.n_local_heads / 32) * 32, self.head_dim), # self.n_heads padded to tile size
core_grid=core_grid_by_batch,
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["ROT_MAT_BMM_PROGCFG"] = lambda m, k, n: ttnn.MatmulMultiCoreReuseProgramConfig(
compute_with_storage_grid_size=grid_by_batch,
in0_block_w=math.ceil(k / 32),
out_subblock_h=1,
out_subblock_w=1, # TODO How to choose this subblock size?
per_core_M=math.ceil(m / 32),
per_core_N=math.ceil(n / 32),
)
self.model_config["ROT_MAT_MEMCONFIG"] = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
core_range_set_by_batch,
[
128,
128,
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
),
)
# Width sharded
# mlp_core_grid = self.dram_shard_core_grid_for_k(self.dim)
mlp_core_grid = self.dram_shard_core_grid_for_k_and_n(self.dim, self.hidden_dim // self.num_devices)
self.model_config["SHARDED_MLP_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
self.dim // mlp_core_grid.num_cores,
), # Shard shape: [32, 128] -> 1 shard per core
mlp_core_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["DECODE_MLP_W1_W3_PRG_CONFIG"] = self.dram_matmul_config(
m=self.tile_padded_batch_rows,
k=self.dim,
n=self.hidden_dim // self.num_devices,
num_cores=mlp_core_grid.num_cores,
)
# mlp2_core_grid = self.dram_shard_core_grid_for_k(self.hidden_dim // self.num_devices)
mlp2_core_grid = self.dram_shard_core_grid_for_k_and_n(self.hidden_dim // self.num_devices, self.dim)
self.model_config["SHARDED_MLP2_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
self.hidden_dim // self.num_devices // mlp2_core_grid.num_cores,
),
mlp2_core_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["DECODE_MLP_W2_PRG_CONFIG"] = self.dram_matmul_config(
m=self.tile_padded_batch_rows,
k=self.hidden_dim // self.num_devices,
n=self.dim,
num_cores=mlp2_core_grid.num_cores,
)
attn_input_grid = self.dram_shard_core_grid_for_k(self.dim)
self.model_config["SHARDED_ATTN_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config(
(
self.tile_padded_batch_rows,
self.dim // attn_input_grid.num_cores,
),
attn_input_grid,
ttnn.ShardStrategy.WIDTH,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.model_config["XQKV_DECODE_PROGCFG"] = self.dram_matmul_config(
m=self.tile_padded_batch_rows,
k=self.dim,
n=self.qkv_size // self.num_devices,
num_cores=attn_input_grid.num_cores,
)
# Vision model configs
self.model_config["IMAGE_MLP_FC_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config(
m=min(seq_len, max_seq),
k=self.vision_dim,
n=self.vision_hidden_dim // self.num_devices,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= max_seq,
)
self.model_config["IMAGE_MLP_PROJ_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config(
m=min(seq_len, max_seq),
k=self.vision_hidden_dim // self.num_devices,
n=self.vision_dim,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= max_seq,
)
self.model_config["IMAGE_ATTN_QKV_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config(
m=min(seq_len, max_seq),
k=self.vision_dim,
n=(nearest_32(self.vision_head_dim) * self.vision_attn_n_heads * 3)
// self.num_devices, # Head dim was padded to nearest 32
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= max_seq,
)
self.model_config["IMAGE_ATTN_OUT_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config(
m=min(seq_len, max_seq),
k=(nearest_32(self.vision_head_dim) * self.vision_attn_n_heads * 3) // self.num_devices,
n=self.vision_dim,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= max_seq,
)
self.model_config["VISION_XATTN_Q_PROGCFG"] = lambda seq_len: self.matmul_config(
m=min(seq_len, 1024),
k=self.dim,
n=(self.head_dim * self.n_heads) // self.num_devices,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= 1024,
)
self.model_config["VISION_XATTN_KV_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config(
m=min(seq_len, max_seq),
k=self.dim,
n=(self.head_dim * self.n_kv_heads) // self.num_devices,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= max_seq,
)
self.model_config["VISION_XATTN_SCORE_PROGCFG"] = lambda seq_len, cache_seq_len: self.matmul_config(
m=seq_len,
k=self.head_dim,
n=cache_seq_len,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=False,
)
self.model_config["VISION_XATTN_OUTPUT_PROGCFG"] = lambda seq_len, cache_seq_len: self.matmul_config(
m=seq_len,
k=cache_seq_len,
n=self.head_dim,
grid_size=(8, 8),
# in0_block_w=1, # TODO: Remove this when we get non-causal FlashDecode
fuse_batch=False,
)
self.model_config["VISION_XATTN_DENSE_PROGCFG"] = lambda seq_len: self.matmul_config(
m=seq_len,
k=self.dim // self.num_devices,
n=self.dim,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=False,
)
self.model_config["VISION_PROJ_PROGCFG"] = lambda seq_len: self.matmul_config(
m=seq_len,
k=self.vision_dim * 6,
n=self.dim // self.num_devices,
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=False,
)
self.model_config["CROSS_TRANSFORMER_TEXT_OUTPUT_PROGCFG"] = lambda seq_len, max_seq: self.matmul_config(
m=min(seq_len, max_seq),
k=self.dim,
n=self.vocab_size // 8, # Magic number. LM Head always contains 8 splits
grid_size=(8, 8),
in0_block_w=1,
fuse_batch=seq_len <= max_seq,
)
xattn_cache_y_cores = (
16 // self.num_devices
) # Based on seqlen, this formula gives us a valid number of y cores
xattn_cache_x_cores = 8
self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = lambda seq_len: ttnn.create_sharded_memory_config(
# using n_heads since xattn repeats KV to match Q
(
nearest_32(
(self.n_heads // self.num_devices) * seq_len // (xattn_cache_y_cores * xattn_cache_x_cores)
),
self.head_dim,
),
ttnn.CoreGrid(y=xattn_cache_y_cores, x=xattn_cache_x_cores),
ttnn.ShardStrategy.HEIGHT,
ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok)
# RMS NORM
self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"] = self.create_sharded_norm_config(attn_input_grid)
self.model_config["SHARDED_NORM_MLP_PRGM_CFG"] = self.create_sharded_norm_config(mlp_core_grid)
self.model_config["SHARDED_NORM_LM_HEAD_PRGM_CFG"] = self.create_sharded_norm_config(self.lm_head_core_grid)
# All gather matmuls currently only supported on T3K
# We need it sharded on num_cores = num_devices
self.model_config["ATTN_ALL_GATHER_MATMUL_OUTPUT_MEMCFG"] = ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.WIDTH_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
num_to_core_range_set(self.num_devices),
[
self.tile_padded_batch_rows,
self.dim // self.num_devices,
],
ttnn.ShardOrientation.ROW_MAJOR,
False,
),
)
self.is_2d_fracturing = all([dim > 1 for dim in self.mesh_device.shape]) if self.mesh_device else False
self.is_multichip = self.num_devices > 1
def is_distributed_norm(self, mode):
if not self.is_multichip:
return False
if all([dim > 1 for dim in self.mesh_device.shape]): # 2D grid
return True
elif self.dim >= 8192 and mode == "prefill": # Somewhere between 4k and 8k WH runs out of L1 if not distributed
return True
return False
def ccl_topology(self):
if self.num_devices == 8: # T3K
return ttnn.Topology.Ring
elif self.num_devices > 1: # All other multi chip devices
return ttnn.Topology.Linear
return None
def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False, on_host=False):
"""
Prepare inputs for decode mode.
x: (batch, seq, dim)
"""
mesh_mapper = (
ttnn.ReplicateTensorToMesh(self.mesh_device)
if force_replicated
else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1)
)
if len(x.shape) == 3:
batch = x.shape[0]
seq_len = x.shape[1]
assert x.shape[2] == self.dim
elif len(x.shape) == 4:
seq_len = x.shape[0]
assert x.shape[1] == 1
batch = x.shape[2]
assert x.shape[3] == self.dim
assert seq_len == 1, "Only supporting decode mode"
# Support input on device
if torch.is_tensor(x): # Input on host -> Use torch
x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, dim]
# Pad small batches to 32
if batch < 32:
zeros = torch.zeros(1, seq_len, 32, self.dim)
zeros[:, :, :batch, :] = x
x = zeros
elif len(x.shape) == 3: # Input on device -> Use ttnn
x = ttnn.reshape(x, (batch, seq_len, 1, self.dim)) # [batch, seqlen, dim] -> [batch, seqlen, 1, dim]
x = ttnn.permute(x, (1, 2, 0, 3)) # [seq_len, 1, batch, dim]
elif len(x.shape) == 4:
pass # already in [seq_len, 1, batch, dim]
if torch.is_tensor(x):
x = ttnn.from_torch(
x,
device=self.mesh_device if not on_host else None,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=mesh_mapper,
memory_config=input_mem_cfg if not on_host else None,
)
else: # Convert the row major layout from embedding back to tile layout
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT)
return x
def prepare_inputs_ttnn_prefill(self, x_bsh, force_replicated=False):
"""
Prepare inputs for prefill mode.
x: (batch, seq, hidden_dim)
B: batch (1)
S: sequence len
H: dim
"""
x_1BSH = x_bsh.unsqueeze(0)
mesh_mapper = (
ttnn.ReplicateTensorToMesh(self.mesh_device)
if force_replicated
else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1)
)
# input goes to DRAM
xs_1BSH = ttnn.from_torch(
x_1BSH,
device=self.mesh_device,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
mesh_mapper=mesh_mapper,
)
return xs_1BSH
def _set_llama_params_from_dict(self, params):
# Text params
self.dim = params["dim"]
self.ffn_dim_multiplier = params["ffn_dim_multiplier"]
self.multiple_of = params["multiple_of"]
self.n_heads = params["n_heads"]
self.n_kv_heads = params["n_kv_heads"]
self.n_layers = params["n_layers"]
self.norm_eps = params["norm_eps"]
self.rope_theta = params["rope_theta"]
self.use_scaled_rope = params["use_scaled_rope"]
self.vocab_size = params["vocab_size"]
self.head_dim = self.dim // self.n_heads
self.hidden_dim = calculate_hidden_dim(self.dim, self.ffn_dim_multiplier, self.multiple_of)
# Vision params
self.vision_chunk_size = params.get("vision_chunk_size", -1)
self.vision_max_num_chunks = params.get("vision_max_num_chunks", 4)
self.vision_num_cross_attention_layers = params.get("vision_num_cross_attention_layers", -1)
# Vision constants
self.vision_dim = 1280
self.vision_mlp_ratio = 4
self.vision_hidden_dim = int(self.vision_dim * self.vision_mlp_ratio)
self.vision_act_layer = ttnn.UnaryOpType.GELU
self.vision_dropout = 0.0
self.vision_attn_n_heads = 16
self.vision_head_dim = self.vision_dim // self.vision_attn_n_heads
self.vision_n_layers = 32
self.vision_n_global_layers = 8
self.vision_max_num_tiles = 4
self.vision_patch_size = 14
self.vision_in_channels = 3
@property
def vision_chunk_ntok(self):
"""
Returns the number of tokens per chunk, accounting for the extra class token
"""
return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1
def _set_llama_params(self, checkpoint_dir):
params_file = os.path.join(checkpoint_dir, "params.json")
assert os.path.exists(params_file), f"params.json file not found at {params_file}"
with open(params_file, "r") as f:
params = json.load(f)
self._set_llama_params_from_dict(params)
def __repr__(self):
return f"""ModelArgs(
dim={self.dim},
n_layers={self.n_layers},
n_heads={self.n_heads},
n_kv_heads={self.n_kv_heads},
vocab_size={self.vocab_size},
multiple_of={self.multiple_of},
ffn_dim_multiplier={self.ffn_dim_multiplier},
norm_eps={self.norm_eps},
rope_theta={self.rope_theta},
use_scaled_rope={self.use_scaled_rope},
max_batch_size={self.max_batch_size},
max_seq_len={self.max_seq_len},
vision_chunk_size={self.vision_chunk_size},
vision_max_num_chunks={self.vision_max_num_chunks},
vision_num_cross_attention_layers={self.vision_num_cross_attention_layers}
)"""
def is_vision(self):
return self.vision_chunk_size > 0
def get_state_dict_prefix(self, module_name, layer_num):
text_prefix = "text_model." if self.is_vision() else ""
layer_prefix = f"layers.{layer_num}." if layer_num is not None else ""
module_map = {
"TtLlamaMLP": "feed_forward",
"TtLlamaAttention": "attention",
"TtTransformerBlock": "",
"": "", # If no module is given, just get layer prefix
}
return text_prefix + layer_prefix + module_map[module_name]
def weight_cache_path(self, dtype):
# Keep the weight cache separate for generative and instruct weights
if self.instruct:
return (
self.model_cache_path
/ {ttnn.bfloat16: "tensor_cache_instruct_bf16", ttnn.bfloat8_b: "tensor_cache_instruct_bfp8"}[dtype]
)
else:
return (
self.model_cache_path / {ttnn.bfloat16: "tensor_cache_bf16", ttnn.bfloat8_b: "tensor_cache_bfp8"}[dtype]
)
def get_model_config(self):
return self.model_config
# TODO Update function for large models: For 1 layer tests we only want to load 1 checkpoint file, instead of all.
def load_state_dict(self):
"""Generate or load state_dict for n_layers of the model"""
if self.dummy_weights:
reference_model = Transformer(self)
state_dict = reference_model.state_dict()
state_dict_prefix = self.get_state_dict_prefix("", None)
state_dict = {f"{state_dict_prefix}{k}": torch.randn_like(v) for k, v in state_dict.items()}
else:
state_dict = load_llama_state_dict(self.DEFAULT_CKPT_DIR, self.n_layers)
keys_dict = list(state_dict.keys())[:]
remv = [
f"layers.{i}." for i in list(range(self.n_layers, 32))
] # TODO, this is not generalized to all models. it assumes max layers = 32
for k in keys_dict:
if any([r in k for r in remv]):
state_dict.pop(k)
return state_dict
def create_dram_sharded_mem_config(self, k, n):
"""Create DRAM-sharded memory config for width-sharded tensors"""
dram_cores = 12
padded_size = math.ceil(n / (self.tile_size * dram_cores)) * (self.tile_size * dram_cores)
shard_spec = ttnn.ShardSpec(
self.dram_weight_grid, (k, padded_size // dram_cores), ttnn.ShardOrientation.ROW_MAJOR, False
)
return ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.DRAM, shard_spec)
def matmul_config(
self,
m: int,
k: int,
n: int,
grid_size: Tuple[int, int],
in0_block_w: int = None,
fuse_batch: bool = False,
fused_activation=None,
) -> ttnn.MatmulMultiCoreReuseMultiCastProgramConfig:
per_core_M = math.ceil(m / (self.tile_size * grid_size[1]))
per_core_N = math.ceil(n / (self.tile_size * grid_size[0]))
out_subblock_h = 1
out_subblock_w = get_out_subblock_w(per_core_N, out_subblock_h)
if in0_block_w is None:
in0_block_w = min(4, max(1, k // (self.tile_size * grid_size[0])))
return ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
out_subblock_h=out_subblock_h,
out_subblock_w=out_subblock_w,
per_core_M=per_core_M,
per_core_N=per_core_N,
transpose_mcast=False,
fused_activation=fused_activation,
fuse_batch=fuse_batch,
)
def dram_shard_core_grid_for_k(self, k: int) -> Tuple[int, int]:
rows, cols = self.find_grid(k // self.tile_size)
return ttnn.CoreGrid(x=cols, y=rows)
def find_grid(self, N):
"""
Find the number of rows and columns for a grid of cores such that
the total number of tiles N can be evenly divided among the cores.
Each core will have the same integer number of tiles.
The grid size is limited to a maximum of 2 rows and 8 columns.
Parameters:
N (int): Total number of tiles to be distributed.
Returns:
tuple: A tuple (rows, cols) representing the grid dimensions.
Raises:
AssertionError: If it's not possible to find such a grid configuration.
"""
max_rows = 8
max_cols = 8
max_cores = max_rows * max_cols
# Find all possible numbers of cores that divide N and are less than or equal to max_cores
target = 32
possible_cores = [k for k in range(1, max_cores + 1) if N % k == 0]
possible_cores.sort(key=lambda x: abs(x - target)) # Sort by closest to target
for cores in possible_cores:
# Try to find a grid configuration with the current number of cores
for rows in range(1, max_rows + 1):
if cores % rows == 0:
cols = cores // rows
if cols <= max_cols:
return rows, cols
# If no configuration is found, assert an error
raise AssertionError(
f"Cannot find a grid configuration for {N} tiles that evenly divides into {max_cores} cores of max size {max_rows}x{max_cols}."
)
def dram_shard_core_grid_for_k_and_n(self, k: int, n: int) -> Tuple[int, int]:
rows, cols = self.find_grid_k_n(k // self.tile_size, n // self.tile_size)
return ttnn.CoreGrid(x=cols, y=rows)
def find_grid_k_n(self, K, N):
"""
Find the number of rows and columns for a grid of cores such that
the total number of tiles N can be evenly divided among the cores.
Each core will have the same integer number of tiles.
The grid size is limited to a maximum of 2 rows and 8 columns.
Parameters:
N (int): Total number of tiles to be distributed.
Returns:
tuple: A tuple (rows, cols) representing the grid dimensions.
Raises:
AssertionError: If it's not possible to find such a grid configuration.
"""
max_rows = 4
max_cols = 8 # Maximum number of rows or columns
max_cores = max_rows * max_cols # Maximum number of cores (8x2 grid)
# Find all possible numbers of cores that divide N and are less than or equal to max_cores
possible_cores = [c for c in range(1, max_cores + 1) if K % c == 0 and N % c == 0]
possible_cores.sort(reverse=True) # Start checking from the largest number of cores
for cores in possible_cores:
# Try to find a grid configuration with the current number of cores
for rows in range(1, max_rows + 1):
if cores % rows == 0:
cols = cores // rows
if cols <= max_cols:
return rows, cols
# If no configuration is found, assert an error
raise AssertionError(
f"Cannot find a grid configuration such that both {K} and {N} tiles evenly divide into cores of max size {max_rows}x{max_cols}."
)
def dram_matmul_config(
self, m: int, k: int, n: int, num_cores=None
) -> ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig:
# in0_block_w must evenly divide k and be no larger than tile_size * num_cores
if num_cores is None:
# num_cores = self.dram_shard_core_grid_for_k_and_n(k).num_cores
num_cores = self.dram_shard_core_grid_for_k_and_n(k, n).num_cores
assert (
k % (self.tile_size * num_cores) == 0
), f"k must be divisible by tile_size * num_cores: {k} % {self.tile_size * num_cores} != 0"
# assert n % (self.tile_size * num_cores) == 0, f"n must be divisible by tile_size * num_cores: {n} % {self.tile_size * num_cores} != 0"
return ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig(
in0_block_w=math.ceil(k / (self.tile_size * num_cores)),
per_core_M=math.ceil(m / self.tile_size),
per_core_N=math.ceil(n / (self.tile_size * num_cores)),
fused_activation=None,
)
def create_sharded_norm_config(self, grid):
"""Helper function to create LayerNormShardedMultiCoreProgramConfig for RMS NORM.
Args:
grid (ttnn.CoreGrid): Grid specification for the norm operation
"""
block_w = self.dim // grid.num_cores // self.tile_size
# Find largest value <= 4 that evenly divides block_w
subblock_w = 4
while subblock_w > 0:
if block_w % subblock_w == 0:
break
subblock_w -= 1
return ttnn.LayerNormShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=[grid.x, grid.y],