@@ -52,29 +52,30 @@ class CacheState final
5252        AttentionType attentionType = AttentionType::kDEFAULT , int  kvFactor = 2 )
5353        : mModelConfig (std::move(modelConfig))
5454        , mParallelConfig {worldConfig.getTensorParallelism (), worldConfig.getPipelineParallelism (),
55-               worldConfig.enableAttentionDP (), worldConfig.getTensorParallelRank (), worldConfig.getTensorParallelism ()}
55+               worldConfig.getContextParallelism (), worldConfig.enableAttentionDP (), worldConfig.getTensorParallelRank (),
56+               worldConfig.getTensorParallelism ()}
5657        , mDataType {dataType}
5758        , mAttentionConfig (attentionType, kvFactor)
5859    {
5960    }
6061
6162    CacheState (std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
62-         SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType ,
63-         AttentionType attentionType = AttentionType::kDEFAULT , int  kvFactor = 2 ,  bool  enableAttentionDP =  false ,
64-         int  DPrank = 0 , int  DPsize = 0 )
63+         SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism ,
64+         nvinfer1::DataType dataType,  AttentionType attentionType = AttentionType::kDEFAULT , int  kvFactor = 2 ,
65+         bool  enableAttentionDP =  false ,  int  DPrank = 0 , int  DPsize = 0 )
6566        : mModelConfig {std::move (nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
66-         , mParallelConfig {tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
67+         , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism,  enableAttentionDP, DPrank, DPsize}
6768        , mDataType {dataType}
6869        , mAttentionConfig (attentionType, kvFactor)
6970    {
7071    }
7172
7273    CacheState (SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
73-         SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType ,
74-         AttentionType attentionType = AttentionType::kDEFAULT , int  kvFactor = 2 ,  bool  enableAttentionDP =  false ,
75-         int  DPrank = 0 , int  DPsize = 0 )
74+         SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism ,
75+         nvinfer1::DataType dataType,  AttentionType attentionType = AttentionType::kDEFAULT , int  kvFactor = 2 ,
76+         bool  enableAttentionDP =  false ,  int  DPrank = 0 , int  DPsize = 0 )
7677        : mModelConfig {std::vector (nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
77-         , mParallelConfig {tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
78+         , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism,  enableAttentionDP, DPrank, DPsize}
7879        , mDataType {dataType}
7980        , mAttentionConfig (attentionType, kvFactor)
8081    {
@@ -83,7 +84,7 @@ class CacheState final
8384    [[nodiscard]] bool  operator ==(kv_cache::CacheState const & other) const  noexcept 
8485    {
8586        return  mModelConfig  == other.mModelConfig  && mParallelConfig  == other.mParallelConfig 
86-             && mDataType  == other.mDataType ;
87+             && mAttentionConfig  == other. mAttentionConfig  &&  mDataType  == other.mDataType ;
8788    }
8889
8990    struct  ModelConfig 
@@ -103,15 +104,16 @@ class CacheState final
103104    {
104105        SizeType32 mTensorParallelism ;
105106        SizeType32 mPipelineParallelism ;
107+         SizeType32 mContextParallelism ;
106108        bool  mEnableAttentionDP ;
107109        SizeType32 mDPrank ;
108110        SizeType32 mDPsize ;
109111
110112        [[nodiscard]] bool  operator ==(ParallelConfig const & other) const  noexcept 
111113        {
112114            return  mTensorParallelism  == other.mTensorParallelism  && mPipelineParallelism  == other.mPipelineParallelism 
113-                 && mEnableAttentionDP  == other.mEnableAttentionDP  && mDPrank  == other.mDPrank 
114-                 && mDPsize  == other.mDPsize ;
115+                 && mContextParallelism  == other.mContextParallelism  && mEnableAttentionDP  == other.mEnableAttentionDP 
116+                 && mDPrank  == other. mDPrank  &&  mDPsize  == other.mDPsize ;
115117        }
116118    };
117119
@@ -125,6 +127,11 @@ class CacheState final
125127        {
126128        }
127129
130+         [[nodiscard]] bool  operator ==(AttentionConfig const & other) const  noexcept 
131+         {
132+             return  mAttentionType  == other.mAttentionType  && mKvFactor  == other.mKvFactor ;
133+         }
134+ 
128135        //  attentionType ;
129136        AttentionType mAttentionType ;
130137        int  mKvFactor ;
@@ -162,6 +169,7 @@ class CacheState final
162169        sstring << " mTokensPerBlock:" mModelConfig .mTokensPerBlock  << " \n " 
163170        sstring << " tp:" mParallelConfig .mTensorParallelism  << " \n " 
164171        sstring << " pp:" mParallelConfig .mPipelineParallelism  << " \n " 
172+         sstring << " cp:" mParallelConfig .mContextParallelism  << " \n " 
165173        sstring << " enableAttentionDP:" mParallelConfig .mEnableAttentionDP  << " \n " 
166174        sstring << " datatype:" static_cast <int32_t >(mDataType ) << " \n " 
167175        sstring << " attentionType:" static_cast <int32_t >(mAttentionConfig .mAttentionType ) << " \n " 
0 commit comments