-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathgptAttentionPlugin.cpp
1396 lines (1276 loc) · 67.1 KB
/
gptAttentionPlugin.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gptAttentionPlugin.h"
#include "tensorrt_llm/batch_manager/contextProgress.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
#include "tensorrt_llm/kernels/gptKernels.h"
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
#include "tensorrt_llm/plugins/common/checkMacrosPlugin.h"
#include "tensorrt_llm/plugins/common/plugin.h"
#include "tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h"
#include "tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommonImpl.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/utils/debugUtils.h"
#include <NvInferRuntimeBase.h>
#include <algorithm>
#include <cstdint>
#include <functional>
#include <numeric>
using namespace nvinfer1;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::common;
using tensorrt_llm::plugins::GPTAttentionPluginCreator;
using tensorrt_llm::plugins::GPTAttentionPlugin;
static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling,
float attn_logit_softcapping_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
float rotary_embedding_long_m_scale, // magnitude scaling factors for Phi-3 long RoPE
int rotary_embedding_max_positions, int rotary_embedding_original_max_positions, int tp_size,
int tp_rank, // for ALiBi
bool unfuse_qkv_gemm, // for AutoPP
bool use_logn_scaling, // for LognScaling
tensorrt_llm::kernels::ContextFMHAType context_fmha_type, int kv_cache_quant_mode, bool remove_input_padding,
tensorrt_llm::kernels::AttentionMaskType mask_type, tensorrt_llm::kernels::BlockSparseParams block_sparse_params,
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool has_full_attention_mask, bool use_cache,
bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
int spec_decoding_max_generation_length, bool is_mla_enabled, int q_lora_rank, int kv_lora_rank,
int qk_nope_head_dim, int qk_rope_head_dim, int v_head_dim, bool skip_attn, int cp_size, int cp_rank,
std::set<int32_t> cp_group)
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, layer_idx_in_cache_pool,
head_size, unidirectional, q_scaling, attn_logit_softcapping_scale, position_embedding_type,
rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale,
rotary_embedding_short_m_scale, rotary_embedding_long_m_scale, rotary_embedding_max_positions,
rotary_embedding_original_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, use_logn_scaling, context_fmha_type,
kv_cache_quant_mode, remove_input_padding, mask_type, block_sparse_params, paged_kv_cache, tokens_per_block,
type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, has_full_attention_mask, use_cache,
is_spec_decoding_enabled, spec_decoding_is_generation_length_variable, spec_decoding_max_generation_length,
is_mla_enabled, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, skip_attn, cp_size,
cp_rank, cp_group)
{
initEntryIdx();
}
GPTAttentionPlugin::GPTAttentionPlugin(void const* data, size_t length)
: GPTAttentionPluginCommon(data, length)
{
initEntryIdx();
}
std::string GPTAttentionPlugin::toString(IdxEntry const& entry) const
{
switch (entry)
{
case IdxEntry::QKV_TENSOR: return "QKV_TENSOR";
case IdxEntry::K_TENSOR: return "K_TENSOR";
case IdxEntry::V_TENSOR: return "V_TENSOR";
case IdxEntry::ATTENTION_MASK: return "ATTENTION_MASK";
case IdxEntry::ATTENTION_PACKED_MASK: return "ATTENTION_PACKED_MASK";
case IdxEntry::SEQUENCE_LENGTH: return "SEQUENCE_LENGTH";
case IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS: return "HOST_PAST_KEY_VALUE_LENGTHS";
case IdxEntry::HOST_MAX_ATTENTION_WINDOW: return "HOST_MAX_ATTENTION_WINDOW";
case IdxEntry::HOST_SINK_TOKEN_LENGTH: return "HOST_SINK_TOKEN_LENGTH";
case IdxEntry::CONTEXT_LENGTHS: return "CONTEXT_LENGTHS";
case IdxEntry::CACHE_INDIR: return "CACHE_INDIR";
case IdxEntry::REQUEST_TYPES: return "REQUEST_TYPES";
case IdxEntry::KV_CACHE_BLOCK_OFFSETS: return "KV_CACHE_BLOCK_OFFSETS";
case IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS: return "HOST_KV_CACHE_BLOCK_OFFSETS";
case IdxEntry::HOST_KV_CACHE_POOL_POINTERS: return "HOST_KV_CACHE_POOL_POINTERS";
case IdxEntry::HOST_KV_CACHE_POOL_MAPPING: return "HOST_KV_CACHE_POOL_MAPPING";
case IdxEntry::PAST_KEY_VALUE: return "PAST_KEY_VALUE";
case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return "KV_CACHE_QUANTIZATION_SCALE";
case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return "KV_CACHE_DEQUANTIZATION_SCALE";
case IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE: return "ATTENTION_OUTPUT_QUANTIZATION_SCALE";
case IdxEntry::ROTARY_INV_FREQ: return "ROTARY_INV_FREQ";
case IdxEntry::ROTARY_COS_SIN: return "ROTARY_COS_SIN";
case IdxEntry::ALIBI_SLOPES: return "ALIBI_SLOPES";
case IdxEntry::RELATIVE_ATTENTION_BIAS: return "RELATIVE_ATTENTION_BIAS";
case IdxEntry::CROSS_KV: return "CROSS_KV";
case IdxEntry::CROSS_KV_LENGTH: return "CROSS_KV_LENGTH";
case IdxEntry::ENCODER_INPUT_LENGTH: return "ENCODER_INPUT_LENGTH";
case IdxEntry::HOST_CONTEXT_LENGTH: return "HOST_CONTEXT_LENGTH";
case IdxEntry::QKV_BIAS_TENSOR: return "QKV_BIAS_TENSOR";
case IdxEntry::SPEC_DECODING_GENERATION_LENGTHS: return "SPEC_DECODING_GENERATION_LENGTHS";
case IdxEntry::SPEC_DECODING_PACKED_MASK: return "SPEC_DECODING_PACKED_MASK";
case IdxEntry::SPEC_DECODING_POSITION_OFFSETS: return "SPEC_DECODING_POSITION_OFFSETS";
case IdxEntry::SPEC_DECODING_USE: return "SPEC_DECODING_USE";
case IdxEntry::LONG_ROPE_ROTARY_INV_FREQ: return "LONG_ROPE_ROTARY_INV_FREQ";
case IdxEntry::LONG_ROPE_ROTARY_COS_SIN: return "LONG_ROPE_ROTARY_COS_SIN";
case IdxEntry::HOST_RUNTIME_PERF_KNOBS: return "HOST_RUNTIME_PERF_KNOBS";
case IdxEntry::HOST_CONTEXT_PROGRESS: return "HOST_CONTEXT_PROGRESS";
case IdxEntry::SKIP_ATTN: return "SKIP_ATTN";
case IdxEntry::ENUM_SIZE: return "ENUM_SIZE";
}
TLLM_LOG_TRACE(common::fmtstr("Missing string description for IdxEntry enum %lu.\n", static_cast<size_t>(entry)));
return "";
}
bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
{
switch (entry)
{
case IdxEntry::QKV_TENSOR: return true;
case IdxEntry::K_TENSOR: return mUnfuseQkvGemm;
case IdxEntry::V_TENSOR: return mUnfuseQkvGemm;
case IdxEntry::ATTENTION_MASK: return useFullCustomMask();
case IdxEntry::ATTENTION_PACKED_MASK: return useCustomMask();
case IdxEntry::SEQUENCE_LENGTH: return useKVCache();
case IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS: return useKVCache();
case IdxEntry::HOST_MAX_ATTENTION_WINDOW: return true;
case IdxEntry::HOST_SINK_TOKEN_LENGTH: return true;
case IdxEntry::CONTEXT_LENGTHS: return true;
case IdxEntry::CACHE_INDIR: return useKVCache();
case IdxEntry::REQUEST_TYPES: return true;
case IdxEntry::KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache;
case IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache;
case IdxEntry::HOST_KV_CACHE_POOL_POINTERS: return useKVCache() && mPagedKVCache;
case IdxEntry::HOST_KV_CACHE_POOL_MAPPING: return useKVCache() && mPagedKVCache;
case IdxEntry::PAST_KEY_VALUE: return useKVCache() && !mPagedKVCache;
case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
case IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE: return mFP8ContextFMHA && mKVCacheQuantMode.hasFp8Qdq();
case IdxEntry::ROTARY_INV_FREQ: return isRoPE();
case IdxEntry::ROTARY_COS_SIN: return isRoPE() || mIsMLAEnabled;
case IdxEntry::ALIBI_SLOPES: return isALiBi();
case IdxEntry::RELATIVE_ATTENTION_BIAS: return isRelativePosition();
case IdxEntry::CROSS_KV: return isCrossAttention();
case IdxEntry::CROSS_KV_LENGTH: return isCrossAttention();
case IdxEntry::LOGN_SCALING: return isLognScaling();
case IdxEntry::ENCODER_INPUT_LENGTH: return isCrossAttention();
case IdxEntry::HOST_CONTEXT_LENGTH: return mRemovePadding;
case IdxEntry::QKV_BIAS_TENSOR: return mQKVBiasEnabled;
case IdxEntry::SPEC_DECODING_GENERATION_LENGTHS: return mIsSpecDecodingEnabled;
case IdxEntry::SPEC_DECODING_PACKED_MASK: return mIsSpecDecodingEnabled;
case IdxEntry::SPEC_DECODING_POSITION_OFFSETS: return mIsSpecDecodingEnabled;
case IdxEntry::SPEC_DECODING_USE: return mIsSpecDecodingEnabled;
case IdxEntry::LONG_ROPE_ROTARY_INV_FREQ: return isLongRoPE();
case IdxEntry::LONG_ROPE_ROTARY_COS_SIN: return isLongRoPE();
case IdxEntry::MROPE_ROTARY_COS_SIN: return isMRoPE();
case IdxEntry::MROPE_POSITION_DELTAS: return isMRoPE();
case IdxEntry::HOST_RUNTIME_PERF_KNOBS: return true;
case IdxEntry::HOST_CONTEXT_PROGRESS: return true;
case IdxEntry::MLA_FUSED_Q_PROJ_TENSOR: return mIsMLAEnabled;
case IdxEntry::MLA_Q_B_PROJ_TENSOR: return mIsMLAEnabled;
case IdxEntry::MLA_KV_B_PROJ_TENSOR: return mIsMLAEnabled;
case IdxEntry::SKIP_ATTN: return mSkipAttn;
default: return false;
}
}
void GPTAttentionPlugin::initEntryIdx()
{
mEntryIdx.resize(static_cast<size_t>(IdxEntry::ENUM_SIZE));
size_t entryIdx = 0;
for (int i = 0; i < static_cast<size_t>(IdxEntry::ENUM_SIZE); i++)
{
mEntryIdx[i] = entryIdx;
entryIdx += isEntryUsed(static_cast<IdxEntry>(i));
}
}
GPTAttentionPlugin::IndexType GPTAttentionPlugin::getIdx(IdxEntry const& entry) const
{
TLLM_CHECK_WITH_INFO(
isEntryUsed(entry), common::fmtstr("getIdx() should not be used with entry %s.\n", toString(entry).data()));
return mEntryIdx[static_cast<size_t>(entry)];
}
// IPluginV2DynamicExt Methods
GPTAttentionPlugin* GPTAttentionPlugin::clone() const noexcept
{
return dynamic_cast<GPTAttentionPlugin*>(this->cloneImpl<GPTAttentionPlugin>());
}
static int getPackedTensorHiddenDimIndex(bool removePadding)
{
return removePadding ? 1 : 2;
}
// NOTE: generation input length might be larger than one in the spec decoding mode.
int GPTAttentionPlugin::getGenerationInputSequenceLength(
nvinfer1::PluginTensorDesc const* inputDesc, int32_t localNbSeq, int32_t localNbTokens) const
{
if (mRemovePadding)
{
// Speculative decoding mode might need variable generation input sequence length.
if (mIsSpecDecodingEnabled && mUseSpecDecoding)
{
TLLM_CHECK_WITH_INFO(mCpSize <= 1, "Context Parallel does not support speculative decoding mode for now");
// SPEC_DECODING_POSITION_OFFSETS: [batch_size, max_generation_input_length].
return inputDesc[getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)].dims.d[1];
}
else
{
if (mCpSize > 1)
{
// Given that localNbTokens == (beamSize * localNbSeq + mCpSize - 1) / mCpSize, but when mCpSize - 1 >
// localNbSeq, there are multiple choices for beamSize. Assume beamSize == 1 here.
TLLM_CHECK_WITH_INFO(localNbTokens == (localNbSeq + mCpSize - 1) / mCpSize,
"Context Parallel does not support beamSize > 1 for non-speculative decoding mode, "
"localNbTokens=%d, localNbSeq=%d",
localNbTokens, localNbSeq);
return 1;
}
// [num_tokens, local_hidden_size] where num_tokens = batch_size * generation_input_length
TLLM_CHECK_WITH_INFO(localNbTokens % localNbSeq == 0,
"seq_len should be same for all generation requests, localNbTokens=%d, localNbSeq=%d", localNbTokens,
localNbSeq);
return localNbTokens / localNbSeq;
}
}
else
{
// We don't have IFB without mRemovePadding, so just take it out from inputDesc
// [batch_size, seq_len, local_hidden_size]
return inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[1];
}
}
// outputs
// output_tensor [batch_size, seq_len, local_hidden_size] or [num_tokens, local_hidden_size]
// present_key_value_pool (optional if mPagedKVCache is false) [batch_size, 2, local_num_kv_heads, max_seq_len,
// head_size]
nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(outputIndex == 0 || (!mPagedKVCache && useKVCache() && outputIndex == 1));
if (outputIndex == 0)
{
auto ret = inputs[getIdx(IdxEntry::QKV_TENSOR)];
// In MLA, the output dim is v_head_dim
auto const head_size = (mIsMLAEnabled ? mMLAParams.v_head_dim : mHeadSize);
ret.d[getPackedTensorHiddenDimIndex(mRemovePadding)] = exprBuilder.operation(
DimensionOperation::kPROD, *exprBuilder.constant(head_size), *exprBuilder.constant(mNumHeads));
return ret;
}
return inputs[getIdx(IdxEntry::PAST_KEY_VALUE)];
}
bool GPTAttentionPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
bool result = false;
int posCaseLine = -1;
if (pos == getIdx(IdxEntry::CONTEXT_LENGTHS) || pos == getIdx(IdxEntry::REQUEST_TYPES)
|| pos == getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW) || pos == getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)
|| (isEntryUsed(IdxEntry::SPEC_DECODING_PACKED_MASK) && pos == getIdx(IdxEntry::SPEC_DECODING_PACKED_MASK))
|| (isEntryUsed(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)
&& pos == getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS))
|| (isEntryUsed(IdxEntry::SPEC_DECODING_GENERATION_LENGTHS)
&& pos == getIdx(IdxEntry::SPEC_DECODING_GENERATION_LENGTHS))
|| (isEntryUsed(IdxEntry::SPEC_DECODING_USE) && pos == getIdx(IdxEntry::SPEC_DECODING_USE)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (isMRoPE() && (pos == getIdx(IdxEntry::MROPE_ROTARY_COS_SIN)))
{
return inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (isMRoPE() && (pos == getIdx(IdxEntry::MROPE_POSITION_DELTAS)))
{
return inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (pos == getIdx(IdxEntry::HOST_RUNTIME_PERF_KNOBS) || pos == getIdx(IdxEntry::HOST_CONTEXT_PROGRESS))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT64;
}
else if (useKVCache()
&& (pos == getIdx(IdxEntry::SEQUENCE_LENGTH) || pos == getIdx(IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS)
|| pos == getIdx(IdxEntry::CACHE_INDIR)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (isRoPE() && (pos == getIdx(IdxEntry::ROTARY_INV_FREQ) || pos == getIdx(IdxEntry::ROTARY_COS_SIN)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (mIsMLAEnabled && (pos == getIdx(IdxEntry::ROTARY_COS_SIN)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (isLongRoPE()
&& (pos == getIdx(IdxEntry::LONG_ROPE_ROTARY_INV_FREQ) || pos == getIdx(IdxEntry::LONG_ROPE_ROTARY_COS_SIN)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (useKVCache() && mKVCacheQuantMode.hasKvCacheQuant()
&& (pos == getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)
|| pos == getIdx(IdxEntry::KV_CACHE_QUANTIZATION_SCALE)))
{
// kv_scale for mType->int8/fp8 and int8/fp8->mType conversion
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (mFP8ContextFMHA && pos == getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (useFullCustomMask() && pos == getIdx(IdxEntry::ATTENTION_MASK))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kBOOL && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (useCustomMask() && pos == getIdx(IdxEntry::ATTENTION_PACKED_MASK))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (useKVCache() && mPagedKVCache
&& (pos == getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS) || pos == getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)))
{
// kv cache block offsets
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (useKVCache() && mPagedKVCache && (pos == getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)))
{
// kv cache pool pointers
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (useKVCache() && mPagedKVCache && (pos == getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)))
{
// kv cache pool mapping
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (useKVCache() && mKVCacheQuantMode.hasInt8KvCache()
&& (!mPagedKVCache && (pos == getIdx(IdxEntry::PAST_KEY_VALUE) || pos == nbInputs + 1)))
{
// If use Int8 K/V cache we require I/O KV values to int8
posCaseLine = __LINE__;
result = (inOut[pos].type == nvinfer1::DataType::kINT8) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (useKVCache() && mKVCacheQuantMode.hasFp8KvCache()
&& (!mPagedKVCache && (pos == getIdx(IdxEntry::PAST_KEY_VALUE) || pos == nbInputs + 1)))
{
// If use FP8 K/V cache we require I/O KV values to FP8
posCaseLine = __LINE__;
result = (inOut[pos].type == nvinfer1::DataType::kFP8) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (mRemovePadding && (pos == getIdx(IdxEntry::HOST_CONTEXT_LENGTH)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
}
else if (mCrossAttention
&& (pos == getIdx(IdxEntry::CROSS_KV_LENGTH) || pos == getIdx(IdxEntry::ENCODER_INPUT_LENGTH)))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kINT32;
}
else if (isLognScaling() && pos == getIdx(IdxEntry::LOGN_SCALING))
{
return inOut[pos].type == nvinfer1::DataType::kFLOAT;
}
else if (pos == nbInputs && mFP8ContextFMHA)
{
// Output tensor now supports fp8 data type.
posCaseLine = __LINE__;
result = (inOut[pos].type == nvinfer1::DataType::kFP8) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (mSkipAttn && pos == getIdx(IdxEntry::SKIP_ATTN))
{
posCaseLine = __LINE__;
result = inOut[pos].type == nvinfer1::DataType::kBOOL && inOut[pos].format == TensorFormat::kLINEAR;
}
else
{
posCaseLine = __LINE__;
result = (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
}
TLLM_LOG_DEBUG(
"%s: pos: %d, result: %d, posCaseLine: %d", __PRETTY_FUNCTION__, pos, static_cast<int>(result), posCaseLine);
return result;
}
template <typename T, typename KVCacheBuffer>
void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
TLLM_CHECK(mHeadSize > 0);
int beamWidth = -1;
if (!isCrossAttention() && useKVCache())
{
// desc_val == -1 means beam_width is not static, we should look at min/max/opt.
//
// In prepareEnqueueGeneration, we'll prepare for all cases where beam_width doesn't exceed max.
// TODO(minwei): pass min AND max to prepareEnqueueGeneration instead of max only.
int desc_val = in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1];
int max_val = in[getIdx(IdxEntry::CACHE_INDIR)].max.d[1];
beamWidth = desc_val == -1 ? max_val : desc_val;
}
else
{
beamWidth = 1;
}
TLLM_CHECK(beamWidth != -1);
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
// unless each layer has different attention window sizes.
// the kv_cache capacity.
int max_encoder_context_len = isCrossAttention() ? in[getIdx(IdxEntry::CROSS_KV_LENGTH)].desc.dims.d[0] : 0;
int const max_attention_window_size = isCrossAttention()
? max_encoder_context_len
: (useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[2] : 0);
int const cyclic_attention_window_size = max_attention_window_size;
int const num_requests = 256;
int const sink_token_length = 0;
EnqueueGenerationParams<T> enqueueParams{
/*attention_input=*/nullptr,
/*qkv_bias=*/nullptr,
/*attention_mask*/ nullptr,
/*rotary_inv_freq*/ nullptr,
/*input_seq_length=*/0,
/*sequence_lengths=*/nullptr,
/*past_kv_length=*/0,
beamWidth,
/*context_lengths=*/nullptr,
/*kv_scale_orig_quant=*/nullptr,
/*kv_scale_quant_orig=*/nullptr,
/*attention_out_orig_quant=*/nullptr,
/*alibi_slopes=*/nullptr,
/*context_buf_=*/nullptr,
/*key_value_cache=*/nullptr,
/*block_offsets=*/nullptr,
/*host_primary_pool_pointer=*/nullptr,
/*host_secondary_pool_pointer=*/nullptr,
/*attention_mask_stride*/ 0,
max_attention_window_size,
cyclic_attention_window_size,
cyclic_attention_window_size,
/*can_use_one_more_block=*/false,
sink_token_length,
num_requests,
/*max_blocks_per_sequence=*/0,
/*cache_indir=*/nullptr,
/*workspace=*/nullptr,
/*max_context_kv_len_list=*/nullptr,
/*mrope_position_deltas*/ nullptr,
};
prepareEnqueueGeneration<T, KVCacheBuffer>(enqueueParams);
// Always reserve SemaphoreArray (for multi-block mode) as MMHA may enable multi-block mode when shared memory is
// not enough.
auto const& ctxLenTensor = in[getIdx(IdxEntry::CONTEXT_LENGTHS)];
TLLM_CHECK_DEBUG(ctxLenTensor.max.nbDims == 1);
int32_t const max_batch_beam = in[getIdx(IdxEntry::CONTEXT_LENGTHS)].max.d[0];
reserveSemaphoreArray(mNumHeads * max_batch_beam);
}
template <typename T>
void GPTAttentionPlugin::configurePluginDispatchKVCacheType(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
if (mPagedKVCache)
{
configurePluginImpl<T, KVBlockArray>(in, nbInputs, out, nbOutputs);
}
else
{
configurePluginImpl<T, KVLinearBuffer>(in, nbInputs, out, nbOutputs);
}
}
void GPTAttentionPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
if (mType == nvinfer1::DataType::kHALF)
{
configurePluginDispatchKVCacheType<half>(in, nbInputs, out, nbOutputs);
}
else if (mType == nvinfer1::DataType::kFLOAT)
{
configurePluginDispatchKVCacheType<float>(in, nbInputs, out, nbOutputs);
}
#ifdef ENABLE_BF16
else if (mType == nvinfer1::DataType::kBF16)
{
configurePluginDispatchKVCacheType<__nv_bfloat16>(in, nbInputs, out, nbOutputs);
}
#endif
}
size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
int const max_context_length = mMaxContextLength;
int const cross_kv_length = isCrossAttention() ? inputs[getIdx(IdxEntry::CROSS_KV_LENGTH)].dims.d[0] : 0;
int const max_num_seq = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0];
auto const type = inputs[getIdx(IdxEntry::QKV_TENSOR)].type;
int const max_kv_cache_length
= isCrossAttention() ? cross_kv_length : (useKVCache() ? inputs[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
int const max_num_tokens
= mRemovePadding ? inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] : max_num_seq * max_context_length;
size_t const context_workspace_size
= getWorkspaceSizeForContext(type, max_num_seq, max_context_length, cross_kv_length, max_num_tokens);
int32_t const num_spec_dec_tokens
= mIsSpecDecodingEnabled ? inputs[getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)].dims.d[1] : 1;
int32_t const max_batch_beam = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0];
int32_t const max_num_gen_tokens = std::min(max_num_tokens, num_spec_dec_tokens * max_batch_beam);
size_t const generation_workspace_size
= getWorkspaceSizeForGeneration(type, max_num_seq, max_kv_cache_length, max_num_tokens);
size_t attention_input_workspace_size = 0;
if (mIsMLAEnabled)
{
int32_t const size_per_head
= 2 * (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim) + mMLAParams.v_head_dim;
size_t const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
size_t const attention_input_size = size * max_num_tokens * mNumHeads
* std::max(size_per_head, mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim);
size_t workspaces[1];
workspaces[0] = attention_input_size;
attention_input_workspace_size = tensorrt_llm::common::calculateTotalWorkspaceSize(workspaces, 1);
}
else if (mUnfuseQkvGemm)
{
int const local_hidden_units_q
= inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)];
int const local_hidden_units_kv
= inputs[getIdx(IdxEntry::K_TENSOR)].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)];
size_t const size = tensorrt_llm::runtime::BufferDataType(type).getSize();
size_t const attention_input_size = size * max_num_tokens * (local_hidden_units_q + 2 * local_hidden_units_kv);
size_t workspaces[1];
workspaces[0] = attention_input_size;
attention_input_workspace_size = tensorrt_llm::common::calculateTotalWorkspaceSize(workspaces, 1);
}
return std::max(context_workspace_size, generation_workspace_size) + attention_input_workspace_size;
}
static size_t getStride(nvinfer1::Dims const& dims, int n)
{
TLLM_CHECK(n >= 0 && n < dims.nbDims);
return std::accumulate(dims.d + n + 1, dims.d + dims.nbDims, 1, std::multiplies<size_t>{});
}
template <typename T, typename AttentionOutT, typename KVCacheBuffer>
int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream)
{
TLLM_LOG_TRACE("Attention plugin start at layer %d", mLayerIdx);
using runtime::RequestType;
int32_t const nbSeq = inputDesc[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0];
int32_t const beam_width = useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1;
RequestType const* reqTypes = static_cast<RequestType const*>(inputs[getIdx(IdxEntry::REQUEST_TYPES)]);
int32_t nbContextRequests = 0;
int32_t contextTokenIdxEnd = 0;
int32_t contextTokenIdxEndForCp = 0;
// count context requests
for (int32_t seqIdx = 0; seqIdx < nbSeq; seqIdx++)
{
if (reqTypes[seqIdx] != RequestType::kCONTEXT)
{
break;
}
++nbContextRequests;
contextTokenIdxEnd += mRemovePadding
? static_cast<int32_t const*>(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)])[seqIdx]
: inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[1];
contextTokenIdxEndForCp += mRemovePadding
? (static_cast<int32_t const*>(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)])[seqIdx] + mCpSize - 1)
/ mCpSize
: (inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[1] + mCpSize - 1) / mCpSize;
}
for (int32_t seqIdx = nbContextRequests; seqIdx < nbSeq; seqIdx++)
{
TLLM_CHECK(reqTypes[seqIdx] == RequestType::kGENERATION);
}
// mixed requests require mRemovePadding and mPagedKVCache
if (nbContextRequests != 0 && nbContextRequests != nbSeq)
{
TLLM_CHECK(mRemovePadding && mPagedKVCache);
}
if (nbContextRequests > 0)
{
auto seqIdxBeg = 0;
auto tokenIdxBeg = 0;
auto localNbTokens = contextTokenIdxEnd;
enqueueSome<T, AttentionOutT, KVCacheBuffer>(seqIdxBeg, nbContextRequests, tokenIdxBeg, localNbTokens,
inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
if (auto nbGenerationSeq = nbSeq - nbContextRequests; nbGenerationSeq > 0)
{
auto seqIdxBeg = nbContextRequests;
auto tokenIdxBeg = mCpSize > 1 ? contextTokenIdxEndForCp : contextTokenIdxEnd;
// if mRemovePadding is true, we may have IFB, and need to remove context tokens.
// if mRemovePadding is false, it is only generation requests, so just multiply batch_beam and seq_len (May not
// 1 for Parallel Decoding)
auto localNbTokens = mRemovePadding
? inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] - tokenIdxBeg
: inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0] * inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[1];
enqueueSome<T, AttentionOutT, KVCacheBuffer>(seqIdxBeg, nbGenerationSeq, tokenIdxBeg, localNbTokens, inputDesc,
outputDesc, inputs, outputs, workspace, stream);
}
sync_check_cuda_error();
TLLM_LOG_TRACE("Attention plugin stop at layer %d", mLayerIdx);
return 0;
}
template <typename T, typename AttentionOutT>
mlaParams<T> GPTAttentionPlugin::enqueueMLAPreprocess(int32_t localNbSeq, int32_t localNbTokens,
nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void*& workspace, bool is_context, cudaStream_t stream)
{
auto const* input = static_cast<T const*>(inputs[getIdx(IdxEntry::QKV_TENSOR)]);
auto const* fused_q_proj = static_cast<T const*>(inputs[getIdx(IdxEntry::MLA_FUSED_Q_PROJ_TENSOR)]);
auto const* q_b_proj = static_cast<T const*>(inputs[getIdx(IdxEntry::MLA_Q_B_PROJ_TENSOR)]);
auto const* kv_b_proj = static_cast<T const*>(inputs[getIdx(IdxEntry::MLA_KV_B_PROJ_TENSOR)]);
float2 const* cos_sin_cache = static_cast<float2 const*>(inputs[getIdx(IdxEntry::ROTARY_COS_SIN)]);
AttentionOutT* context_buf_ = static_cast<AttentionOutT*>(outputs[0]);
mlaParams<T> mla_params;
mla_params.fused_a_input = input;
mla_params.context_buf = reinterpret_cast<T*>(context_buf_);
mla_params.fused_q_proj = fused_q_proj;
mla_params.q_b_proj = q_b_proj;
mla_params.kv_b_proj = kv_b_proj;
mla_params.cos_sin_cache = cos_sin_cache;
mla_params.batch_size = localNbSeq;
mla_params.acc_q_len = localNbTokens;
mla_params.head_num = mNumHeads;
mla_params.meta = mMLAParams;
// {
// __nv_bfloat16 *h_input;
// auto const inSize = sizeof(__nv_bfloat16) * 20;
// cudaMallocHost(reinterpret_cast<void**>(&h_input), inSize);
// cudaMemcpyAsync(reinterpret_cast<void*>(h_input), mla_params.fused_a_input, inSize, cudaMemcpyDeviceToHost,
// stream); cudaDeviceSynchronize(); printf("gpt attention input\n"); for (int i = 0; i < 20; ++i) {
// printf("%d: %f", i, __bfloat162float(h_input[i]));
// if (i % 5 == 4) {
// printf("\n");
// }
// }
// cudaFreeHost(h_input);
// }
return mla_params;
}
template <typename T, typename AttentionOutT, typename KVCacheBuffer>
int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t tokenIdxBeg, int32_t localNbTokens,
nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream)
{
// relative_attention_bias [head_num, max_seq_len, max_seq_len] (optional in relative position)
// or [head_num, num_buckets] (optional in implicit relative attention)
// cross_kv [batch_size, seq_len, 2 * local_hidden_size] or [num_tokens, 2 * local_hidden_size]
// when enable remove_input_padding (optional in cross attention mode)
// cross_kv_length [int] max encoder input context length (optional in cross attention mode)
// encoder_input_lengths [batch_size] raw sequence lengths (optional in cross attention mode)
using runtime::RequestType;
auto const* const reqTypeInBatchPtr
= static_cast<RequestType const*>(inputs[getIdx(IdxEntry::REQUEST_TYPES)]) + seqIdxBeg;
bool const is_context = (reqTypeInBatchPtr[0] == RequestType::kCONTEXT);
T const* attention_input = static_cast<T const*>(inputs[getIdx(IdxEntry::QKV_TENSOR)])
+ inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)]
* size_t(tokenIdxBeg);
bool changeSpecDecodingMode = false;
if (mIsSpecDecodingEnabled)
{
bool useSpecDecoding
= static_cast<bool>(reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_USE)])[0]);
changeSpecDecodingMode = mUseSpecDecoding != useSpecDecoding;
mUseSpecDecoding = useSpecDecoding;
// change mMultiBlockMode to default
mMultiBlockMode = mUseSpecDecoding ? false : true;
}
[[maybe_unused]] mlaParams<T> mla_params;
if (mIsMLAEnabled)
{
// In MLA, attention_input will be the ptr of workspace, and workspace value will be updated in
// enqueueMLAPreprocess
mla_params = enqueueMLAPreprocess<T, AttentionOutT>(
localNbSeq, localNbTokens, inputDesc, outputDesc, inputs, outputs, workspace, is_context, stream);
size_t const size_per_head = is_context
? (2 * (mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim) + mMLAParams.v_head_dim)
: mMLAParams.kv_lora_rank + mMLAParams.qk_rope_head_dim;
size_t const total_size = sizeof(T) * mla_params.acc_q_len * mNumHeads * size_per_head;
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
size_t offset = 0;
T* attention_input_qkv = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, total_size));
workspace = reinterpret_cast<void*>(workspace_byte_ptr + offset);
mla_params.attention_input_buf = attention_input_qkv;
mla_params.workspace = workspace;
attention_input = attention_input_qkv;
}
T const* qkv_bias = nullptr;
if (mQKVBiasEnabled)
{
qkv_bias = reinterpret_cast<T const*>(inputs[getIdx(IdxEntry::QKV_BIAS_TENSOR)]);
}
// Note we still need context length during generation for MMHA optimization.
int32_t const max_context_q_len = [&]()
{
if (!mRemovePadding)
{
return static_cast<int>(inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[1]);
}
auto const host_context_lengths
= static_cast<int32_t const*>(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)]) + seqIdxBeg;
return *std::max_element(host_context_lengths, host_context_lengths + localNbSeq);
}();
// Rotary inv_freq, cos_sin cache to avoid re-computing.
float const* rotary_inv_freq = nullptr;
float2 const* rotary_cos_sin = nullptr;
bool const useLongRoPECache = isLongRoPE() && max_context_q_len > mRotaryEmbeddingOriginalMaxPositions;
if (isRoPE())
{
auto inputName = useLongRoPECache ? IdxEntry::LONG_ROPE_ROTARY_INV_FREQ : IdxEntry::ROTARY_INV_FREQ;
rotary_inv_freq = reinterpret_cast<float const*>(inputs[getIdx(inputName)]);
}
if (isRoPE() || mIsMLAEnabled)
{
auto inputName = useLongRoPECache ? IdxEntry::LONG_ROPE_ROTARY_COS_SIN : IdxEntry::ROTARY_COS_SIN;
rotary_cos_sin = reinterpret_cast<float2 const*>(inputs[getIdx(inputName)]);
}
auto const mrope_rotary_cos_sin
= isMRoPE() ? reinterpret_cast<float2 const*>(inputs[getIdx(IdxEntry::MROPE_ROTARY_COS_SIN)]) : nullptr;
auto const mrope_position_deltas
= isMRoPE() ? reinterpret_cast<int32_t const*>(inputs[getIdx(IdxEntry::MROPE_POSITION_DELTAS)]) : nullptr;
if (mUnfuseQkvGemm)
{
int const max_seqlen = inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[mRemovePadding ? 0 : 1];
int const batch_size = mRemovePadding ? 1 : inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0];
T const* attention_input_q = static_cast<T const*>(inputs[getIdx(IdxEntry::QKV_TENSOR)]);
T const* attention_input_k = static_cast<T const*>(inputs[getIdx(IdxEntry::K_TENSOR)]);
T const* attention_input_v = static_cast<T const*>(inputs[getIdx(IdxEntry::V_TENSOR)]);
size_t const hidden_units_q
= inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)];
size_t const hidden_units_kv
= inputDesc[getIdx(IdxEntry::K_TENSOR)].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)];
size_t const hidden_units = hidden_units_q + 2 * hidden_units_kv;
size_t const size_qkv = sizeof(T) * hidden_units;
size_t const size_q = sizeof(T) * hidden_units_q;
size_t const size_kv = sizeof(T) * hidden_units_kv;
size_t const total_size = size_qkv * batch_size * max_seqlen;
int8_t* workspace_byte_ptr = reinterpret_cast<int8_t*>(workspace);
size_t offset = 0;
T* attention_input_qkv = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, total_size));
workspace = reinterpret_cast<void*>(workspace_byte_ptr + offset);
cudaMemcpy2DAsync(attention_input_qkv, size_qkv, attention_input_q, size_q, size_q, batch_size * max_seqlen,
cudaMemcpyDeviceToDevice, stream);
cudaMemcpy2DAsync(attention_input_qkv + hidden_units_q, size_qkv, attention_input_k, size_kv, size_kv,
batch_size * max_seqlen, cudaMemcpyDeviceToDevice, stream);
cudaMemcpy2DAsync(attention_input_qkv + hidden_units_q + hidden_units_kv, size_qkv, attention_input_v, size_kv,
size_kv, batch_size * max_seqlen, cudaMemcpyDeviceToDevice, stream);
attention_input = attention_input_qkv + hidden_units * tokenIdxBeg;
}
int const* context_q_lengths = reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)]) + seqIdxBeg;
int const* sequence_kv_length = useKVCache()
? static_cast<int const*>(inputs[getIdx(IdxEntry::SEQUENCE_LENGTH)]) + seqIdxBeg
: context_q_lengths;
int max_encoder_context_len = isCrossAttention() ? inputDesc[getIdx(IdxEntry::CROSS_KV_LENGTH)].dims.d[0] : 0;
// for enc-dec model, since decoder_input_ids could be longer than 1,
// such model has an encoder context (for cross attn) and an decoder context (for self attn)
// clarify 3 lens:
// -- max_context_q_len: len of decoder input. No "max" concept, it's what it is given.
// Also called (decoder_)input_seq_length, normally 1 for encoder-decoder start token
// -- max_seq_len: max allowed len of decoder output, i.e. final results
// -- max_encoder_context_len: len of encoder input (in cross attn). Also called encoder_input_seq_length
int const beamWidth
= isCrossAttention() ? 1 : (useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1);
// Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same
// unless each layer has different attention window sizes.
// the kv_cache capacity.
int const max_attention_window_size = isCrossAttention()
? max_encoder_context_len
: (useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0);
// The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens.
// Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity.
int const* cyclic_attention_window_sizes
= reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)]);
int const cyclic_attention_window_size
= isCrossAttention() ? max_encoder_context_len : cyclic_attention_window_sizes[mLayerIdx];
int const sink_token_length = reinterpret_cast<int const*>(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0];
int const num_attn_layer = inputDesc[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)].dims.d[0];
int const max_cyclic_attention_window_size = isCrossAttention()
? max_encoder_context_len
: *std::max_element(cyclic_attention_window_sizes, cyclic_attention_window_sizes + num_attn_layer);
bool const can_use_one_more_block = beamWidth > 1;
float const* kv_scale_orig_quant = nullptr;
float const* kv_scale_quant_orig = nullptr;
if (useKVCache() && mKVCacheQuantMode.hasKvCacheQuant())
{
assert(inputDesc[getIdx(IdxEntry::KV_CACHE_QUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT);
assert(inputDesc[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT);
kv_scale_orig_quant = reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::KV_CACHE_QUANTIZATION_SCALE)]);
kv_scale_quant_orig = reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)]);
}
float const* attention_output_orig_quant = nullptr;
if (mFP8ContextFMHA)
{
assert(inputDesc[getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT);
attention_output_orig_quant
= reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE)]);
}
uint32_t const* attention_packed_mask = nullptr;
if (useCustomMask())
{
assert(inputDesc[getIdx(IdxEntry::ATTENTION_PACKED_MASK)].type == nvinfer1::DataType::kINT32);
attention_packed_mask = reinterpret_cast<uint32_t const*>(inputs[getIdx(IdxEntry::ATTENTION_PACKED_MASK)]);
}
bool const* attention_mask = nullptr;
int attention_mask_stride = 0;
if (useFullCustomMask())
{
attention_mask_stride = static_cast<int>(inputDesc[getIdx(IdxEntry::ATTENTION_MASK)].dims.d[1]);
attention_mask = reinterpret_cast<bool const*>(inputs[getIdx(IdxEntry::ATTENTION_MASK)])
+ attention_mask_stride * static_cast<size_t>(tokenIdxBeg);
}
int max_blocks_per_sequence = 0;
kernels::KVBlockArray::DataType* block_offsets = nullptr;
kernels::KVBlockArray::DataType* host_block_offsets = nullptr;
void* host_primary_pool_pointer = nullptr;
void* host_secondary_pool_pointer = nullptr;
if (useKVCache() && mPagedKVCache)
{
auto const& kvCacheBlockOffsets = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)];
auto const& kvCacheBlockOffsetsShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims;
max_blocks_per_sequence = kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1];
std::int32_t const* host_pool_mapping
= static_cast<std::int32_t const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)]);
const int32_t layerToPool = host_pool_mapping[mLayerIdx];
auto const seqStride = getStride(kvCacheBlockOffsetsShape, 1);
auto const poolStride = getStride(kvCacheBlockOffsetsShape, 0);
auto const seqOffset = seqIdxBeg * seqStride;
auto const poolOffset = layerToPool * poolStride;
block_offsets
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)])
+ poolOffset + seqOffset;
host_block_offsets
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)])
+ poolOffset + seqOffset;
auto const* const typed_host_pool_pointers
= static_cast<char* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]);
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
auto const blockSize = mTokensPerBlock * mNumKVHeads * mHeadSize;
auto const bytesPerBlock = blockSize * cacheElemSize;
auto const layerOffset = mLayerIdxInCachePool * 2 * bytesPerBlock;
host_primary_pool_pointer = reinterpret_cast<void*>(typed_host_pool_pointers[layerToPool * 2] + layerOffset);
host_secondary_pool_pointer
= reinterpret_cast<void*>(typed_host_pool_pointers[layerToPool * 2 + 1] + layerOffset);
}
AttentionOutT* context_buf_ = static_cast<AttentionOutT*>(outputs[0])
+ outputDesc[0].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)] * tokenIdxBeg;
void* key_value_cache = nullptr;
if (useKVCache() && !mPagedKVCache)
{
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
key_value_cache
= static_cast<std::byte*>(outputs[1]) + cacheElemSize * getStride(outputDesc[1].dims, 0) * seqIdxBeg;
void const* past_key_value_cache = inputs[getIdx(IdxEntry::PAST_KEY_VALUE)];
if (past_key_value_cache != outputs[1])
{
auto shape = outputDesc[1].dims;
auto const size
= cacheElemSize * std::accumulate(shape.d, shape.d + shape.nbDims, 1, std::multiplies<size_t>{});
cudaMemcpyAsync(outputs[1], past_key_value_cache, size, cudaMemcpyDeviceToDevice, stream);
}
}
T const* alibi_slopes = isALiBi() ? static_cast<T const*>(inputs[getIdx(IdxEntry::ALIBI_SLOPES)]) : nullptr;
int const* spec_decoding_packed_mask = nullptr;
int const* spec_decoding_position_offsets = nullptr;
int const* spec_decoding_generation_lengths = nullptr;
int num_decoding_draft_tokens = 0;
if (mIsSpecDecodingEnabled && mUseSpecDecoding)
{
// Second dimension of spec_decoding_position_offsets is num_decoding_draft_tokens + 1.
// [batch_size, num_decoding_draft_tokens + 1]
num_decoding_draft_tokens = inputDesc[getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)].dims.d[1] - 1;
if (num_decoding_draft_tokens > 0)
{
// spec_decoding_* tensors are not filled for context requests. Hence, always strting from 0th index
int32_t constexpr genSeqIdx = 0;
spec_decoding_packed_mask = static_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_PACKED_MASK)])
+ genSeqIdx * getStride(inputDesc[getIdx(IdxEntry::SPEC_DECODING_PACKED_MASK)].dims, 0);
// Packed as [num_tokens, packed_mask_size]
// Use seqIdxBeg * (num_decoding_draft_tokens + 1) here as only generation tokens have the packed_mask
// buffer.
// TODO: support variable sequence length based on generationTokenIdxBeg.
spec_decoding_packed_mask = static_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_PACKED_MASK)])
+ genSeqIdx * (num_decoding_draft_tokens + 1)
* getStride(inputDesc[getIdx(IdxEntry::SPEC_DECODING_PACKED_MASK)].dims, 0);
spec_decoding_position_offsets
= static_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)])
+ genSeqIdx * getStride(inputDesc[getIdx(IdxEntry::SPEC_DECODING_POSITION_OFFSETS)].dims, 0);
spec_decoding_generation_lengths
= static_cast<int const*>(inputs[getIdx(IdxEntry::SPEC_DECODING_GENERATION_LENGTHS)]) + genSeqIdx;
}
}
int32_t const* max_context_kv_len_list = useKVCache()
? static_cast<int const*>(inputs[getIdx(IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS)]) + seqIdxBeg
: nullptr;
int32_t const max_context_kv_len = useKVCache()
? *std::max_element(max_context_kv_len_list, max_context_kv_len_list + localNbSeq)
: max_context_q_len;