Skip to content

Commit d902264

Browse files
committed
test for filled value - workaround for empty cache
1 parent b37ff08 commit d902264

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
10751075
}
10761076
int kvFactor = mCacheState->getAttentionConfig().mKvFactor;
10771077
int tokensPerBlock = mCacheState->getModelConfig().mTokensPerBlock;
1078-
int startTokenId = blockId * tokensPerBlock;
1078+
int startTokenId = (blockId * mCpSize + mCpRank) * tokensPerBlock;
10791079
int sizePerHead = mCacheState->getModelConfig().mSizePerHead;
10801080

10811081
bufferManager.copy(blockData, *hostTensor);
@@ -1099,7 +1099,12 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
10991099
{
11001100
using ValueType = decltype(generateValue);
11011101
auto* dataPtr = static_cast<ValueType*>(hostTensor->data(keyIndex));
1102-
// EXPECT_EQ(*dataPtr, generateValue);
1102+
if (*dataPtr != static_cast<ValueType>(0)) {
1103+
EXPECT_EQ(*dataPtr, generateValue);
1104+
} else {
1105+
// // TODO: Remove this when over-allocation is fixed.
1106+
// printf("[verifyBlockData::key] SKIPPING 0! \n");
1107+
}
11031108
// Debug print with rank information for MPI debugging (KEY values)
11041109
if (TARGET_RANK == -1 || tensorrt_llm::mpi::MpiComm::world().getRank() == TARGET_RANK)
11051110
{
@@ -1124,7 +1129,12 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
11241129
{
11251130
using ValueType = decltype(generateValue);
11261131
auto* dataPtr = static_cast<ValueType*>(hostTensor->data(valueIndex));
1127-
// EXPECT_EQ(*dataPtr, generateValue);
1132+
if (*dataPtr != static_cast<ValueType>(0)) {
1133+
EXPECT_EQ(*dataPtr, generateValue);
1134+
} else {
1135+
// // TODO: Remove this when over-allocation is fixed.
1136+
// printf("[verifyBlockData::value] SKIPPING 0! \n");
1137+
}
11281138
// Debug print with rank information for MPI debugging (VALUE values)
11291139
if (TARGET_RANK == -1 || tensorrt_llm::mpi::MpiComm::world().getRank() == TARGET_RANK)
11301140
{

0 commit comments

Comments
 (0)