Skip to content

Commit bd978af

Browse files
committed
save more tests
1 parent c709489 commit bd978af

File tree

2 files changed

+70
-27
lines changed

2 files changed

+70
-27
lines changed

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ TargetRanksInfo TargetRanksInfoForDP(
114114
int mDomainCPSize = 1;
115115
int peerCPRankStart = 0;
116116
int peerCPRankEnd = 0;
117+
for (auto val : {peerCPNum, selfCPNum})
118+
{
119+
TLLM_CHECK(isPowerOfTwo(val));
120+
}
117121
if (selfCPNum <= peerCPNum)
118122
{
119123
mDomainCPSize = peerCPNum / selfCPNum;

cpp/tests/batch_manager/cacheTransceiverTest.cpp

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,7 +1440,7 @@ TEST(targetTest, CacheStateNODP)
14401440

14411441
auto const verifyContext
14421442
= [&](int contextRank, tr::WorldConfig const& contextWC, tr::WorldConfig const& genWC,
1443-
std::vector<int> const& expectRanks, int expectPPDomain, int expectTPDomain, bool expectNeedSend)
1443+
std::vector<int> const& expectRanks, int expectPPDomain, int expectTPDomain, int expectCPDomain, bool expectNeedSend)
14441444
{
14451445
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
14461446
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
@@ -1457,6 +1457,7 @@ TEST(targetTest, CacheStateNODP)
14571457
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
14581458
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
14591459
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
1460+
EXPECT_EQ(expectCPDomain, contextTargetInfo.mDomainCPSize);
14601461
EXPECT_EQ(expectNeedSend, MLACacheFormatter::needSendCache(contextCache, genCache, contextRank));
14611462
};
14621463

@@ -1466,28 +1467,28 @@ TEST(targetTest, CacheStateNODP)
14661467
tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1};
14671468
verifyContext(
14681469
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0}, /*expectPPDomain*/ 1,
1469-
/*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1470+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
14701471
verifyContext(
14711472
/*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0}, /*expectPPDomain*/ 1,
1472-
/*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1473+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ false);
14731474
verifyContext(
14741475
/*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1}, /*expectPPDomain*/ 1,
1475-
/*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1476+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
14761477
verifyContext(
14771478
/*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1}, /*expectPPDomain*/ 1,
1478-
/*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1479+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ false);
14791480
verifyContext(
14801481
/*contextRank*/ 4, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2}, /*expectPPDomain*/ 1,
1481-
/*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1482+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
14821483
verifyContext(
14831484
/*contextRank*/ 5, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2}, /*expectPPDomain*/ 1,
1484-
/*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1485+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ false);
14851486
verifyContext(
14861487
/*contextRank*/ 6, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {3}, /*expectPPDomain*/ 1,
1487-
/*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1488+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
14881489
verifyContext(
14891490
/*contextRank*/ 7, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {3}, /*expectPPDomain*/ 1,
1490-
/*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1491+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ false);
14911492
}
14921493

14931494
// TP grows from context to generation.
@@ -1496,16 +1497,16 @@ TEST(targetTest, CacheStateNODP)
14961497
tr::WorldConfig const genWC{/*tpSize*/ 4, /*ppSize*/ 2, /*cpSize*/ 1};
14971498
verifyContext(
14981499
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1,
1499-
/*expectTPDomain*/ 2, /*expectNeedSend*/ true);
1500+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15001501
verifyContext(
15011502
/*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1,
1502-
/*expectTPDomain*/ 2, /*expectNeedSend*/ true);
1503+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15031504
verifyContext(
15041505
/*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 5}, /*expectPPDomain*/ 1,
1505-
/*expectTPDomain*/ 2, /*expectNeedSend*/ true);
1506+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15061507
verifyContext(
15071508
/*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {6, 7}, /*expectPPDomain*/ 1,
1508-
/*expectTPDomain*/ 2, /*expectNeedSend*/ true);
1509+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15091510
}
15101511

15111512
// TP as well as PP grow from context to generation.
@@ -1514,21 +1515,46 @@ TEST(targetTest, CacheStateNODP)
15141515
tr::WorldConfig const genWC{/*tpSize*/ 4, /*ppSize*/ 2, /*cpSize*/ 1};
15151516
verifyContext(
15161517
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5},
1517-
/*expectPPDomain*/ 2,
1518-
/*expectTPDomain*/ 2, /*expectNeedSend*/ true);
1518+
/*expectPPDomain*/ 2, /*expectTPDomain*/ 2, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15191519
verifyContext(
15201520
/*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 6, 3, 7},
1521-
/*expectPPDomain*/ 2,
1522-
/*expectTPDomain*/ 2, /*expectNeedSend*/ true);
1521+
/*expectPPDomain*/ 2, /*expectTPDomain*/ 2, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15231522
}
15241523

15251524
// CP grows from context to generation.
15261525
{
1527-
tr::WorldConfig const contextWC{/*tpSize*/ 1, /*ppSize*/ 1, /*cpSize*/ 1};
1528-
tr::WorldConfig const genWC{/*tpSize*/ 1, /*ppSize*/ 1, /*cpSize*/ 2};
1526+
tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1};
1527+
tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 2};
1528+
verifyContext(
1529+
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 2},
1530+
/*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1531+
verifyContext(
1532+
/*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 3},
1533+
/*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1534+
verifyContext(
1535+
/*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {4, 6},
1536+
/*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1537+
verifyContext(
1538+
/*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {5, 7},
1539+
/*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1540+
}
1541+
1542+
// TP as well as CP grow from context to generation.
1543+
{
1544+
tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1};
1545+
tr::WorldConfig const genWC{/*tpSize*/ 4, /*ppSize*/ 2, /*cpSize*/ 2};
1546+
verifyContext(
1547+
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5}, /*expectPPDomain*/ 1,
1548+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1549+
verifyContext(
1550+
/*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {2, 6, 3, 7}, /*expectPPDomain*/ 1,
1551+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1552+
verifyContext(
1553+
/*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {8, 12, 9, 13}, /*expectPPDomain*/ 1,
1554+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
15291555
verifyContext(
1530-
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1},
1531-
/*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1556+
/*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {10, 14, 11, 15}, /*expectPPDomain*/ 1,
1557+
/*expectTPDomain*/ 2, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
15321558
}
15331559

15341560
// // TP shrinks while CP grows from context to generation.
@@ -1537,17 +1563,30 @@ TEST(targetTest, CacheStateNODP)
15371563
// tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 1, /*cpSize*/ 2};
15381564
// }
15391565

1540-
// // TP grows while CP shrinks from context to generation.
1541-
// {
1542-
// tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1};
1543-
// tr::WorldConfig const genWC{/*tpSize*/ 4, /*ppSize*/ 2, /*cpSize*/ 4};
1544-
// }
1566+
// PP as well as CP grow from context to generation.
1567+
{
1568+
tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 2, /*cpSize*/ 1};
1569+
tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 4, /*cpSize*/ 2};
1570+
verifyContext(
1571+
/*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 2, 6}, /*expectPPDomain*/ 2,
1572+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1573+
verifyContext(
1574+
/*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {1, 5, 3, 7}, /*expectPPDomain*/ 2,
1575+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1576+
verifyContext(
1577+
/*contextRank*/ 2, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {8, 12, 10, 14}, /*expectPPDomain*/ 2,
1578+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1579+
verifyContext(
1580+
/*contextRank*/ 3, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {9, 13, 11, 15}, /*expectPPDomain*/ 2,
1581+
/*expectTPDomain*/ 1, /*expectCPDomain*/ 2, /*expectNeedSend*/ true);
1582+
}
15451583

1546-
// // TP shrinks while CP grows from context to generation.
1584+
// // PP shrinks while CP grows from context to generation.
15471585
// {
15481586
// tr::WorldConfig const contextWC{/*tpSize*/ 4, /*ppSize*/ 1, /*cpSize*/ 1};
15491587
// tr::WorldConfig const genWC{/*tpSize*/ 2, /*ppSize*/ 1, /*cpSize*/ 2};
15501588
// }
1589+
15511590
}
15521591

15531592
TEST(targetTest, CacheStateContextDP)

0 commit comments

Comments
 (0)