Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update toy brutus #34

Merged
merged 5 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 69 additions & 49 deletions examples/brutus.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,55 @@
"""
Brutus is a toy implementation of a Julia typed IR to MLIR conversion, the name
is a reference to the [brutus](https://github.com/JuliaLabs/brutus) project from
the MIT JuliaLabs which performs a similar conversion (with a lot more language constructs supported)
but from C++.
"""
module Brutus

import LLVM
using MLIR.IR
using MLIR.Dialects: arith, func, cf, std
using MLIR.Dialects: arith, func, cf
using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode

const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64}

module Predicates
const eq = 0
const ne = 1
const slt = 2
const sle = 3
const sgt = 4
const sge = 5
const ult = 6
const ule = 7
const ugt = 8
const uge = 9
end

function cmpi_pred(predicate)
function(ops; loc=Location())
arith.cmpi(predicate, ops; loc)
function (ops...; location = Location())
arith.cmpi(ops...; result=IR.MLIRType(Bool), predicate, location)
end
end

function single_op_wrapper(fop)
(block::Block, args::Vector{Value}; loc=Location()) -> push!(block, fop(args; loc))
(block::Block, args::Vector{Value}; location=Location()) -> push!(block, fop(args...; location))
end

const intrinsics_to_mlir = Dict([
Base.add_int => single_op_wrapper(arith.addi),
Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)),
Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)),
Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)),
Base.sle_int => single_op_wrapper(cmpi_pred(Predicates.sle)),
Base.slt_int => single_op_wrapper(cmpi_pred(Predicates.slt)),
Base.:(===) => single_op_wrapper(cmpi_pred(Predicates.eq)),
Base.mul_int => single_op_wrapper(arith.muli),
Base.mul_float => single_op_wrapper(arith.mulf),
Base.not_int => function(block, args; loc=Location())
Base.not_int => function (block, args; location=Location())
arg = only(args)
ones = push!(block, arith.constant(-1, IR.get_type(arg); loc)) |> IR.get_result
push!(block, arith.xori(Value[arg, ones]; loc))
mT = IR.get_type(arg)
T = IR.julia_type(mT)
ones = push!(block, arith.constant(value=typemax(UInt64) % T;
result=mT, location)) |> IR.get_result
push!(block, arith.xori(arg, ones; location))
end,
])

Expand Down Expand Up @@ -85,7 +107,7 @@ function code_mlir(f, types)

values = Vector{Value}(undef, length(ir.stmts))

for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",))
for dialect in ("func", "cf")
IR.get_or_load_dialect!(dialect)
end

Expand All @@ -107,7 +129,7 @@ function code_mlir(f, types)
elseif x isa Core.Argument
IR.get_argument(entry_block, x.n - 1)
elseif x isa BrutusScalar
IR.get_result(push!(current_block, arith.constant(x)))
IR.get_result(push!(current_block, arith.constant(;value=x)))
else
error("could not use value $x inside MLIR")
end
Expand Down Expand Up @@ -137,8 +159,8 @@ function code_mlir(f, types)
fop! = intrinsics_to_mlir[called_func]
args = get_value.(@view inst.args[begin+1:end])

loc = Location(string(line.file), line.line, 0)
res = IR.get_result(fop!(current_block, args; loc))
location = Location(string(line.file), line.line, 0)
res = IR.get_result(fop!(current_block, args; location))

values[sidx] = res
elseif inst isa PhiNode
Expand All @@ -148,9 +170,8 @@ function code_mlir(f, types)
elseif inst isa GotoNode
args = get_value.(collect_value_arguments(ir, block_id, inst.label))
dest = blocks[inst.label]
loc = Location(string(line.file), line.line, 0)
brop = LLVM.version() >= v"15" ? cf.br : std.br
push!(current_block, brop(dest, args; loc))
location = Location(string(line.file), line.line, 0)
push!(current_block, cf.br(args; dest, location))
elseif inst isa GotoIfNot
false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest))
cond = get_value(inst.cond)
Expand All @@ -160,15 +181,13 @@ function code_mlir(f, types)
other_dest = blocks[other_dest]
dest = blocks[inst.dest]

loc = Location(string(line.file), line.line, 0)
cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br
cond_br = cond_brop(cond, other_dest, dest, true_args, false_args; loc)
location = Location(string(line.file), line.line, 0)
cond_br = cf.cond_br(cond, true_args, false_args; trueDest=other_dest, falseDest=dest, location)
push!(current_block, cond_br)
elseif inst isa ReturnNode
line = ir.linetable[stmt[:line]]
retop = LLVM.version() >= v"15" ? func.return_ : std.return_
loc = Location(string(line.file), line.line, 0)
push!(current_block, retop([get_value(inst.val)]; loc))
location = Location(string(line.file), line.line, 0)
push!(current_block, func.return_([get_value(inst.val)]; location))
elseif Meta.isexpr(inst, :code_coverage_effect)
# Skip
else
Expand All @@ -194,13 +213,13 @@ function code_mlir(f, types)

ftype = MLIRType(input_types => result_types)
op = IR.create_operation(
LLVM15 ? "func.func" : "builtin.func",
"func.func",
Location();
attributes = [
attributes=[
NamedAttribute("sym_name", IR.Attribute(string(func_name))),
NamedAttribute(LLVM15 ? "function_type" : "type", IR.Attribute(ftype)),
NamedAttribute("function_type", IR.Attribute(ftype)),
],
owned_regions = Region[region],
owned_regions=Region[region],
result_inference=false,
)

Expand All @@ -227,16 +246,18 @@ macro code_mlir(call)
end
end

export code_mlir, @code_mlir

end # module Brutus

# ---

function pow(x::F, n) where {F}
p = one(F)
for _ in 1:n
p *= x
end
p
p = one(F)
for _ in 1:n
p *= x
end
p
end

function f(x)
Expand All @@ -252,29 +273,28 @@ end
using Test
using MLIR.IR, MLIR

ctx = Context()
# IR.enable_multithreading!(ctx, false)
fptr = IR.context!(IR.Context()) do
op = Brutus.code_mlir(pow, Tuple{Int,Int})

op = Brutus.code_mlir(pow, Tuple{Int, Int})
mod = MModule(Location())
body = IR.get_body(mod)
push!(body, op)

mod = MModule(Location())
body = IR.get_body(mod)
push!(body, op)
pm = IR.PassManager()
opm = IR.OpPassManager(pm)

pm = IR.PassManager()
opm = IR.OpPassManager(pm)
# IR.enable_ir_printing!(pm)
IR.enable_verifier!(pm, true)

# IR.enable_ir_printing!(pm)
IR.enable_verifier!(pm, true)
MLIR.API.mlirRegisterAllPasses()
MLIR.API.mlirRegisterAllLLVMTranslations(IR.context())
IR.add_pipeline!(opm, "convert-arith-to-llvm,convert-func-to-llvm")

MLIR.API.mlirRegisterAllPasses()
MLIR.API.mlirRegisterAllLLVMTranslations(ctx)
IR.add_pipeline!(opm, Brutus.LLVM.version() >= v"15" ? "convert-arith-to-llvm,convert-func-to-llvm" : "convert-std-to-llvm")
IR.run!(pm, mod)

IR.run!(pm, mod)

jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL)
fptr = MLIR.API.mlirExecutionEngineLookup(jit, "pow")
jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL)
MLIR.API.mlirExecutionEngineLookup(jit, "pow")
end

x, y = 3, 4

Expand Down
27 changes: 2 additions & 25 deletions src/Dialects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,7 @@ function namedattribute(name, val::NamedAttribute)
return val
end

operandsegmentsizes(segments) = namedattribute(
"operand_segment_sizes",
Attribute(API.mlirDenseI32ArrayGet(
context().context,
length(segments),
Int32.(segments)
)))
operandsegmentsizes(segments) = namedattribute("operand_segment_sizes", Attribute(Int32.(segments)))

let
ver = string(LLVM.version().major)
Expand All @@ -27,26 +21,9 @@ let
You might need a newer version of MLIR.jl for this version of Julia.""")
end

for path in readdir(joinpath(@__DIR__, "Dialects"); join=true)
for path in readdir(dir; join=true)
include(path)
end
end

# module arith

# module Predicates
# const eq = 0
# const ne = 1
# const slt = 2
# const sle = 3
# const sgt = 4
# const sge = 5
# const ult = 6
# const ule = 7
# const ugt = 8
# const uge = 9
# end

# end # module arith

end # module Dialects
26 changes: 13 additions & 13 deletions src/Dialects/14/ArmNeon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,20 @@ corresponding entry of `a`:
res[i] := a[i] + dot_product(b[i, ...], c[i, ...])
```
"""
function 2d_sdot(a::Value, b::Value, c::Value; res::MLIRType, location=Location())
results = MLIRType[res, ]
operands = Value[a, b, c, ]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
# function 2d_sdot(a::Value, b::Value, c::Value; res::MLIRType, location=Location())
# results = MLIRType[res, ]
# operands = Value[a, b, c, ]
# owned_regions = Region[]
# successors = Block[]
# attributes = NamedAttribute[]

create_operation(
"arm_neon.2d.sdot", location;
operands, owned_regions, successors, attributes,
results=results,
result_inference=false
)
end
# create_operation(
# "arm_neon.2d.sdot", location;
# operands, owned_regions, successors, attributes,
# results=results,
# result_inference=false
# )
# end

"""
`intr_sdot`
Expand Down
2 changes: 1 addition & 1 deletion src/Dialects/14/GPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ module attributes {gpu.container_module} {
}
```
"""
function launch_func(asyncDependencies::Vector{Value}, gridSizeX::Value, gridSizeY::Value, gridSizeZ::Value, blockSizeX::Value, blockSizeY::Value, blockSizeZ::Value, dynamicSharedMemorySize=nothing::Union{Nothing, Value}, operands::Vector{Value}; asyncToken=nothing::Union{Nothing, MLIRType}, kernel, location=Location())
function launch_func(asyncDependencies::Vector{Value}, gridSizeX::Value, gridSizeY::Value, gridSizeZ::Value, blockSizeX::Value, blockSizeY::Value, blockSizeZ::Value, dynamicSharedMemorySize=nothing::Union{Nothing, Value}; operands::Vector{Value}, asyncToken=nothing::Union{Nothing, MLIRType}, kernel, location=Location())
results = MLIRType[]
operands = Value[asyncDependencies..., gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ, operands..., ]
owned_regions = Region[]
Expand Down
Loading
Loading