Skip to content

Commit 9a2b44d

Browse files
authored
[None][chore] No-op changes to support context parallelism in disaggregated serving later (#7063)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 90bfc8c commit 9a2b44d

File tree

7 files changed

+157
-120
lines changed

7 files changed

+157
-120
lines changed

cpp/include/tensorrt_llm/executor/dataTransceiverState.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,30 @@ class CacheState final
5252
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
5353
: mModelConfig(std::move(modelConfig))
5454
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
55-
worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), worldConfig.getTensorParallelism()}
55+
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
56+
worldConfig.getTensorParallelism()}
5657
, mDataType{dataType}
5758
, mAttentionConfig(attentionType, kvFactor)
5859
{
5960
}
6061

6162
CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
62-
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
63-
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
64-
int DPrank = 0, int DPsize = 0)
63+
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
64+
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
65+
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
6566
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
66-
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
67+
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
6768
, mDataType{dataType}
6869
, mAttentionConfig(attentionType, kvFactor)
6970
{
7071
}
7172

7273
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
73-
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
74-
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
75-
int DPrank = 0, int DPsize = 0)
74+
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
75+
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
76+
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
7677
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
77-
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
78+
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
7879
, mDataType{dataType}
7980
, mAttentionConfig(attentionType, kvFactor)
8081
{
@@ -83,7 +84,7 @@ class CacheState final
8384
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
8485
{
8586
return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig
86-
&& mDataType == other.mDataType;
87+
&& mAttentionConfig == other.mAttentionConfig && mDataType == other.mDataType;
8788
}
8889

8990
struct ModelConfig
@@ -103,15 +104,16 @@ class CacheState final
103104
{
104105
SizeType32 mTensorParallelism;
105106
SizeType32 mPipelineParallelism;
107+
SizeType32 mContextParallelism;
106108
bool mEnableAttentionDP;
107109
SizeType32 mDPrank;
108110
SizeType32 mDPsize;
109111

110112
[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
111113
{
112114
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
113-
&& mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank
114-
&& mDPsize == other.mDPsize;
115+
&& mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
116+
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize;
115117
}
116118
};
117119

@@ -125,6 +127,11 @@ class CacheState final
125127
{
126128
}
127129

130+
[[nodiscard]] bool operator==(AttentionConfig const& other) const noexcept
131+
{
132+
return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor;
133+
}
134+
128135
// attentionType ;
129136
AttentionType mAttentionType;
130137
int mKvFactor;
@@ -162,6 +169,7 @@ class CacheState final
162169
sstring << "mTokensPerBlock:" << mModelConfig.mTokensPerBlock << "\n";
163170
sstring << "tp:" << mParallelConfig.mTensorParallelism << "\n";
164171
sstring << "pp:" << mParallelConfig.mPipelineParallelism << "\n";
172+
sstring << "cp:" << mParallelConfig.mContextParallelism << "\n";
165173
sstring << "enableAttentionDP:" << mParallelConfig.mEnableAttentionDP << "\n";
166174
sstring << "datatype:" << static_cast<int32_t>(mDataType) << "\n";
167175
sstring << "attentionType:" << static_cast<int32_t>(mAttentionConfig.mAttentionType) << "\n";

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,14 @@ void CacheFormatter::unformat(TransferSession& session)
822822
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support non-MLA");
823823
return false;
824824
}
825+
if (selfConfig.getParallelConfig().mContextParallelism != 1
826+
|| destConfig.getParallelConfig().mContextParallelism != 1)
827+
{
828+
TLLM_LOG_WARNING(
829+
"CacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
830+
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
831+
return false;
832+
}
825833

826834
std::unordered_set<int> setVecDest{
827835
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -558,18 +558,20 @@ void MLACacheFormatter::unformat(TransferSession& session)
558558
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
559559
return false;
560560
}
561-
562-
if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
563-
{
564-
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
565-
return false;
566-
}
567561
if (selfConfig.getParallelConfig().mEnableAttentionDP
568562
&& (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0))
569563
{
570564
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
571565
return false;
572566
}
567+
if (selfConfig.getParallelConfig().mContextParallelism != 1
568+
|| destConfig.getParallelConfig().mContextParallelism != 1)
569+
{
570+
TLLM_LOG_WARNING(
571+
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
572+
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
573+
return false;
574+
}
573575
if (destConfig.getParallelConfig().mEnableAttentionDP
574576
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
575577
{

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,15 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is)
531531
auto tokensPerBlock = su::deserialize<decltype(CacheState::ModelConfig::mTokensPerBlock)>(is);
532532
auto tensorParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mTensorParallelism)>(is);
533533
auto pipelineParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mPipelineParallelism)>(is);
534+
auto contextParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mContextParallelism)>(is);
534535
auto enableAttentionDP = su::deserialize<decltype(CacheState::ParallelConfig::mEnableAttentionDP)>(is);
535536
auto DPrank = su::deserialize<decltype(CacheState::ParallelConfig::mDPrank)>(is);
536537
auto DPsize = su::deserialize<decltype(CacheState::ParallelConfig::mDPsize)>(is);
537538
auto dataType = su::deserialize<decltype(CacheState::mDataType)>(is);
538539
auto attentionType = su::deserialize<decltype(CacheState::AttentionConfig::mAttentionType)>(is);
539540
auto kvFactor = su::deserialize<decltype(CacheState::AttentionConfig::mKvFactor)>(is);
540-
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType,
541-
attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
541+
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism,
542+
contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
542543
}
543544

544545
void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os)
@@ -548,6 +549,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o
548549
su::serialize(state.mModelConfig.mTokensPerBlock, os);
549550
su::serialize(state.mParallelConfig.mTensorParallelism, os);
550551
su::serialize(state.mParallelConfig.mPipelineParallelism, os);
552+
su::serialize(state.mParallelConfig.mContextParallelism, os);
551553
su::serialize(state.mParallelConfig.mEnableAttentionDP, os);
552554
su::serialize(state.mParallelConfig.mDPrank, os);
553555
su::serialize(state.mParallelConfig.mDPsize, os);
@@ -564,6 +566,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state)
564566
totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock);
565567
totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism);
566568
totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism);
569+
totalSize += su::serializedSize(state.mParallelConfig.mContextParallelism);
567570
totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP);
568571
totalSize += su::serializedSize(state.mParallelConfig.mDPrank);
569572
totalSize += su::serializedSize(state.mParallelConfig.mDPsize);

0 commit comments

Comments
 (0)