Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,20 @@ class CacheTransceiver : public BaseCacheTransceiver
public:
CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager,
executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig,
nvinfer1::DataType dataType,
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
executor::kv_cache::CacheState::AttentionType attentionType
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt);

CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, std::vector<SizeType32> numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, runtime::WorldConfig const& worldConfig,
nvinfer1::DataType dataType,
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
executor::kv_cache::CacheState::AttentionType attentionType
= executor::kv_cache::CacheState::AttentionType::kDEFAULT,
std::optional<executor::CacheTransceiverConfig> cacheTransceiverConfig = std::nullopt)
: CacheTransceiver(cacheManager,
executor::kv_cache::CacheState::ModelConfig{numKvHeadsPerLayer, sizePerHead, tokensPerBlock}, worldConfig,
dataType, attentionType, cacheTransceiverConfig)
attentionLayerNumPerPP, dataType, attentionType, cacheTransceiverConfig)
{
}

Expand Down
25 changes: 16 additions & 9 deletions cpp/include/tensorrt_llm/executor/dataTransceiverState.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,39 @@ class CacheState final
kMLA = 1,
};

CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType,
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig,
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
: mModelConfig(std::move(modelConfig))
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
worldConfig.getTensorParallelism()}
worldConfig.getTensorParallelism(), attentionLayerNumPerPP}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}

CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
int DPrank = 0, int DPsize = 0)
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
attentionLayerNumPerPP}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}

CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
int DPrank = 0, int DPsize = 0)
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
attentionLayerNumPerPP}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
Expand Down Expand Up @@ -108,12 +113,14 @@ class CacheState final
bool mEnableAttentionDP;
SizeType32 mDPrank;
SizeType32 mDPsize;
std::vector<SizeType32> mAttentionLayerNumPerPP;

[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
{
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
&& mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize;
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize
&& mAttentionLayerNumPerPP == other.mAttentionLayerNumPerPP;
}
};

Expand Down
Loading