Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IRTools generates inefficient IR #21

Closed
Roger-luo opened this issue Aug 14, 2019 · 16 comments
Closed

IRTools generates inefficient IR #21

Roger-luo opened this issue Aug 14, 2019 · 16 comments

Comments

@Roger-luo
Copy link
Contributor

Roger-luo commented Aug 14, 2019

The IRTools generates different Julia IR for exact same IR, I find this while trying to generate my own spmd function. You can find implementation here:

https://gist.github.com/Roger-luo/a7cddd8cf5902c5b43e09ea16acc74da

Julia version: 1.1.1

A=VecArray(Matrix{Float64}, rand(2, 2, 1000))

generated IR

julia> @code_ir spmd_lane(tr, 1, A)
1: (%1, %2, %3, %4)
  %5 = $(Expr(:meta, :inline))
  %6 = (Base.getfield)(%4, 1)
  %7 = (Main.spmd_lane)(LinearAlgebra.checksquare, %3, %6)
  %8 = (LinearAlgebra.zero)($(QuoteNode(Float64)))
  %9 = 1:%7
  %10 = (Base.iterate)(%9)
  %11 = %10 === nothing
  %12 = (Base.not_int)(%11)
  br 3 (%8) unless %12
  br 2 (%10, %8)
2: (%13, %14)
  %15 = (Core.getfield)(%13, 1)
  %16 = (Core.getfield)(%13, 2)
  %17 = (Main.spmd_lane)(Base.getindex, %3, %6, %15, %15)
  %18 = %14 + %17
  %19 = (Base.iterate)(%9, %16)
  %20 = %19 === nothing
  %21 = (Base.not_int)(%20)
  br 3 (%18) unless %21
  br 2 (%19, %18)
3: (%22)
  return %22

manual implementation

julia> @code_ir m_spmd_lane(tr, 1, A)
1: (%1, %2, %3, %4)
  %5 = LinearAlgebra.checksquare
  %6 = (Main.spmd_lane)(%5, %3, %4)
  %7 = (Main.zero)($(QuoteNode(Float64)))
  %8 = 1:%6
  %9 = (Base.iterate)(%8)
  %10 = %9 === nothing
  %11 = (Base.not_int)(%10)
  br 3 (%7) unless %11
  br 2 (%9, %7)
2: (%12, %13)
  %14 = (Core.getfield)(%12, 1)
  %15 = (Core.getfield)(%12, 2)
  %16 = Base.getindex
  %17 = (Main.spmd_lane)(%16, %3, %4, %14, %14)
  %18 = %13 + %17
  %19 = (Base.iterate)(%8, %15)
  %20 = %19 === nothing
  %21 = (Base.not_int)(%20)
  br 3 (%18) unless %21
  br 2 (%19, %18)
3: (%22)
  return %22

Note these two look the same in code_ir, but they are different in code_lowered

julia> @code_lowered spmd_lane(tr, 1, A)
CodeInfo(
1$(Expr(:meta, :inline))
│   %2  = (Base.getfield)(xs, 1)
│   %3  = (Main.spmd_lane)(LinearAlgebra.checksquare, k, %2)
│   %4  = (LinearAlgebra.zero)($(QuoteNode(Float64)))
│   %5  = 1:%3%6  = (Base.iterate)(%5)
│   %7  = %6 === nothing%8  = (Base.not_int)(%7)
│         spat_3_1 = %4
│         spat_2_1 = %6
│         spat_2_2 = %4
└──       goto #5 if not %8
2 ─       goto #3
3%14 = (Core.getfield)(spat_2_1, 1)
│   %15 = (Core.getfield)(spat_2_1, 2)
│   %16 = (Main.spmd_lane)(Base.getindex, k, %2, %14, %14)
│   %17 = spat_2_2 + %16%18 = (Base.iterate)(%5, %15)
│   %19 = %18 === nothing%20 = (Base.not_int)(%19)
│         spat_3_1 = %17
│         spat_2_1 = %18
│         spat_2_2 = %17
└──       goto #5 if not %20
4 ─       goto #3
5return spat_3_1
)

manual:

julia> @code_lowered m_spmd_lane(tr, 1, A)
CodeInfo(
1%1  = LinearAlgebra.checksquare
│         n = (Main.spmd_lane)(%1, k, A)
│         out = (Main.zero)($(Expr(:static_parameter, 1)))
│   %4  = 1:n
│         #temp# = (Base.iterate)(%4)%6  = #temp# === nothing%7  = (Base.not_int)(%6)
└──       goto #4 if not %7
2%9  = #temp#
│         i = (Core.getfield)(%9, 1)
│   %11 = (Core.getfield)(%9, 2)
│   %12 = out
│   %13 = Base.getindex
│   %14 = i
│   %15 = (Main.spmd_lane)(%13, k, A, %14, i)
│         out = %12 + %15#temp# = (Base.iterate)(%4, %11)%18 = #temp# === nothing%19 = (Base.not_int)(%18)
└──       goto #4 if not %19
3 ─       goto #2
4return out
)

I'm not sure if this is the reason the generated one is much slower, but if check the benchmark

julia> @benchmark m_spmd_lane(tr, 1, $A)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.824 ns (0.00% GC)
  median time:      8.124 ns (0.00% GC)
  mean time:        8.961 ns (0.00% GC)
  maximum time:     89.727 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     999

julia> @benchmark spmd_lane(tr, 1, $A)
BenchmarkTools.Trial: 
  memory estimate:  144 bytes
  allocs estimate:  5
  --------------
  minimum time:     87.729 ns (0.00% GC)
  median time:      93.754 ns (0.00% GC)
  mean time:        125.178 ns (21.86% GC)
  maximum time:     65.466 μs (99.84% GC)
  --------------
  samples:          10000
  evals/sample:     961

The generated one is 10x slower. Did I miss something, or there is some bug here?

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Aug 14, 2019

Update: I guess this might be the same reason that the generated adjoint of tr (in Zygote) is much slower than the manual one?

@Roger-luo
Copy link
Contributor Author

I find a simpler MWE:

So I'll just simply convert the manually implemented version to IR and then convert it back to CodeInfo with IRTools, and you can see this conversion is not equivalent (from its performance)

The definition of VecArray in my gist is still required

function m_spmd_lane(::typeof(tr), k, A::VecArray{T}) where T
    n = spmd_lane(LinearAlgebra.checksquare, k, A)
    out = zero(T)
    for i in 1:n
        out += spmd_lane(Base.getindex, k, A, i, i)
    end
    return out
end

@generated function test_m_spmd(::typeof(tr), k, A::VecArray)
    T = Tuple{typeof(m_spmd_lane), typeof(tr), k, A}
    m = IRTools.meta(T)
    m === nothing && return
    return IRTools.update!(m.code, IR(m))
end
julia> @benchmark test_m_spmd(tr, 1, A) setup=(A=VecArray(Matrix{Float64}, rand(2, 2, 1000)))
BenchmarkTools.Trial: 
  memory estimate:  128 bytes
  allocs estimate:  4
  --------------
  minimum time:     66.702 ns (0.00% GC)
  median time:      71.824 ns (0.00% GC)
  mean time:        88.557 ns (15.74% GC)
  maximum time:     50.190 μs (99.81% GC)
  --------------
  samples:          10000
  evals/sample:     973

julia> @benchmark m_spmd_lane(tr, 1, A) setup=(A=VecArray(Matrix{Float64}, rand(2, 2, 1000)))
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     8.055 ns (0.00% GC)
  median time:      8.407 ns (0.00% GC)
  mean time:        9.308 ns (0.00% GC)
  maximum time:     60.661 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     999

@MikeInnes
Copy link
Member

How does the type-inferred code look in each case?

Could any of the difference be due to inlining being forced on the original function?

@Roger-luo
Copy link
Contributor Author

code_ir

julia> @code_ir test_m_spmd(tr, 1, A)
1: (%1, %2, %3, %4)
  %5 = LinearAlgebra.checksquare
  %6 = (Main.spmd_lane)(%5, %3, %4)
  %7 = (Main.zero)($(QuoteNode(Float64)))
  %8 = 1:%6
  %9 = (Base.iterate)(%8)
  %10 = %9 === nothing
  %11 = (Base.not_int)(%10)
  br 3 (%7) unless %11
  br 2 (%9, %7)
2: (%12, %13)
  %14 = (Core.getfield)(%12, 1)
  %15 = (Core.getfield)(%12, 2)
  %16 = Base.getindex
  %17 = (Main.spmd_lane)(%16, %3, %4, %14, %14)
  %18 = %13 + %17
  %19 = (Base.iterate)(%8, %15)
  %20 = %19 === nothing
  %21 = (Base.not_int)(%20)
  br 3 (%18) unless %21
  br 2 (%19, %18)
3: (%22)
  return %22


julia> @code_ir m_spmd_lane(tr, 1, A)
1: (%1, %2, %3, %4)
  %5 = LinearAlgebra.checksquare
  %6 = (Main.spmd_lane)(%5, %3, %4)
  %7 = (Main.zero)($(QuoteNode(Float64)))
  %8 = 1:%6
  %9 = (Base.iterate)(%8)
  %10 = %9 === nothing
  %11 = (Base.not_int)(%10)
  br 3 (%7) unless %11
  br 2 (%9, %7)
2: (%12, %13)
  %14 = (Core.getfield)(%12, 1)
  %15 = (Core.getfield)(%12, 2)
  %16 = Base.getindex
  %17 = (Main.spmd_lane)(%16, %3, %4, %14, %14)
  %18 = %13 + %17
  %19 = (Base.iterate)(%8, %15)
  %20 = %19 === nothing
  %21 = (Base.not_int)(%20)
  br 3 (%18) unless %21
  br 2 (%19, %18)
3: (%22)
  return %22

code_lowered

julia> @code_lowered test_m_spmd(tr, 1, A)
CodeInfo(
1%1  = LinearAlgebra.checksquare
│   %2  = (Main.spmd_lane)(%1, k, A)
│   %3  = (Main.zero)($(QuoteNode(Float64)))
│   %4  = 1:%2%5  = (Base.iterate)(%4)
│   %6  = %5 === nothing%7  = (Base.not_int)(%6)
│         spat_3_1 = %3
│         spat_2_1 = %5
│         spat_2_2 = %3
└──       goto #5 if not %7
2 ─       goto #3
3%13 = (Core.getfield)(spat_2_1, 1)
│   %14 = (Core.getfield)(spat_2_1, 2)
│   %15 = Base.getindex
│   %16 = (Main.spmd_lane)(%15, k, A, %13, %13)
│   %17 = spat_2_2 + %16%18 = (Base.iterate)(%4, %14)
│   %19 = %18 === nothing%20 = (Base.not_int)(%19)
│         spat_3_1 = %17
│         spat_2_1 = %18
│         spat_2_2 = %17
└──       goto #5 if not %20
4 ─       goto #3
5return spat_3_1
)

julia> @code_lowered m_spmd_lane(tr, 1, A)
CodeInfo(
1%1  = LinearAlgebra.checksquare
│         n = (Main.spmd_lane)(%1, k, A)
│         out = (Main.zero)($(Expr(:static_parameter, 1)))
│   %4  = 1:n
│         #temp# = (Base.iterate)(%4)%6  = #temp# === nothing%7  = (Base.not_int)(%6)
└──       goto #4 if not %7
2%9  = #temp#
│         i = (Core.getfield)(%9, 1)
│   %11 = (Core.getfield)(%9, 2)
│   %12 = out
│   %13 = Base.getindex
│   %14 = i
│   %15 = (Main.spmd_lane)(%13, k, A, %14, i)
│         out = %12 + %15#temp# = (Base.iterate)(%4, %11)%18 = #temp# === nothing%19 = (Base.not_int)(%18)
└──       goto #4 if not %19
3 ─       goto #2
4return out
)

code_typed

julia> @code_typed test_m_spmd(tr, 1, A)
CodeInfo(
1 ── %1  = (Base.getfield)(A, :storage)::Array{Float64,3}%2  = (Base.arraysize)(%1, 1)::Int64%3  = (Base.arraysize)(%1, 2)::Int64
│          (Base.arraysize)(%1, 3)::Int64%5  = (Base.slt_int)(%2, 0)::Bool%6  = (Base.ifelse)(%5, 0, %2)::Int64%7  = %new(OneTo{Int64}, %6)::OneTo{Int64}%8  = (Base.slt_int)(%3, 0)::Bool%9  = (Base.ifelse)(%8, 0, %3)::Int64%10 = %new(OneTo{Int64}, %9)::OneTo{Int64}%11 = %new(Slice{OneTo{Int64}}, %7)::Slice{OneTo{Int64}}%12 = %new(Slice{OneTo{Int64}}, %10)::Slice{OneTo{Int64}}
└───       goto #6 if not true
2 ── %14 = (Core.tuple)(%11, %12, k)::Tuple{Slice{OneTo{Int64}},Slice{OneTo{Int64}},Int64}
│          (Base.arraysize)(%1, 1)::Int64
│          (Base.arraysize)(%1, 2)::Int64%17 = (Base.arraysize)(%1, 3)::Int64%18 = (Base.slt_int)(%17, 0)::Bool%19 = (Base.ifelse)(%18, 0, %17)::Int64%20 = (Base.sle_int)(1, k)::Bool%21 = (Base.sle_int)(k, %19)::Bool%22 = (Base.and_int)(%20, %21)::Bool%23 = (Base.and_int)(%22, true)::Bool%24 = (Base.and_int)(true, %23)::Bool%25 = (Base.and_int)(true, %24)::Bool
└───       goto #4 if not %25
3 ──       goto #5
4 ──       invoke Base.throw_boundserror(%1::Array{Float64,3}, %14::Tuple{Base.Slice{Base.OneTo{Int64}},Base.Slice{Base.OneTo{Int64}},Int64})::Union{}
└───       $(Expr(:unreachable))::Union{}
5 ┄─       nothing::Nothing
6 ┄─       (Base.arraysize)(%1, 1)::Int64
│          (Base.arraysize)(%1, 2)::Int64
│          (Base.arraysize)(%1, 3)::Int64
│          (Base.arraysize)(%1, 1)::Int64
│          (Base.arraysize)(%1, 2)::Int64
│          (Base.arraysize)(%1, 3)::Int64
└───       goto #7
7 ── %38 = (Base.sub_int)(%6, 0)::Int64%39 = (Base.sub_int)(%9, 0)::Int64%40 = (%38 === %39)::Bool
└───       goto #9 if not %40
8 ──       goto #10
9 ── %43 = (Base.sub_int)(%6, 0)::Int64%44 = (Base.sub_int)(%9, 0)::Int64%45 = (Core.tuple)(%43, %44)::Tuple{Int64,Int64}%46 = invoke Base.print_to_string("matrix is not square: dimensions are "::String, %45::Vararg{Any,N} where N)::String%47 = %new(Base.DimensionMismatch, %46)::DimensionMismatch
│          (LinearAlgebra.throw)(%47)::Union{}
└───       $(Expr(:unreachable))::Union{}
10 ┄       goto #11
11%51 = (Base.sle_int)(1, %38)::Bool%52 = (Base.ifelse)(%51, %38, 0)::Int64%53 = (Base.slt_int)(%52, 1)::Bool
└───       goto #13 if not %53
12%55 = Base.nothing::Const(nothing, false)
└───       goto #14
13%57 = (Core.tuple)(1, 1)::Tuple{Int64,Int64}
└───       goto #14
14%59 = φ (#12 => true, #13 => false)::Bool%60 = φ (#12 => %55, #13 => %57)::Union{Nothing, Tuple{Int64,Int64}}%61 = (Base.not_int)(%59)::Bool
└───       goto #21 if not %61
15nothing::Nothing
16%64 = φ (#15 => %60, #20 => %79)::Union{Nothing, Tuple{Int64,Int64}}%65 = φ (#15 => 0.0, #20 => %70)::Float64%66 = (Core.getfield)(%64, 1)::Int64%67 = (Core.getfield)(%64, 2)::Int64%68 = (Base.getfield)(A, :storage)::Array{Float64,3}%69 = (Base.arrayref)(true, %68, %66, %66, k)::Float64%70 = (Base.add_float)(%65, %69)::Float64%71 = (%67 === %52)::Bool
└───       goto #18 if not %71
17%73 = Base.nothing::Const(nothing, false)
└───       goto #19
18%75 = (Base.add_int)(%67, 1)::Int64%76 = (Core.tuple)(%75, %75)::Tuple{Int64,Int64}
└───       goto #19
19%78 = φ (#17 => true, #18 => false)::Bool%79 = φ (#17 => %73, #18 => %76)::Union{Nothing, Tuple{Int64,Int64}}%80 = (Base.not_int)(%78)::Bool
└───       goto #21 if not %80
20 ─       goto #16
21%83 = φ (#19 => %70, #14 => 0.0)::Float64
└───       return %83
) => Float64

julia> @code_typed m_spmd_lane(tr, 1, A)
CodeInfo(
1 ── %1  = (Base.getfield)(A, :storage)::Array{Float64,3}%2  = (Base.arraysize)(%1, 1)::Int64%3  = (Base.arraysize)(%1, 2)::Int64
│          (Base.arraysize)(%1, 3)::Int64%5  = (Base.slt_int)(%2, 0)::Bool%6  = (Base.ifelse)(%5, 0, %2)::Int64%7  = %new(OneTo{Int64}, %6)::OneTo{Int64}%8  = (Base.slt_int)(%3, 0)::Bool%9  = (Base.ifelse)(%8, 0, %3)::Int64%10 = %new(OneTo{Int64}, %9)::OneTo{Int64}%11 = %new(Slice{OneTo{Int64}}, %7)::Slice{OneTo{Int64}}%12 = %new(Slice{OneTo{Int64}}, %10)::Slice{OneTo{Int64}}
└───       goto #6 if not true
2 ── %14 = (Core.tuple)(%11, %12, k)::Tuple{Slice{OneTo{Int64}},Slice{OneTo{Int64}},Int64}
│          (Base.arraysize)(%1, 1)::Int64
│          (Base.arraysize)(%1, 2)::Int64%17 = (Base.arraysize)(%1, 3)::Int64%18 = (Base.slt_int)(%17, 0)::Bool%19 = (Base.ifelse)(%18, 0, %17)::Int64%20 = (Base.sle_int)(1, k)::Bool%21 = (Base.sle_int)(k, %19)::Bool%22 = (Base.and_int)(%20, %21)::Bool%23 = (Base.and_int)(%22, true)::Bool%24 = (Base.and_int)(true, %23)::Bool%25 = (Base.and_int)(true, %24)::Bool
└───       goto #4 if not %25
3 ──       goto #5
4 ──       invoke Base.throw_boundserror(%1::Array{Float64,3}, %14::Tuple{Base.Slice{Base.OneTo{Int64}},Base.Slice{Base.OneTo{Int64}},Int64})::Union{}
└───       $(Expr(:unreachable))::Union{}
5 ┄─       nothing::Nothing
6 ┄─       (Base.arraysize)(%1, 1)::Int64
│          (Base.arraysize)(%1, 2)::Int64
│          (Base.arraysize)(%1, 3)::Int64
│          (Base.arraysize)(%1, 1)::Int64
│          (Base.arraysize)(%1, 2)::Int64
│          (Base.arraysize)(%1, 3)::Int64
└───       goto #7
7 ── %38 = (Base.sub_int)(%6, 0)::Int64%39 = (Base.sub_int)(%9, 0)::Int64%40 = (%38 === %39)::Bool
└───       goto #9 if not %40
8 ──       goto #10
9 ── %43 = (Base.sub_int)(%6, 0)::Int64%44 = (Base.sub_int)(%9, 0)::Int64%45 = (Core.tuple)(%43, %44)::Tuple{Int64,Int64}%46 = invoke Base.print_to_string("matrix is not square: dimensions are "::String, %45::Vararg{Any,N} where N)::String%47 = %new(Base.DimensionMismatch, %46)::DimensionMismatch
│          (LinearAlgebra.throw)(%47)::Union{}
└───       $(Expr(:unreachable))::Union{}
10 ┄       goto #11
11%51 = (Base.sle_int)(1, %38)::Bool%52 = (Base.ifelse)(%51, %38, 0)::Int64%53 = (Base.slt_int)(%52, 1)::Bool
└───       goto #13 if not %53
12 ─       goto #14
13 ─       goto #14
14%57 = φ (#12 => true, #13 => false)::Bool%58 = φ (#13 => 1)::Int64%59 = φ (#13 => 1)::Int64%60 = (Base.not_int)(%57)::Bool
└───       goto #20 if not %60
15%62 = φ (#14 => 0.0, #19 => %67)::Float64%63 = φ (#14 => %58, #19 => %73)::Int64%64 = φ (#14 => %59, #19 => %74)::Int64%65 = (Base.getfield)(A, :storage)::Array{Float64,3}%66 = (Base.arrayref)(true, %65, %63, %63, k)::Float64%67 = (Base.add_float)(%62, %66)::Float64%68 = (%64 === %52)::Bool
└───       goto #17 if not %68
16 ─       goto #18
17%71 = (Base.add_int)(%64, 1)::Int64
└───       goto #18
18%73 = φ (#17 => %71)::Int64%74 = φ (#17 => %71)::Int64%75 = φ (#16 => true, #17 => false)::Bool%76 = (Base.not_int)(%75)::Bool
└───       goto #20 if not %76
19 ─       goto #15
20%79 = φ (#18 => %67, #14 => 0.0)::Float64
└───       return %79
) => Float64

@MikeInnes
Copy link
Member

They look identical, are there any differences that you can spot?

@Roger-luo
Copy link
Contributor Author

Is this because there is an extra branch generated by IRTools?

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Aug 14, 2019

But the benchmark shows there are some extra allocation happening (128 bytes), which doesn't make sense if it is because an extra branch.

@Roger-luo
Copy link
Contributor Author

I checked again, it seems the only difference here is the extra branch IRTools generates... but my knowledge is limited to tell me why this matters in performance. @MikeInnes do you have time to run the code so you can check the IR?

@MikeInnes
Copy link
Member

I guess this refers to the block 15 that just has nothing in it. I'd be surprised if that was the issue, LLVM can simplify that kind of thing pretty easily. It's worth checking inlining but if it's not that I can try and find some time to play around with it.

@Roger-luo
Copy link
Contributor Author

Do you mean put inlineable! to the ir code? I tried, it didn't make any difference.

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Sep 25, 2019

I did some more investigation, it seems to because IRTools will always add an extra goto when the original code contains a for loop, now I have a much simpler MWE, @MikeInnes you should be able to reproduce this by just copy it.

using LinearAlgebra, IRTools

@generated function ir_tr(A)
    T = Tuple{typeof(tr), A}
    m = IRTools.meta(T)
    m === nothing && return
    return IRTools.update!(m.code, IRTools.IR(m))
end

A = rand(100, 100);

using BenchmarkTools
@benchmark tr($A)
@benchmark ir_tr($A)


@code_ir tr(A)
@code_ir ir_tr(A)

@code_lowered tr(A)
@code_lowered ir_tr(A)
julia> @benchmark ir_tr($A)
BenchmarkTools.Trial: 
  memory estimate:  6.25 KiB
  allocs estimate:  200
  --------------
  minimum time:     3.158 μs (0.00% GC)
  median time:      3.342 μs (0.00% GC)
  mean time:        4.734 μs (24.18% GC)
  maximum time:     6.142 ms (99.85% GC)
  --------------
  samples:          10000
  evals/sample:     8

julia> @benchmark tr($A)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     98.826 ns (0.00% GC)
  median time:      104.615 ns (0.00% GC)
  mean time:        110.573 ns (0.00% GC)
  maximum time:     792.313 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     946

julia> @code_ir tr(A)
1: (%1, %2)
  %3 = (LinearAlgebra.checksquare)(%2)
  %4 = (LinearAlgebra.zero)($(QuoteNode(Float64)))
  %5 = 1:%3
  %6 = (Base.iterate)(%5)
  %7 = %6 === nothing
  %8 = (Base.not_int)(%7)
  br 3 (%4) unless %8
  br 2 (%6, %4)
2: (%9, %10)
  %11 = (Core.getfield)(%9, 1)
  %12 = (Core.getfield)(%9, 2)
  %13 = (Base.getindex)(%2, %11, %11)
  %14 = %10 + %13
  %15 = (Base.iterate)(%5, %12)
  %16 = %15 === nothing
  %17 = (Base.not_int)(%16)
  br 3 (%14) unless %17
  br 2 (%15, %14)
3: (%18)
  return %18

julia> @code_ir ir_tr(A)
1: (%1, %2)
  %3 = (LinearAlgebra.checksquare)(%2)
  %4 = (LinearAlgebra.zero)($(QuoteNode(Float64)))
  %5 = 1:%3
  %6 = (Base.iterate)(%5)
  %7 = %6 === nothing
  %8 = (Base.not_int)(%7)
  br 3 (%4) unless %8
  br 2 (%6, %4)
2: (%9, %10)
  %11 = (Core.getfield)(%9, 1)
  %12 = (Core.getfield)(%9, 2)
  %13 = (Base.getindex)(%2, %11, %11)
  %14 = %10 + %13
  %15 = (Base.iterate)(%5, %12)
  %16 = %15 === nothing
  %17 = (Base.not_int)(%16)
  br 3 (%14) unless %17
  br 2 (%15, %14)
3: (%18)
  return %18

julia> @code_lowered tr(A)
CodeInfo(
1 ─       n = (LinearAlgebra.checksquare)(A)
│         t = (LinearAlgebra.zero)($(Expr(:static_parameter, 1)))
│   %3  = 1:n
│         #temp# = (Base.iterate)(%3)%5  = #temp# === nothing%6  = (Base.not_int)(%5)
└──       goto #4 if not %6
2%8  = #temp#
│         i = (Core.getfield)(%8, 1)
│   %10 = (Core.getfield)(%8, 2)
│   %11 = t
│   %12 = (Base.getindex)(A, i, i)
│         t = %11 + %12#temp# = (Base.iterate)(%3, %10)%15 = #temp# === nothing%16 = (Base.not_int)(%15)
└──       goto #4 if not %16
3 ─       goto #2
4return t
)

julia> @code_lowered ir_tr(A)
CodeInfo(
1%1  = (LinearAlgebra.checksquare)(A)
│   %2  = (LinearAlgebra.zero)($(QuoteNode(Float64)))
│   %3  = 1:%1%4  = (Base.iterate)(%3)
│   %5  = %4 === nothing%6  = (Base.not_int)(%5)
│         spat_3_1 = %2
│         spat_2_1 = %4
│         spat_2_2 = %2
└──       goto #5 if not %6
2 ─       goto #3
3%12 = (Core.getfield)(spat_2_1, 1)
│   %13 = (Core.getfield)(spat_2_1, 2)
│   %14 = (Base.getindex)(A, %12, %12)
│   %15 = spat_2_2 + %14%16 = (Base.iterate)(%3, %13)
│   %17 = %16 === nothing%18 = (Base.not_int)(%17)
│         spat_3_1 = %15
│         spat_2_1 = %16
│         spat_2_2 = %15
└──       goto #5 if not %18
4 ─       goto #3
5return spat_3_1
)

@MikeInnes
Copy link
Member

It would be great if you can try this again; we just added a patch that makes intrinsics significantly faster. If it's still a problem we can look into doing IR simplification do avoid redundant blocks etc.

@Roger-luo
Copy link
Contributor Author

Roger-luo commented Feb 2, 2020

It gives segment fault now. I'm not sure if API has changed, but this what I tried

using LinearAlgebra, IRTools

@generated function ir_tr(A)
    T = Tuple{typeof(tr), A}
    m = IRTools.meta(T)
    m === nothing && return
    ir = IRTools.IR(m)
    return IRTools.Inner.update!(m.code, ir)
end

A = rand(100, 100);

@code_ir ir_tr(A)
ir_tr(A)

@code_ir and @code_lowered work fine, but somehow ir_tr still segment fault. I can't tell which is wrong however. This is on the current IRTools release version 0.3.1 and Julia
1.3.0.

@MikeInnes
Copy link
Member

Can you try with an @dynamo? It's pretty easy to go wrong somewhere with the raw generated functions, but that should work.

@Roger-luo
Copy link
Contributor Author

still not changed.

using LinearAlgebra, IRTools
using IRTools: IR
using IRTools: @dynamo

@dynamo function spmd(f, x)
    ir = IR(f, x)
    return ir
end

and the benchmark

A = rand(100, 100)


julia> @benchmark tr($A)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     100.501 ns (0.00% GC)
  median time:      107.775 ns (0.00% GC)
  mean time:        111.301 ns (0.00% GC)
  maximum time:     350.092 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     942

the same function, but generated using IRTools

julia> @benchmark spmd($tr, $A)
BenchmarkTools.Trial: 
  memory estimate:  6.25 KiB
  allocs estimate:  200
  --------------
  minimum time:     3.325 μs (0.00% GC)
  median time:      3.471 μs (0.00% GC)
  mean time:        3.882 μs (7.28% GC)
  maximum time:     289.762 μs (98.75% GC)
  --------------
  samples:          10000
  evals/sample:     8

@MikeInnes
Copy link
Member

Typed lowered code for both:

julia> @code_typed optimize=false tr(A)
CodeInfo(
1 ─       (n = LinearAlgebra.checksquare(A))::Int64
│         (t = LinearAlgebra.zero($(Expr(:static_parameter, 1))))::Core.Compiler.Const(0.0, false)
│   %3  = (1:n)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│         (@_5 = Base.iterate(%3))::Union{Nothing, Tuple{Int64,Int64}}%5  = (@_5 === nothing)::Bool%6  = Base.not_int(%5)::Bool
└──       goto #4 if not %6
2%8  = @_5::Tuple{Int64,Int64}::Tuple{Int64,Int64}
│         (i = Core.getfield(%8, 1))::Int64%10 = Core.getfield(%8, 2)::Int64%11 = t::Float64%12 = Base.getindex(A, i, i)::Float64
│         (t = %11 + %12)::Float64
│         (@_5 = Base.iterate(%3, %10))::Union{Nothing, Tuple{Int64,Int64}}%15 = (@_5 === nothing)::Bool%16 = Base.not_int(%15)::Bool
└──       goto #4 if not %16
3 ─       goto #2
4return t
) => Float64

julia> @code_typed optimize=false spmd(tr, A)
CodeInfo(
1 ─       Base.getfield(args, 1)::Core.Compiler.Const(LinearAlgebra.tr, false)
│   %2  = Base.getfield(args, 2)::Array{Float64,2}%3  = LinearAlgebra.checksquare(%2)::Int64%4  = LinearAlgebra.zero($(QuoteNode(Float64)))::Core.Compiler.Const(0.0, false)
│   %5  = (1:%3)::Core.Compiler.PartialStruct(UnitRange{Int64}, Any[Core.Compiler.Const(1, false), Int64])
│   %6  = Base.iterate(%5)::Union{Nothing, Tuple{Int64,Int64}}%7  = (%6 === nothing)::Bool%8  = Base.not_int(%7)::Bool
│         (phi_3_1 = %4)::Core.Compiler.Const(0.0, false)
│         (phi_2_1 = %6)::Union{Nothing, Tuple{Int64,Int64}}
│         (phi_2_2 = %4)::Core.Compiler.Const(0.0, false)
└──       goto #5 if not %8
2 ─       goto #3
3%14 = Core.getfield(phi_2_1, 1)::Int64%15 = Core.getfield(phi_2_1, 2)::Int64%16 = Base.getindex(%2, %14, %14)::Float64%17 = (phi_2_2 + %16)::Float64%18 = Base.iterate(%5, %15)::Union{Nothing, Tuple{Int64,Int64}}%19 = (%18 === nothing)::Bool%20 = Base.not_int(%19)::Bool
│         (phi_3_1 = %17)::Float64
│         (phi_2_1 = %18)::Union{Nothing, Tuple{Int64,Int64}}
│         (phi_2_2 = %17)::Float64
└──       goto #5 if not %20
4 ─       goto #3
5return phi_3_1
) => Float64

Zooming in a bit:

# tr(A)
│         (@_5 = Base.iterate(%3, %10))::Union{Nothing, Tuple{Int64,Int64}}%15 = (@_5 === nothing)::Bool%16 = Base.not_int(%15)::Bool
└──       goto #4 if not %16

# spmd(tr, A)%18 = Base.iterate(%5, %15)::Union{Nothing, Tuple{Int64,Int64}}%19 = (%18 === nothing)::Bool%20 = Base.not_int(%19)::Bool
│         (phi_3_1 = %17)::Float64
│         (phi_2_1 = %18)::Union{Nothing, Tuple{Int64,Int64}}
│         (phi_2_2 = %17)::Float64
└──       goto #5 if not %20

What's happening here is that when we compare @_5 to nothing Julia can use this information to narrow the type of @_5 at the start of the loop (we know the condition was true so the @_5 can't be nothing). In the second case this is still possible in principle but more complex, and it foils type inference.

This should be pretty easy to fix: if %18 is defined in the current block we should do %18 = phi_2_1 = ... instead of %18 = ...; phi_2_1 = %18; we may also need to replace uses of %18 with the slot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants