@@ -1438,9 +1438,9 @@ TEST(targetTest, CacheStateNODP)
14381438    bool  const  isMLA = true ;
14391439    int  const  kvFactor = 2 ;
14401440
1441-     auto  const  verifyContext
1442-         = [&]( int  contextRank, tr::WorldConfig  const & contextWC, tr::WorldConfig  const & genWC ,
1443-               std::vector< int >  const & expectRanks,  int  expectPPDomain,  int  expectTPDomain,  int  expectCPDomain, bool  expectNeedSend)
1441+     auto  const  verifyContext = [&]( int  contextRank, tr::WorldConfig  const & contextWC, tr::WorldConfig  const & genWC, 
1442+                                    std::vector< int >  const & expectRanks,  int  expectPPDomain,  int  expectTPDomain ,
1443+                                     int  expectCPDomain, bool  expectNeedSend)
14441444    {
14451445        auto  attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA 
14461446                                   : texec::kv_cache::CacheState::AttentionType::kDEFAULT ;
@@ -1526,11 +1526,13 @@ TEST(targetTest, CacheStateNODP)
15261526    //      tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 1, /*cpSize*/ 1};
15271527    //      tr::WorldConfig const genWC{/*tpSize*/ 1, /*ppSize*/ 2, /*cpSize*/ 1};
15281528    //      verifyContext(
1529-     //          /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 2,
1529+     //          /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/
1530+     //          2,
15301531    //          /*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ true);
15311532    //      // TODO: Figure why needSendCache is false here.
15321533    //      verifyContext(
1533-     //          /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 2,
1534+     //          /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 1}, /*expectPPDomain*/
1535+     //          2,
15341536    //          /*expectTPDomain*/ 1, /*expectCPDomain*/ 1, /*expectNeedSend*/ false);
15351537    //  }
15361538
@@ -1557,16 +1559,20 @@ TEST(targetTest, CacheStateNODP)
15571559        tr::WorldConfig const  contextWC{/* tpSize*/ 2 , /* ppSize*/ 2 , /* cpSize*/ 1 };
15581560        tr::WorldConfig const  genWC{/* tpSize*/ 4 , /* ppSize*/ 2 , /* cpSize*/ 2 };
15591561        verifyContext (
1560-             /* contextRank*/ 0 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 1 , 5 }, /* expectPPDomain*/ 1 ,
1562+             /* contextRank*/ 0 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 1 , 5 },
1563+             /* expectPPDomain*/ 1 ,
15611564            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
15621565        verifyContext (
1563-             /* contextRank*/ 1 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 2 , 6 , 3 , 7 }, /* expectPPDomain*/ 1 ,
1566+             /* contextRank*/ 1 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 2 , 6 , 3 , 7 },
1567+             /* expectPPDomain*/ 1 ,
15641568            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
15651569        verifyContext (
1566-             /* contextRank*/ 2 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 8 , 12 , 9 , 13 }, /* expectPPDomain*/ 1 ,
1570+             /* contextRank*/ 2 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 8 , 12 , 9 , 13 },
1571+             /* expectPPDomain*/ 1 ,
15671572            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
15681573        verifyContext (
1569-             /* contextRank*/ 3 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 10 , 14 , 11 , 15 }, /* expectPPDomain*/ 1 ,
1574+             /* contextRank*/ 3 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 10 , 14 , 11 , 15 },
1575+             /* expectPPDomain*/ 1 ,
15701576            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
15711577    }
15721578
@@ -1593,16 +1599,20 @@ TEST(targetTest, CacheStateNODP)
15931599        tr::WorldConfig const  contextWC{/* tpSize*/ 2 , /* ppSize*/ 2 , /* cpSize*/ 1 };
15941600        tr::WorldConfig const  genWC{/* tpSize*/ 2 , /* ppSize*/ 4 , /* cpSize*/ 2 };
15951601        verifyContext (
1596-             /* contextRank*/ 0 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 2 , 6 }, /* expectPPDomain*/ 2 ,
1602+             /* contextRank*/ 0 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 2 , 6 },
1603+             /* expectPPDomain*/ 2 ,
15971604            /* expectTPDomain*/ 1 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
15981605        verifyContext (
1599-             /* contextRank*/ 1 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 1 , 5 , 3 , 7 }, /* expectPPDomain*/ 2 ,
1606+             /* contextRank*/ 1 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 1 , 5 , 3 , 7 },
1607+             /* expectPPDomain*/ 2 ,
16001608            /* expectTPDomain*/ 1 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16011609        verifyContext (
1602-             /* contextRank*/ 2 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 8 , 12 , 10 , 14 }, /* expectPPDomain*/ 2 ,
1610+             /* contextRank*/ 2 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 8 , 12 , 10 , 14 },
1611+             /* expectPPDomain*/ 2 ,
16031612            /* expectTPDomain*/ 1 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16041613        verifyContext (
1605-             /* contextRank*/ 3 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 9 , 13 , 11 , 15 }, /* expectPPDomain*/ 2 ,
1614+             /* contextRank*/ 3 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 9 , 13 , 11 , 15 },
1615+             /* expectPPDomain*/ 2 ,
16061616            /* expectTPDomain*/ 1 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16071617    }
16081618
@@ -1671,16 +1681,20 @@ TEST(targetTest, CacheStateNODP)
16711681        tr::WorldConfig const  contextWC{/* tpSize*/ 2 , /* ppSize*/ 2 , /* cpSize*/ 1 };
16721682        tr::WorldConfig const  genWC{/* tpSize*/ 4 , /* ppSize*/ 1 , /* cpSize*/ 2 };
16731683        verifyContext (
1674-             /* contextRank*/ 0 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 1 , 5 }, /* expectPPDomain*/ 1 ,
1684+             /* contextRank*/ 0 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 1 , 5 },
1685+             /* expectPPDomain*/ 1 ,
16751686            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16761687        verifyContext (
1677-             /* contextRank*/ 1 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 2 , 6 , 3 , 7 }, /* expectPPDomain*/ 1 ,
1688+             /* contextRank*/ 1 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 2 , 6 , 3 , 7 },
1689+             /* expectPPDomain*/ 1 ,
16781690            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16791691        verifyContext (
1680-             /* contextRank*/ 2 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 1 , 5 }, /* expectPPDomain*/ 1 ,
1692+             /* contextRank*/ 2 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 0 , 4 , 1 , 5 },
1693+             /* expectPPDomain*/ 1 ,
16811694            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16821695        verifyContext (
1683-             /* contextRank*/ 3 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 2 , 6 , 3 , 7 }, /* expectPPDomain*/ 1 ,
1696+             /* contextRank*/ 3 , /* contextWC*/ /* genWC*/ /* expectRanks*/ 2 , 6 , 3 , 7 },
1697+             /* expectPPDomain*/ 1 ,
16841698            /* expectTPDomain*/ 2 , /* expectCPDomain*/ 2 , /* expectNeedSend*/ true );
16851699    }
16861700
@@ -1689,11 +1703,13 @@ TEST(targetTest, CacheStateNODP)
16891703    //      tr::WorldConfig const contextWC{/*tpSize*/ 2, /*ppSize*/ 1, /*cpSize*/ 1};
16901704    //      tr::WorldConfig const genWC{/*tpSize*/ 1, /*ppSize*/ 2, /*cpSize*/ 4};
16911705    //      verifyContext(
1692-     //          /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5, 2, 6, 3, 7}, /*expectPPDomain*/ 2,
1706+     //          /*contextRank*/ 0, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5, 2, 6, 3, 7},
1707+     //          /*expectPPDomain*/ 2,
16931708    //          /*expectTPDomain*/ 1, /*expectCPDomain*/ 4, /*expectNeedSend*/ true);
16941709    //      // TODO: Figure why needSendCache is false here.
16951710    //      verifyContext(
1696-     //          /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5, 2, 6, 3, 7}, /*expectPPDomain*/ 2,
1711+     //          /*contextRank*/ 1, /*contextWC*/ contextWC, /*genWC*/ genWC, /*expectRanks*/ {0, 4, 1, 5, 2, 6, 3, 7},
1712+     //          /*expectPPDomain*/ 2,
16971713    //          /*expectTPDomain*/ 1, /*expectCPDomain*/ 4, /*expectNeedSend*/ false);
16981714    //  }
16991715}
0 commit comments