@@ -199,25 +199,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
199199 EXPECT_EQ (validResult.getDimension (), 2u );
200200}
201201
202- // Helper to create a minimal function and embedder for getter tests
203- struct GetterTestEnv {
204- Vocab V = {};
202+ // Fixture for IR2Vec tests requiring IR setup and weight management.
203+ class IR2VecTestFixture : public ::testing::Test {
204+ protected:
205+ Vocab V;
205206 LLVMContext Ctx;
206- std::unique_ptr<Module> M = nullptr ;
207+ std::unique_ptr<Module> M;
207208 Function *F = nullptr ;
208209 BasicBlock *BB = nullptr ;
209- Instruction *Add = nullptr ;
210- Instruction *Ret = nullptr ;
211- std::unique_ptr<Embedder> Emb = nullptr ;
210+ Instruction *AddInst = nullptr ;
211+ Instruction *RetInst = nullptr ;
212212
213- GetterTestEnv () {
213+ float OriginalOpcWeight = ::OpcWeight;
214+ float OriginalTypeWeight = ::TypeWeight;
215+ float OriginalArgWeight = ::ArgWeight;
216+
217+ void SetUp () override {
214218 V = {{" add" , {1.0 , 2.0 }},
215219 {" integerTy" , {0.5 , 0.5 }},
216220 {" constant" , {0.2 , 0.3 }},
217221 {" variable" , {0.0 , 0.0 }},
218222 {" unknownTy" , {0.0 , 0.0 }}};
219223
220- M = std::make_unique<Module>(" M" , Ctx);
224+ // Setup IR
225+ M = std::make_unique<Module>(" TestM" , Ctx);
221226 FunctionType *FTy = FunctionType::get (
222227 Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
223228 false );
@@ -226,61 +231,82 @@ struct GetterTestEnv {
226231 Argument *Arg = F->getArg (0 );
227232 llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
228233
229- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
230- Ret = ReturnInst::Create (Ctx, Add, BB);
234+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
235+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
236+ }
237+
238+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
239+ ::OpcWeight = OpcWeight;
240+ ::TypeWeight = TypeWeight;
241+ ::ArgWeight = ArgWeight;
242+ }
231243
232- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
233- EXPECT_TRUE (static_cast <bool >(Result));
234- Emb = std::move (*Result);
244+ void TearDown () override {
245+ // Restore original global weights
246+ ::OpcWeight = OriginalOpcWeight;
247+ ::TypeWeight = OriginalTypeWeight;
248+ ::ArgWeight = OriginalArgWeight;
235249 }
236250};
237251
238- TEST (IR2VecTest, GetInstVecMap) {
239- GetterTestEnv Env;
240- const auto &InstMap = Env.Emb ->getInstVecMap ();
252+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
253+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
254+ ASSERT_TRUE (static_cast <bool >(Result));
255+ auto Emb = std::move (*Result);
256+
257+ const auto &InstMap = Emb->getInstVecMap ();
241258
242259 EXPECT_EQ (InstMap.size (), 2u );
243- EXPECT_TRUE (InstMap.count (Env. Add ));
244- EXPECT_TRUE (InstMap.count (Env. Ret ));
260+ EXPECT_TRUE (InstMap.count (AddInst ));
261+ EXPECT_TRUE (InstMap.count (RetInst ));
245262
246- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
247- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
263+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
264+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
248265
249266 // Check values for add: {1.29, 2.31}
250- EXPECT_THAT (InstMap.at (Env. Add ),
267+ EXPECT_THAT (InstMap.at (AddInst ),
251268 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
252269
253270 // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
254271 // vocab
255- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
272+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
256273}
257274
258- TEST (IR2VecTest, GetBBVecMap) {
259- GetterTestEnv Env;
260- const auto &BBMap = Env.Emb ->getBBVecMap ();
275+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
276+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
277+ ASSERT_TRUE (static_cast <bool >(Result));
278+ auto Emb = std::move (*Result);
279+
280+ const auto &BBMap = Emb->getBBVecMap ();
261281
262282 EXPECT_EQ (BBMap.size (), 1u );
263- EXPECT_TRUE (BBMap.count (Env. BB ));
264- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
283+ EXPECT_TRUE (BBMap.count (BB));
284+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
265285
266286 // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
267287 // {1.29, 2.31}
268- EXPECT_THAT (BBMap.at (Env. BB ),
288+ EXPECT_THAT (BBMap.at (BB),
269289 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
270290}
271291
272- TEST (IR2VecTest, GetBBVector) {
273- GetterTestEnv Env;
274- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
292+ TEST_F (IR2VecTestFixture, GetBBVector) {
293+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
294+ ASSERT_TRUE (static_cast <bool >(Result));
295+ auto Emb = std::move (*Result);
296+
297+ const auto &BBVec = Emb->getBBVector (*BB);
275298
276299 EXPECT_EQ (BBVec.size (), 2u );
277300 EXPECT_THAT (BBVec,
278301 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
279302}
280303
281- TEST (IR2VecTest, GetFunctionVector) {
282- GetterTestEnv Env;
283- const auto &FuncVec = Env.Emb ->getFunctionVector ();
304+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
305+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
306+ ASSERT_TRUE (static_cast <bool >(Result));
307+ auto Emb = std::move (*Result);
308+
309+ const auto &FuncVec = Emb->getFunctionVector ();
284310
285311 EXPECT_EQ (FuncVec.size (), 2u );
286312
@@ -289,4 +315,45 @@ TEST(IR2VecTest, GetFunctionVector) {
289315 ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
290316}
291317
318+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
319+ setWeights (1.0 , 1.0 , 1.0 );
320+
321+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
322+ ASSERT_TRUE (static_cast <bool >(Result));
323+ auto Emb = std::move (*Result);
324+
325+ const auto &FuncVec = Emb->getFunctionVector ();
326+
327+ EXPECT_EQ (FuncVec.size (), 2u );
328+
329+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
330+ // 0.3] + [0.0 0.0])
331+ EXPECT_THAT (FuncVec,
332+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
333+ }
334+
335+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
336+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
337+ Vocab ExpectedVocab = InitialVocab;
338+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
339+
340+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
341+
342+ LLVMContext TestCtx;
343+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
344+ ModuleAnalysisManager MAM;
345+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
346+
347+ EXPECT_TRUE (Result.isValid ());
348+ ASSERT_FALSE (Result.getVocabulary ().empty ());
349+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
350+
351+ const auto &ResultVocab = Result.getVocabulary ();
352+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
353+ for (const auto &pair : ExpectedVocab) {
354+ EXPECT_TRUE (ResultVocab.count (pair.first ));
355+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
356+ }
357+ }
358+
292359} // end anonymous namespace
0 commit comments