@@ -266,25 +266,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
266266 EXPECT_EQ (validResult.getDimension (), 2u );
267267}
268268
269- // Helper to create a minimal function and embedder for getter tests
270- struct GetterTestEnv {
271- Vocab V = {};
269+ // Fixture for IR2Vec tests requiring IR setup and weight management.
270+ class IR2VecTestFixture : public ::testing::Test {
271+ protected:
272+ Vocab V;
272273 LLVMContext Ctx;
273- std::unique_ptr<Module> M = nullptr ;
274+ std::unique_ptr<Module> M;
274275 Function *F = nullptr ;
275276 BasicBlock *BB = nullptr ;
276- Instruction *Add = nullptr ;
277- Instruction *Ret = nullptr ;
278- std::unique_ptr<Embedder> Emb = nullptr ;
277+ Instruction *AddInst = nullptr ;
278+ Instruction *RetInst = nullptr ;
279279
280- GetterTestEnv () {
280+ float OriginalOpcWeight = ::OpcWeight;
281+ float OriginalTypeWeight = ::TypeWeight;
282+ float OriginalArgWeight = ::ArgWeight;
283+
284+ void SetUp () override {
281285 V = {{" add" , {1.0 , 2.0 }},
282286 {" integerTy" , {0.5 , 0.5 }},
283287 {" constant" , {0.2 , 0.3 }},
284288 {" variable" , {0.0 , 0.0 }},
285289 {" unknownTy" , {0.0 , 0.0 }}};
286290
287- M = std::make_unique<Module>(" M" , Ctx);
291+ // Setup IR
292+ M = std::make_unique<Module>(" TestM" , Ctx);
288293 FunctionType *FTy = FunctionType::get (
289294 Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
290295 false );
@@ -293,61 +298,82 @@ struct GetterTestEnv {
293298 Argument *Arg = F->getArg (0 );
294299 llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
295300
296- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
297- Ret = ReturnInst::Create (Ctx, Add, BB);
301+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
302+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
303+ }
304+
305+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
306+ ::OpcWeight = OpcWeight;
307+ ::TypeWeight = TypeWeight;
308+ ::ArgWeight = ArgWeight;
309+ }
298310
299- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
300- EXPECT_TRUE (static_cast <bool >(Result));
301- Emb = std::move (*Result);
311+ void TearDown () override {
312+ // Restore original global weights
313+ ::OpcWeight = OriginalOpcWeight;
314+ ::TypeWeight = OriginalTypeWeight;
315+ ::ArgWeight = OriginalArgWeight;
302316 }
303317};
304318
305- TEST (IR2VecTest, GetInstVecMap) {
306- GetterTestEnv Env;
307- const auto &InstMap = Env.Emb ->getInstVecMap ();
319+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
320+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
321+ ASSERT_TRUE (static_cast <bool >(Result));
322+ auto Emb = std::move (*Result);
323+
324+ const auto &InstMap = Emb->getInstVecMap ();
308325
309326 EXPECT_EQ (InstMap.size (), 2u );
310- EXPECT_TRUE (InstMap.count (Env. Add ));
311- EXPECT_TRUE (InstMap.count (Env. Ret ));
327+ EXPECT_TRUE (InstMap.count (AddInst ));
328+ EXPECT_TRUE (InstMap.count (RetInst ));
312329
313- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
314- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
330+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
331+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
315332
316333 // Check values for add: {1.29, 2.31}
317- EXPECT_THAT (InstMap.at (Env. Add ),
334+ EXPECT_THAT (InstMap.at (AddInst ),
318335 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
319336
320337 // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
321338 // vocab
322- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
339+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
323340}
324341
325- TEST (IR2VecTest, GetBBVecMap) {
326- GetterTestEnv Env;
327- const auto &BBMap = Env.Emb ->getBBVecMap ();
342+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
343+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
344+ ASSERT_TRUE (static_cast <bool >(Result));
345+ auto Emb = std::move (*Result);
346+
347+ const auto &BBMap = Emb->getBBVecMap ();
328348
329349 EXPECT_EQ (BBMap.size (), 1u );
330- EXPECT_TRUE (BBMap.count (Env. BB ));
331- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
350+ EXPECT_TRUE (BBMap.count (BB));
351+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
332352
333353 // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
334354 // {1.29, 2.31}
335- EXPECT_THAT (BBMap.at (Env. BB ),
355+ EXPECT_THAT (BBMap.at (BB),
336356 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
337357}
338358
339- TEST (IR2VecTest, GetBBVector) {
340- GetterTestEnv Env;
341- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
359+ TEST_F (IR2VecTestFixture, GetBBVector) {
360+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
361+ ASSERT_TRUE (static_cast <bool >(Result));
362+ auto Emb = std::move (*Result);
363+
364+ const auto &BBVec = Emb->getBBVector (*BB);
342365
343366 EXPECT_EQ (BBVec.size (), 2u );
344367 EXPECT_THAT (BBVec,
345368 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
346369}
347370
348- TEST (IR2VecTest, GetFunctionVector) {
349- GetterTestEnv Env;
350- const auto &FuncVec = Env.Emb ->getFunctionVector ();
371+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
372+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
373+ ASSERT_TRUE (static_cast <bool >(Result));
374+ auto Emb = std::move (*Result);
375+
376+ const auto &FuncVec = Emb->getFunctionVector ();
351377
352378 EXPECT_EQ (FuncVec.size (), 2u );
353379
@@ -356,4 +382,45 @@ TEST(IR2VecTest, GetFunctionVector) {
356382 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
357383}
358384
385+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
386+ setWeights (1.0 , 1.0 , 1.0 );
387+
388+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
389+ ASSERT_TRUE (static_cast <bool >(Result));
390+ auto Emb = std::move (*Result);
391+
392+ const auto &FuncVec = Emb->getFunctionVector ();
393+
394+ EXPECT_EQ (FuncVec.size (), 2u );
395+
396+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
397+ // 0.3] + [0.0 0.0])
398+ EXPECT_THAT (FuncVec,
399+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
400+ }
401+
402+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
403+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
404+ Vocab ExpectedVocab = InitialVocab;
405+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
406+
407+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
408+
409+ LLVMContext TestCtx;
410+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
411+ ModuleAnalysisManager MAM;
412+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
413+
414+ EXPECT_TRUE (Result.isValid ());
415+ ASSERT_FALSE (Result.getVocabulary ().empty ());
416+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
417+
418+ const auto &ResultVocab = Result.getVocabulary ();
419+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
420+ for (const auto &pair : ExpectedVocab) {
421+ EXPECT_TRUE (ResultVocab.count (pair.first ));
422+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
423+ }
424+ }
425+
359426} // end anonymous namespace
0 commit comments