3131
3232#include " llvm/ADT/DenseMap.h"
3333#include " llvm/IR/PassManager.h"
34+ #include " llvm/IR/Type.h"
3435#include " llvm/Support/CommandLine.h"
3536#include " llvm/Support/Compiler.h"
3637#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
4344class BasicBlock ;
4445class Instruction ;
4546class Function ;
46- class Type ;
4747class Value ;
4848class raw_ostream ;
4949class LLVMContext ;
50+ class IR2VecVocabAnalysis ;
5051
5152// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
5253// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -125,9 +126,73 @@ struct Embedding {
125126
126127using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
127128using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
128- // FIXME: Current the keys are strings. This can be changed to
129- // use integers for cheaper lookups.
130- using Vocab = std::map<std::string, Embedding>;
129+
130+ // / Class for storing and accessing the IR2Vec vocabulary.
131+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
132+ class Vocabulary {
133+ friend class llvm ::IR2VecVocabAnalysis;
134+ using VocabVector = std::vector<ir2vec::Embedding>;
135+ VocabVector Vocab;
136+ bool Valid = false ;
137+
138+ // / Operand kinds supported by IR2Vec Vocabulary
139+ #define OPERAND_KINDS \
140+ OPERAND_KIND (FunctionID, " Function" ) \
141+ OPERAND_KIND (PointerID, " Pointer" ) \
142+ OPERAND_KIND (ConstantID, " Constant" ) \
143+ OPERAND_KIND (VariableID, " Variable" )
144+
145+ enum class OperandKind : unsigned {
146+ #define OPERAND_KIND (Name, Str ) Name,
147+ OPERAND_KINDS
148+ #undef OPERAND_KIND
149+ MaxOperandKind
150+ };
151+
152+ #undef OPERAND_KINDS
153+
154+ // / Vocabulary layout constants
155+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
156+ #include " llvm/IR/Instruction.def"
157+ #undef LAST_OTHER_INST
158+
159+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
160+ static constexpr unsigned MaxOperandKinds =
161+ static_cast <unsigned >(OperandKind::MaxOperandKind);
162+
163+ // / Helper function to get vocabulary key for a given OperandKind
164+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
165+
166+ // / Helper function to classify an operand into OperandKind
167+ static OperandKind getOperandKind (const Value *Op);
168+
169+ // / Helper function to get vocabulary key for a given TypeID
170+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
171+
172+ public:
173+ Vocabulary () = default ;
174+ Vocabulary (VocabVector &&Vocab);
175+
176+ bool isValid () const ;
177+ unsigned getDimension () const ;
178+ unsigned size () const ;
179+
180+ const ir2vec::Embedding &at (unsigned Position) const ;
181+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
182+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
183+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
184+
185+ // / Returns the string key for a given index position in the vocabulary.
186+ // / This is useful for debugging or printing the vocabulary. Do not use this
187+ // / for embedding generation as string based lookups are inefficient.
188+ static StringRef getStringKey (unsigned Pos);
189+
190+ // / Create a dummy vocabulary for testing purposes.
191+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
192+
193+ bool invalidate (Module &M, const PreservedAnalyses &PA,
194+ ModuleAnalysisManager::Invalidator &Inv) const ;
195+ };
131196
132197// / Embedder provides the interface to generate embeddings (vector
133198// / representations) for instructions, basic blocks, and functions. The
@@ -138,7 +203,7 @@ using Vocab = std::map<std::string, Embedding>;
138203class Embedder {
139204protected:
140205 const Function &F;
141- const Vocab &Vocabulary ;
206+ const Vocabulary &Vocab ;
142207
143208 // / Dimension of the vector representation; captured from the input vocabulary
144209 const unsigned Dimension;
@@ -153,7 +218,7 @@ class Embedder {
153218 mutable BBEmbeddingsMap BBVecMap;
154219 mutable InstEmbeddingsMap InstVecMap;
155220
156- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
221+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
157222
158223 // / Helper function to compute embeddings. It generates embeddings for all
159224 // / the instructions and basic blocks in the function F. Logic of computing
@@ -164,16 +229,12 @@ class Embedder {
164229 // / Specific to the kind of embeddings being computed.
165230 virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
166231
167- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
168- // / zero vector.
169- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
170-
171232public:
172233 virtual ~Embedder () = default ;
173234
174235 // / Factory method to create an Embedder object.
175236 LLVM_ABI static std::unique_ptr<Embedder>
176- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
237+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
177238
178239 // / Returns a map containing instructions and the corresponding embeddings for
179240 // / the function F if it has been computed. If not, it computes the embeddings
@@ -199,56 +260,40 @@ class Embedder {
199260// / representations obtained from the Vocabulary.
200261class LLVM_ABI SymbolicEmbedder : public Embedder {
201262private:
202- // / Utility function to compute the embedding for a given type.
203- Embedding getTypeEmbedding (const Type *Ty) const ;
204-
205- // / Utility function to compute the embedding for a given operand.
206- Embedding getOperandEmbedding (const Value *Op) const ;
207-
208263 void computeEmbeddings () const override ;
209264 void computeEmbeddings (const BasicBlock &BB) const override ;
210265
211266public:
212- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
213- : Embedder(F, Vocabulary ) {
267+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
268+ : Embedder(F, Vocab ) {
214269 FuncVector = Embedding (Dimension, 0 );
215270 }
216271};
217272
218273} // namespace ir2vec
219274
220- // / Class for storing the result of the IR2VecVocabAnalysis.
221- class IR2VecVocabResult {
222- ir2vec::Vocab Vocabulary;
223- bool Valid = false ;
224-
225- public:
226- IR2VecVocabResult () = default ;
227- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
228-
229- bool isValid () const { return Valid; }
230- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
231- LLVM_ABI unsigned getDimension () const ;
232- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
233- ModuleAnalysisManager::Invalidator &Inv) const ;
234- };
235-
236275// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
237276// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
238277// / its corresponding embedding.
239278class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
240- ir2vec::Vocab Vocabulary;
279+ using VocabVector = std::vector<ir2vec::Embedding>;
280+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
281+ VocabMap OpcVocab, TypeVocab, ArgVocab;
282+ VocabVector Vocab;
283+
284+ unsigned Dim = 0 ;
241285 Error readVocabulary ();
242286 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
243- ir2vec::Vocab &TargetVocab, unsigned &Dim);
287+ VocabMap &TargetVocab, unsigned &Dim);
288+ void generateNumMappedVocab ();
244289 void emitError (Error Err, LLVMContext &Ctx);
245290
246291public:
247292 LLVM_ABI static AnalysisKey Key;
248293 IR2VecVocabAnalysis () = default ;
249- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
250- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
251- using Result = IR2VecVocabResult ;
294+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
295+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
296+ using Result = ir2vec::Vocabulary ;
252297 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
253298};
254299
0 commit comments