Skip to content

Commit 644cea9

Browse files
committed
add header.h
1 parent d58f4be commit 644cea9

File tree

2 files changed

+138
-20
lines changed

2 files changed

+138
-20
lines changed

kaleidoscope/header.h

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
2+
#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
3+
4+
#include "llvm/ADT/STLExtras.h"
5+
#include "llvm/ExecutionEngine/ExecutionEngine.h"
6+
#include "llvm/ExecutionEngine/JITSymbol.h"
7+
#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
8+
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
9+
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
10+
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
11+
#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
12+
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
13+
#include "llvm/IR/DataLayout.h"
14+
#include "llvm/IR/Mangler.h"
15+
#include "llvm/Support/DynamicLibrary.h"
16+
#include "llvm/Support/raw_ostream.h"
17+
#include "llvm/Target/TargetMachine.h"
18+
#include <algorithm>
19+
#include <memory>
20+
#include <string>
21+
#include <vector>
22+
23+
namespace llvm {
24+
namespace orc {
25+
26+
class KaleidoscopeJIT {
27+
private:
28+
std::unique_ptr<TargetMachine> TM;
29+
const DataLayout DL;
30+
RTDyldObjectLinkingLayer ObjectLayer;
31+
IRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
32+
33+
public:
34+
using ModuleHandle = decltype(CompileLayer)::ModuleHandleT;
35+
36+
KaleidoscopeJIT()
37+
: TM(EngineBuilder().selectTarget()), DL(TM->createDataLayout()),
38+
ObjectLayer([]() { return std::make_shared<SectionMemoryManager>(); }),
39+
CompileLayer(ObjectLayer, SimpleCompiler(*TM)) {
40+
llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
41+
}
42+
43+
TargetMachine &getTargetMachine() { return *TM; }
44+
45+
ModuleHandle addModule(std::unique_ptr<Module> M) {
46+
// Build our symbol resolver:
47+
// Lambda 1: Look back into the JIT itself to find symbols that are part of
48+
// the same "logical dylib".
49+
// Lambda 2: Search for external symbols in the host process.
50+
auto Resolver = createLambdaResolver(
51+
[&](const std::string &Name) {
52+
if (auto Sym = CompileLayer.findSymbol(Name, false))
53+
return Sym;
54+
return JITSymbol(nullptr);
55+
},
56+
[](const std::string &Name) {
57+
if (auto SymAddr =
58+
RTDyldMemoryManager::getSymbolAddressInProcess(Name))
59+
return JITSymbol(SymAddr, JITSymbolFlags::Exported);
60+
return JITSymbol(nullptr);
61+
});
62+
63+
// Add the set to the JIT with the resolver we created above and a newly
64+
// created SectionMemoryManager.
65+
return cantFail(CompileLayer.addModule(std::move(M),
66+
std::move(Resolver)));
67+
}
68+
69+
JITSymbol findSymbol(const std::string Name) {
70+
std::string MangledName;
71+
raw_string_ostream MangledNameStream(MangledName);
72+
Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
73+
return CompileLayer.findSymbol(MangledNameStream.str(), true);
74+
}
75+
76+
JITTargetAddress getSymbolAddress(const std::string Name) {
77+
return cantFail(findSymbol(Name).getAddress());
78+
}
79+
80+
void removeModule(ModuleHandle H) {
81+
cantFail(CompileLayer.removeModule(H));
82+
}
83+
};
84+
85+
} // end namespace orc
86+
} // end namespace llvm
87+
88+
#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H

kaleidoscope/main.cpp

+50-20
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <llvm/IR/LegacyPassManager.h>
2020
#include <llvm/Transforms/Scalar.h>
2121

22+
#include "header.h"
23+
2224
namespace Singleton {
2325
llvm::LLVMContext& context () {
2426
static llvm::LLVMContext context_;
@@ -47,6 +49,17 @@ namespace Singleton {
4749
std::make_unique<llvm::legacy::FunctionPassManager>(module_ptr.get());
4850
return fpm_;
4951
}
52+
53+
std::unique_ptr<llvm::orc::KaleidoscopeJIT>& jit_ptr () {
54+
static auto jit_ =
55+
std::make_unique<llvm::orc::KaleidoscopeJIT>();
56+
return jit_;
57+
}
58+
}
59+
60+
void init_module () {
61+
Singleton::module_ptr()->setDataLayout(
62+
Singleton::jit_ptr()->getTargetMachine().createDataLayout());
5063
}
5164

5265
void init_function_pass_manager () {
@@ -722,7 +735,7 @@ std::unique_ptr<PrototypeAST> Parser::parse_extern () {
722735
std::unique_ptr<FunctionAST> Parser::parse_top_level_expression () {
723736
auto expr = parse_expression();
724737
if (expr) {
725-
auto proto = std::make_unique<PrototypeAST>("anonymous", std::vector<std::string>());
738+
auto proto = std::make_unique<PrototypeAST>("__anonymous", std::vector<std::string>());
726739
return std::make_unique<FunctionAST>(std::move(proto), std::move(expr));
727740
} else {
728741
return nullptr;
@@ -739,15 +752,15 @@ std::unique_ptr<ExprAST> Parser::parse () {
739752
case Token::Extern:
740753
return parse_extern();
741754
default:
742-
return parse_expression();
755+
return parse_top_level_expression();
743756
}
744757
}
745758

746759
void parse (std::string& code) {
747760
char* ptr = const_cast<char*>(code.data());
748761
uint32_t len = code.size();
749762

750-
printf("Code: %s\n", code.c_str());
763+
printf("Code:\n%s\n", code.c_str());
751764

752765
// lexer
753766
#if 0
@@ -767,23 +780,29 @@ void parse (std::string& code) {
767780
#endif
768781

769782
printf("Parse:\n");
783+
printf("========================\n");
770784
// parse
771-
{
772-
Lexer lexer(ptr, len);
773-
Parser parser(lexer);
774-
auto res = parser.parse();
785+
Lexer lexer(ptr, len);
786+
Parser parser(lexer);
787+
while (auto res = parser.parse()) {
788+
//auto res = parser.parse();
775789
res->print();
776790
printf("\n");
777791
printf("IR:\n");
778792
auto* ir = res->codegen();
779793
ir->print(llvm::errs());
780794
printf("\n");
795+
printf("========================\n");
781796
}
782-
printf("========================\n");
783797
}
784798

785799
int main () {
786800

801+
LLVMInitializeNativeTarget();
802+
LLVMInitializeNativeAsmPrinter();
803+
LLVMInitializeNativeAsmParser();
804+
805+
init_module();
787806
init_function_pass_manager();
788807

789808
#if 0
@@ -797,20 +816,31 @@ int main () {
797816
std::string("# This expression will compute the 40th number. \n") +
798817
std::string(" fib(40)\n");
799818
#else
800-
std::vector<std::string> codes = {
801-
std::string("4+5"),
802-
std::string("def foo(a b) a*a + 2*a*b + b*b"),
803-
std::string("def bar(a) foo(a, 4.0) + bar(31337)"),
804-
std::string("extern cos(x)"),
805-
std::string("cos(1.234)"),
806-
std::string("def test(x) 1+2+x"),
807-
std::string("def testToOpt(x) (1+2+x)*(x+(1+2))")
808-
};
819+
//std::vector<std::string> codes = {
820+
// std::string("4+5"),
821+
// std::string("def foo(a b) a*a + 2*a*b + b*b"),
822+
// std::string("def bar(a) foo(a, 4.0) + bar(31337)"),
823+
// std::string("extern cos(x)"),
824+
// std::string("cos(1.234)"),
825+
// std::string("def test(x) 1+2+x"),
826+
// std::string("def testToOpt(x) (1+2+x)*(x+(1+2))")
827+
//};
828+
std::string codes =
829+
//std::string("4+5\n") +
830+
//std::string("def foo(a b) a*a + 2*a*b + b*b\n") +
831+
//std::string("def bar(a) foo(a, 4.0) + bar(31337)\n") +
832+
//std::string("extern cos(x)\n") +
833+
//std::string("cos(1.234)\n") +
834+
//std::string("def test(x) 1+2+x\n") +
835+
std::string("def testToOpt(x) (1+2+x)*(x+(1+2))\n") +
836+
std::string("testToOpt(1)\n");
809837
#endif
810838

811-
for (auto& str : codes) {
812-
parse(str);
813-
}
839+
//for (auto& str : codes) {
840+
// parse(str);
841+
//}
842+
843+
parse(codes);
814844

815845

816846
return 0;

0 commit comments

Comments
 (0)