@@ -52,29 +52,30 @@ class CacheState final
5252 AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 )
5353 : mModelConfig (std::move(modelConfig))
5454 , mParallelConfig {worldConfig.getTensorParallelism (), worldConfig.getPipelineParallelism (),
55- worldConfig.enableAttentionDP (), worldConfig.getTensorParallelRank (), worldConfig.getTensorParallelism ()}
55+ worldConfig.getContextParallelism (), worldConfig.enableAttentionDP (), worldConfig.getTensorParallelRank (),
56+ worldConfig.getTensorParallelism ()}
5657 , mDataType {dataType}
5758 , mAttentionConfig (attentionType, kvFactor)
5859 {
5960 }
6061
6162 CacheState (std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
62- SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType ,
63- AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableAttentionDP = false ,
64- int DPrank = 0 , int DPsize = 0 )
63+ SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism ,
64+ nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 ,
65+ bool enableAttentionDP = false , int DPrank = 0 , int DPsize = 0 )
6566 : mModelConfig {std::move (nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
66- , mParallelConfig {tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
67+ , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
6768 , mDataType {dataType}
6869 , mAttentionConfig (attentionType, kvFactor)
6970 {
7071 }
7172
7273 CacheState (SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
73- SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType ,
74- AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableAttentionDP = false ,
75- int DPrank = 0 , int DPsize = 0 )
74+ SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism ,
75+ nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 ,
76+ bool enableAttentionDP = false , int DPrank = 0 , int DPsize = 0 )
7677 : mModelConfig {std::vector (nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
77- , mParallelConfig {tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
78+ , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
7879 , mDataType {dataType}
7980 , mAttentionConfig (attentionType, kvFactor)
8081 {
@@ -83,7 +84,7 @@ class CacheState final
8384 [[nodiscard]] bool operator ==(kv_cache::CacheState const & other) const noexcept
8485 {
8586 return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig
86- && mDataType == other.mDataType ;
87+ && mAttentionConfig == other. mAttentionConfig && mDataType == other.mDataType ;
8788 }
8889
8990 struct ModelConfig
@@ -103,15 +104,16 @@ class CacheState final
103104 {
104105 SizeType32 mTensorParallelism ;
105106 SizeType32 mPipelineParallelism ;
107+ SizeType32 mContextParallelism ;
106108 bool mEnableAttentionDP ;
107109 SizeType32 mDPrank ;
108110 SizeType32 mDPsize ;
109111
110112 [[nodiscard]] bool operator ==(ParallelConfig const & other) const noexcept
111113 {
112114 return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
113- && mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank
114- && mDPsize == other.mDPsize ;
115+ && mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
116+ && mDPrank == other. mDPrank && mDPsize == other.mDPsize ;
115117 }
116118 };
117119
@@ -125,6 +127,11 @@ class CacheState final
125127 {
126128 }
127129
130+ [[nodiscard]] bool operator ==(AttentionConfig const & other) const noexcept
131+ {
132+ return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor ;
133+ }
134+
128135 // attentionType ;
129136 AttentionType mAttentionType ;
130137 int mKvFactor ;
@@ -162,6 +169,7 @@ class CacheState final
162169 sstring << " mTokensPerBlock:" << mModelConfig .mTokensPerBlock << " \n " ;
163170 sstring << " tp:" << mParallelConfig .mTensorParallelism << " \n " ;
164171 sstring << " pp:" << mParallelConfig .mPipelineParallelism << " \n " ;
172+ sstring << " cp:" << mParallelConfig .mContextParallelism << " \n " ;
165173 sstring << " enableAttentionDP:" << mParallelConfig .mEnableAttentionDP << " \n " ;
166174 sstring << " datatype:" << static_cast <int32_t >(mDataType ) << " \n " ;
167175 sstring << " attentionType:" << static_cast <int32_t >(mAttentionConfig .mAttentionType ) << " \n " ;
0 commit comments