-
Notifications
You must be signed in to change notification settings - Fork 8
/
kcontractor.cr
284 lines (236 loc) · 8.01 KB
/
kcontractor.cr
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
require "./visitor"
require "./llvm_binding_extensions"
require "./kast"
module Kazoo
class Contractor
include CLTK::Visitor(LLVM::Builder)
getter :main_module
@fpm : LLVM::FunctionPassManager
@ctx : LLVM::Context
@main_module : LLVM::Module
@builder : LLVM::Builder
def zero
@ctx.float.const_double(0.0)
end
def initialize()
LLVM.init_x86
@ctx = LLVM::Context.new
@main_module = @ctx.new_module("Kazoo JIT")
@builder = @ctx.new_builder
@st = {} of String => LLVM::Value
@env = @builder
# Execution Engine
@engine = LLVM::JITCompiler.new(@main_module)
@fpm = @main_module.new_function_pass_manager
# Add passes to the Function Pass Manager.
LibLLVM.add_instruction_combining_pass(@fpm)
LibLLVM.add_reassociate_pass(@fpm)
LibLLVM.add_gvn_pass(@fpm)
LibLLVM.add_cfg_simplification_pass(@fpm)
LibLLVM.add_promote_memory_to_register_pass(@fpm)
end
def execute(func, args = [] of LLVM::GenericValue)
@engine.run_function(func, @ctx)
end
def optimize(func : LLVM::Function)
@fpm.run { |runner | runner.run(func) }
func
end
def with_builder(builder)
old_env = @env
@env = builder
result = with builder yield
@env = old_env
result
end
def add(ast)
case ast
when Function, Prototype then visit ast
when Expression then visit Function.new(
proto: Kazoo::Prototype.new(name: "", arg_names: [] of String),
body: ast)
else raise "Attempting to add an unhandled node type to the JIT."
end.as(LLVM::Function)
end
on Assign do |node|
right = visit node.right
loc =
if @st.has_key?(node.name)
@st[node.name]
else
@st[node.name] = alloca @ctx.float, node.name
end
store(right, loc)
right
end
on Variable do |node|
if @st[node.name]?
load @st[node.name], node.name
else
raise "Unitialized variable \"#{node.name}\"."
end
end
on Call do |node|
callee = @main_module.functions[node.name]
if !callee
raise "Unknown function referenced."
end
if callee.params.size != node.args.size
raise "Function #{node.name} expected #{callee.params.size} argument(s) but was called with #{node.args.size}."
end
call callee,
node.args.map { |arg| (visit arg).as(LLVM::Value) },
"calltmp"
end
on Prototype do |node|
func = begin
# get function if it"s already defined
@main_module.functions[node.name].tap do |func|
if LibLLVM.count_basic_blocks(func) != 0
raise "Redefinition of function #{node.name}."
elsif func.params.size != node.arg_names.size
raise "Redefinition of function #{node.name} with different number of arguments."
end
end
rescue
# add function, if not
@main_module.functions.add(node.name, Array.new(node.arg_names.size, @ctx.float), @ctx.float)
end
# Name each of the function paramaters.
func.tap do
node.arg_names.each_with_index do |name, i|
func.params[i].name = name
end
end
end
on Function do |node|
# Reset the symbol table.
# @st.clear
# Translate the function"s prototype.
func = visit node.proto.as(Prototype)
func.params.to_a.each do |param|
@st[param.name] = alloca @ctx.float, param.name
store param, @st[param.name]
end
# Create a new basic block to insert into, allocate space for
# the arguments, store their values, translate the expression,
# and set its value as the return value.
func.basic_blocks.append("entry") do |builder|
with_builder(builder) do
body = node.body
case body
when ExpressionList then
expressions = body.expressions
expressions.each_with_index do |expression, index|
if index < (expressions.size - 1)
visit expression
else
ret visit(expression)
end
end
else
ret visit(body)
end
end
end
# Verify the function and return it.
func.tap do |func|
LibLLVM.verify_function(func, LLVM::VerifierFailureAction::ReturnStatusAction )
end # .tap &.dump
func
end
on For do |node|
ph_bb = insert_block
func = LLVM::Function.new LibLLVM.get_basic_block_parent(ph_bb)
loop_cond_bb = func.basic_blocks.append("loop_cond")
loc = alloca @ctx.float, node.var
store (visit node.init), loc
old_var = @st[node.var]? ? @st[node.var] : loc
@st[node.var] = loc
br loop_cond_bb
position_at_end(loop_cond_bb)
branch_body = visit(node.cond)
end_cond = fcmp LLVM::RealPredicate::ONE, branch_body , zero, "loopcond"
loop_bb1 = nil
loop_bb0 = func.basic_blocks.append("loop") do |builder|
with_builder(builder) do
visit node.body
loop_bb1 = builder.insert_block
step_val = visit node.step
var = load loc, node.var
next_var = fadd var, step_val, "nextvar"
store next_var, loc
br loop_cond_bb
end
end
# Add the conditional branch to the loop_cond_bb.
after_bb = func.basic_blocks.append("afterloop") do |builder|
with_builder(builder) do
position_at_end(loop_cond_bb)
end
end
cond end_cond, loop_bb0, after_bb
position_at_end(after_bb)
@st[node.var] = old_var
zero
end
on If do |node|
# IF
cond_val = fcmp LLVM::RealPredicate::UGT, (visit node.cond), zero, "ifcond"
table = LLVM::PhiTable.new
start_bb = insert_block
func = LLVM::Function.new LibLLVM.get_basic_block_parent(start_bb)
## THEN
new_then_bb = nil
then_bb = func.basic_blocks.append("then") do |builder|
then_val, new_then_bb = with_builder(builder) do
{ visit(node.elseExp), builder.insert_block }
end
table.add(new_then_bb, then_val.as(LLVM::Value))
end
## ELSE
new_else_bb = nil
else_bb = func.basic_blocks.append("else") do |builder|
else_val, new_else_bb = with_builder(builder) do
{ visit(node.thenExp), builder.insert_block }
end
table.add(new_else_bb, else_val.as(LLVM::Value))
end
merge_bb = func.basic_blocks.append("merge")
position_at_end(merge_bb)
phi_inst = phi @ctx.float, table, "iftmp"
position_at_end(start_bb)
cond cond_val, then_bb, else_bb
position_at_end(new_then_bb.not_nil!)
br merge_bb
position_at_end(new_else_bb.not_nil!)
br merge_bb
phi_inst.tap { position_at_end merge_bb }
end
on Binary do |node|
left = visit node.left
right = visit node.right
case node
when Add then fadd(left, right, "addtmp")
when Sub then fsub(left, right, "subtmp")
when Mul then fmul(left, right, "multmp")
when Div then fdiv(left, right, "divtmp")
when LT then ui2fp(fcmp(LLVM::RealPredicate::ULT, left, right, "cmptmp"), @ctx.float, "lttmp")
when GT then ui2fp(fcmp(LLVM::RealPredicate::UGT, left, right, "cmptmp"), @ctx.float, "gttmp")
when Eql then ui2fp(fcmp(LLVM::RealPredicate::UEQ, left, right, "cmptmp"), @ctx.float, "eqtmp")
when Or
left = fcmp LLVM::RealPredicate::UNE, left, zero, "lefttmp"
right = fcmp LLVM::RealPredicate::UNE, right, zero, "righttmp"
ui2fp (@env.or left, right, "ortmp"), @ctx.float, "orltmp"
when And
left = fcmp LLVM::RealPredicate::UNE, left, zero, "lefttmp"
right = fcmp LLVM::RealPredicate::UNE, right, zero, "rightmp"
ui2fp (@env.and left, right, "andtmp"), @ctx.float, "andtmp"
else right
end
end
on ANumber do |node|
@ctx.float.const_double(node.value)
end
end
end