-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lint long chains of into/float additions/multiplications #1243
Comments
Should this be smart and notice add-assign too? i.e.:
Obviously an esoteric example. But if the floats all come from different functions, it might not be as noticeable. |
Isn't this optimized already by the compiler in integer arithmetic? Since it is - save for overflow errors - commutative and associative. Float addition and substraction are not associative (and mul / div are, but only sometimes). Therefore, this lint would still be a good help for floating point code. What'd be the threshold of the chains? Should it be fixed, or configurable? Should it be the same for addition and multiplication? |
This may be? I'm not sure. It's worth playing with http://rust.godbolt.org/ to see if that's the case. Sometimes lints that provide "optimizations" also provide clarity of code, so even if the compiler is smart enough we still have them. But in this case the parens make it worse so we should only add this if it really has a perf benefit. |
Floating point arithmeticThis is where things got interesting! Let's first do 4 parameters. n = 4Here's the addition code: pub fn add_4(a: f32, b: f32, c: f32, d: f32) -> f32 {
a + b + c + d
}
pub fn add_4_p(a: f32, b: f32, c: f32, d: f32) -> f32 {
(a + b) + (c + d)
} The generated IR was different!define float @_ZN7example5add_417hbd9627e19913a4a5E(float %a, float %b, float %c, float %d) unnamed_addr #0 !dbg !5 {
%0 = fadd float %a, %b, !dbg !8
%1 = fadd float %0, %c, !dbg !8
%2 = fadd float %1, %d, !dbg !8
ret float %2, !dbg !9
} Here without parenthesis we can see a linear progression through the chain. define float @_ZN7example7add_4_p17h45fdde80c80ae9f4E(float %a, float %b, float %c, float %d) unnamed_addr #0 !dbg !10 {
%0 = fadd float %a, %b, !dbg !11
%1 = fadd float %c, %d, !dbg !12
%2 = fadd float %0, %1, !dbg !11
ret float %2, !dbg !13
} With parenthesis we can see a tree approach to the operations! Which makes a lot of sense: LLVM shouldn't be "normalizing" the order of operations, if doing so changes the end result. Which for floating points, it definitely would! The multiplication code's results are strikingly similar.pub fn mul_4(a: f32, b: f32, c: f32, d: f32) -> f32 {
a * b * c * d
}
pub fn mul_4_p(a: f32, b: f32, c: f32, d: f32) -> f32 {
(a * b) * (c * d)
} define float @_ZN7example5mul_417h978c4d41e342ced3E(float %a, float %b, float %c, float %d) unnamed_addr #0 !dbg !22 {
%0 = fmul float %a, %b, !dbg !23
%1 = fmul float %0, %c, !dbg !23
%2 = fmul float %1, %d, !dbg !23
ret float %2, !dbg !24
}
define float @_ZN7example7mul_4_p17h539935c5f42320aeE(float %a, float %b, float %c, float %d) unnamed_addr #0 !dbg !25 {
%0 = fmul float %a, %b, !dbg !26
%1 = fmul float %c, %d, !dbg !27
%2 = fmul float %0, %1, !dbg !26
ret float %2, !dbg !28 n = 8Now things got really interesting. Here's the addition operations:pub fn add_8(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32) -> f32 {
a + b + c + d + e + f + g + h
}
pub fn add_8_p(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32) -> f32 {
((a + b) + (c + d)) + ((e + f) + (g + h))
} And the generated IR:define float @_ZN7example5add_817hef9e967ef9fb2f7eE(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h) unnamed_addr #0 !dbg !14 {
%0 = fadd float %a, %b, !dbg !15
%1 = fadd float %0, %c, !dbg !15
%2 = fadd float %1, %d, !dbg !15
%3 = fadd float %2, %e, !dbg !15
%4 = fadd float %3, %f, !dbg !15
%5 = fadd float %4, %g, !dbg !15
%6 = fadd float %5, %h, !dbg !15
ret float %6, !dbg !16
}
define float @_ZN7example7add_8_p17hf411fbb0fd42e0a7E(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h) unnamed_addr #0 !dbg !17 {
%0 = insertelement <2 x float> undef, float %a, i32 0, !dbg !18
%1 = insertelement <2 x float> %0, float %e, i32 1, !dbg !18
%2 = insertelement <2 x float> undef, float %b, i32 0, !dbg !18
%3 = insertelement <2 x float> %2, float %f, i32 1, !dbg !18
%4 = fadd <2 x float> %1, %3, !dbg !18
%5 = insertelement <2 x float> undef, float %c, i32 0, !dbg !19
%6 = insertelement <2 x float> %5, float %g, i32 1, !dbg !19
%7 = insertelement <2 x float> undef, float %d, i32 0, !dbg !19
%8 = insertelement <2 x float> %7, float %h, i32 1, !dbg !19
%9 = fadd <2 x float> %6, %8, !dbg !19
%10 = fadd <2 x float> %4, %9, !dbg !20
%11 = extractelement <2 x float> %10, i32 0, !dbg !20
%12 = extractelement <2 x float> %10, i32 1, !dbg !20
%13 = fadd float %11, %12, !dbg !20
ret float %13, !dbg !21
} What LLVM is doing in the second case is vector operations. I've checked the generated code (please feel free to verify it as well) and it does indeed exactly what the parenthesis say. The cool thing here is that it only does 4 addition operations instead of 7, which if assuming We can see an exact parallel between the addition code and the multiplication code's behavior when it comes to vectorization: pub fn mul_8(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32) -> f32 {
a * b * c * d * e * f * g * h
}
pub fn mul_8_p(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32) -> f32 {
((a * b) * (c * d)) * ((e * f) * (g * h))
} define float @_ZN7example5mul_817he6344433229b5d87E(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h) unnamed_addr #0 !dbg !29 {
%0 = fmul float %a, %b, !dbg !30
%1 = fmul float %0, %c, !dbg !30
%2 = fmul float %1, %d, !dbg !30
%3 = fmul float %2, %e, !dbg !30
%4 = fmul float %3, %f, !dbg !30
%5 = fmul float %4, %g, !dbg !30
%6 = fmul float %5, %h, !dbg !30
ret float %6, !dbg !31
}
define float @_ZN7example7mul_8_p17h7be0422ddf0675e3E(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h) unnamed_addr #0 !dbg !32 {
%0 = insertelement <2 x float> undef, float %a, i32 0, !dbg !33
%1 = insertelement <2 x float> %0, float %e, i32 1, !dbg !33
%2 = insertelement <2 x float> undef, float %b, i32 0, !dbg !33
%3 = insertelement <2 x float> %2, float %f, i32 1, !dbg !33
%4 = fmul <2 x float> %1, %3, !dbg !33
%5 = insertelement <2 x float> undef, float %c, i32 0, !dbg !34
%6 = insertelement <2 x float> %5, float %g, i32 1, !dbg !34
%7 = insertelement <2 x float> undef, float %d, i32 0, !dbg !34
%8 = insertelement <2 x float> %7, float %h, i32 1, !dbg !34
%9 = fmul <2 x float> %6, %8, !dbg !34
%10 = fmul <2 x float> %4, %9, !dbg !35
%11 = extractelement <2 x float> %10, i32 0, !dbg !35
%12 = extractelement <2 x float> %10, i32 1, !dbg !35
%13 = fmul float %11, %12, !dbg !35
ret float %13, !dbg !36
} ConclusionsThis is very interesting. Since for n=4 it didn't vectorize, and knowing how performance-focused is LLVM, I think we can assume that the costs of vectorization are higher than the benefits at some point when going down from n=8 to n=4. Of course, this asks for two things... benchmarking, and checking what's the codegen for n = 16. Guess which one I'm gonna do now! :D |
Bonus Track: n = 16Well, it had to be done. Here are the results. Original codepub fn add_8(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32,
i: f32, j: f32, k: f32, l: f32,
m: f32, n: f32, o: f32, p: f32) -> f32 {
a + b + c + d + e + f + g + h +
i + j + k + l + m + n + o + p
}
pub fn add_8_p(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32,
i: f32, j: f32, k: f32, l: f32,
m: f32, n: f32, o: f32, p: f32) -> f32 {
(((a + b) + (c + d)) + ((e + f) + (g + h))) +
(((i + j) + (k + l)) + ((m + n) + (o + p)))
}
pub fn mul_8(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32,
i: f32, j: f32, k: f32, l: f32,
m: f32, n: f32, o: f32, p: f32) -> f32 {
a * b * c * d * e * f * g * h *
i * j * k * l * m * n * o * p
}
pub fn mul_8_p(
a: f32, b: f32, c: f32, d: f32,
e: f32, f: f32, g: f32, h: f32,
i: f32, j: f32, k: f32, l: f32,
m: f32, n: f32, o: f32, p: f32) -> f32 {
(((a * b) * (c * d)) * ((e * f) * (g * h))) *
(((i * j) * (k * l)) * ((m * n) * (o * p)))
} Generated IRdefine float @_ZN7example5add_817hfacaa9e4a302cdeaE(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h, float %i, float %j, float %k, float %l, float %m, float %n, float %o, float %p) unnamed_addr #0 !dbg !5 {
%0 = fadd float %a, %b, !dbg !8
%1 = fadd float %0, %c, !dbg !8
%2 = fadd float %1, %d, !dbg !8
%3 = fadd float %2, %e, !dbg !8
%4 = fadd float %3, %f, !dbg !8
%5 = fadd float %4, %g, !dbg !8
%6 = fadd float %5, %h, !dbg !8
%7 = fadd float %6, %i, !dbg !8
%8 = fadd float %7, %j, !dbg !8
%9 = fadd float %8, %k, !dbg !8
%10 = fadd float %9, %l, !dbg !8
%11 = fadd float %10, %m, !dbg !8
%12 = fadd float %11, %n, !dbg !8
%13 = fadd float %12, %o, !dbg !8
%14 = fadd float %13, %p, !dbg !8
ret float %14, !dbg !9
}
define float @_ZN7example7add_8_p17hf34e2f52a58a2f66E(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h, float %i, float %j, float %k, float %l, float %m, float %n, float %o, float %p) unnamed_addr #0 !dbg !10 {
%0 = insertelement <2 x float> undef, float %a, i32 0, !dbg !11
%1 = insertelement <2 x float> %0, float %i, i32 1, !dbg !11
%2 = insertelement <2 x float> undef, float %b, i32 0, !dbg !11
%3 = insertelement <2 x float> %2, float %j, i32 1, !dbg !11
%4 = fadd <2 x float> %1, %3, !dbg !11
%5 = insertelement <2 x float> undef, float %c, i32 0, !dbg !12
%6 = insertelement <2 x float> %5, float %k, i32 1, !dbg !12
%7 = insertelement <2 x float> undef, float %d, i32 0, !dbg !12
%8 = insertelement <2 x float> %7, float %l, i32 1, !dbg !12
%9 = fadd <2 x float> %6, %8, !dbg !12
%10 = fadd <2 x float> %4, %9, !dbg !13
%11 = insertelement <2 x float> undef, float %e, i32 0, !dbg !14
%12 = insertelement <2 x float> %11, float %m, i32 1, !dbg !14
%13 = insertelement <2 x float> undef, float %f, i32 0, !dbg !14
%14 = insertelement <2 x float> %13, float %n, i32 1, !dbg !14
%15 = fadd <2 x float> %12, %14, !dbg !14
%16 = insertelement <2 x float> undef, float %g, i32 0, !dbg !15
%17 = insertelement <2 x float> %16, float %o, i32 1, !dbg !15
%18 = insertelement <2 x float> undef, float %h, i32 0, !dbg !15
%19 = insertelement <2 x float> %18, float %p, i32 1, !dbg !15
%20 = fadd <2 x float> %17, %19, !dbg !15
%21 = fadd <2 x float> %15, %20, !dbg !16
%22 = fadd <2 x float> %10, %21, !dbg !17
%23 = extractelement <2 x float> %22, i32 0, !dbg !17
%24 = extractelement <2 x float> %22, i32 1, !dbg !17
%25 = fadd float %23, %24, !dbg !17
ret float %25, !dbg !18
}
define float @_ZN7example5mul_817he802d2a8e2d2b3afE(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h, float %i, float %j, float %k, float %l, float %m, float %n, float %o, float %p) unnamed_addr #0 !dbg !19 {
%0 = fmul float %a, %b, !dbg !20
%1 = fmul float %0, %c, !dbg !20
%2 = fmul float %1, %d, !dbg !20
%3 = fmul float %2, %e, !dbg !20
%4 = fmul float %3, %f, !dbg !20
%5 = fmul float %4, %g, !dbg !20
%6 = fmul float %5, %h, !dbg !20
%7 = fmul float %6, %i, !dbg !20
%8 = fmul float %7, %j, !dbg !20
%9 = fmul float %8, %k, !dbg !20
%10 = fmul float %9, %l, !dbg !20
%11 = fmul float %10, %m, !dbg !20
%12 = fmul float %11, %n, !dbg !20
%13 = fmul float %12, %o, !dbg !20
%14 = fmul float %13, %p, !dbg !20
ret float %14, !dbg !21
}
define float @_ZN7example7mul_8_p17h4ad349c7d77b6c6cE(float %a, float %b, float %c, float %d, float %e, float %f, float %g, float %h, float %i, float %j, float %k, float %l, float %m, float %n, float %o, float %p) unnamed_addr #0 !dbg !22 {
%0 = insertelement <2 x float> undef, float %a, i32 0, !dbg !23
%1 = insertelement <2 x float> %0, float %i, i32 1, !dbg !23
%2 = insertelement <2 x float> undef, float %b, i32 0, !dbg !23
%3 = insertelement <2 x float> %2, float %j, i32 1, !dbg !23
%4 = fmul <2 x float> %1, %3, !dbg !23
%5 = insertelement <2 x float> undef, float %c, i32 0, !dbg !24
%6 = insertelement <2 x float> %5, float %k, i32 1, !dbg !24
%7 = insertelement <2 x float> undef, float %d, i32 0, !dbg !24
%8 = insertelement <2 x float> %7, float %l, i32 1, !dbg !24
%9 = fmul <2 x float> %6, %8, !dbg !24
%10 = fmul <2 x float> %4, %9, !dbg !25
%11 = insertelement <2 x float> undef, float %e, i32 0, !dbg !26
%12 = insertelement <2 x float> %11, float %m, i32 1, !dbg !26
%13 = insertelement <2 x float> undef, float %f, i32 0, !dbg !26
%14 = insertelement <2 x float> %13, float %n, i32 1, !dbg !26
%15 = fmul <2 x float> %12, %14, !dbg !26
%16 = insertelement <2 x float> undef, float %g, i32 0, !dbg !27
%17 = insertelement <2 x float> %16, float %o, i32 1, !dbg !27
%18 = insertelement <2 x float> undef, float %h, i32 0, !dbg !27
%19 = insertelement <2 x float> %18, float %p, i32 1, !dbg !27
%20 = fmul <2 x float> %17, %19, !dbg !27
%21 = fmul <2 x float> %15, %20, !dbg !28
%22 = fmul <2 x float> %10, %21, !dbg !29
%23 = extractelement <2 x float> %22, i32 0, !dbg !29
%24 = extractelement <2 x float> %22, i32 1, !dbg !29
%25 = fmul float %23, %24, !dbg !29
ret float %25, !dbg !30
} As you can see, vectorization is still the compiler's choice, and the amount of operations in the "operation tree" is still 1/2 + 1 that of the "operation chain"'s IR. Conclusions v2I don't know what the verdict would be. From the n=4 code, or by hand, we can see there is a way of processing the parenthesis tree with the same amount of operations that the chain would need. Here's how to do it for n = 8:
%0 = fmul float %a, %b
%1 = fmul float %0, %c
%2 = fmul float %1, %d
%3 = fmul float %2, %e
%4 = fmul float %3, %f
%5 = fmul float %4, %g
%6 = fmul float %5, %h
%0 = fmul float %a, %b
%1 = fmul float %c, %d
%2 = fmul float %e, %f
%3 = fmul float %g, %h
%4 = fmul float %0, %1
%5 = fmul float %2, %3
%6 = fmul float %4, %5 However, for n = 8 and onwards, LLVM seems to prefer vectorization. That would mean that the vectorized processing of the tree is faster than the linear processing of it. And since the linear processing of the tree is of the same cost as the linear processing of the chain, that would mean that the vectorized processing of the tree is faster than the linear processing of the chain. We might need to make some benchmarks. But I think the optimization potential of converting the "chain" into a "tree" is real. Otherwise LLVM wouldn't be making that transformation. I'd implement the lint, but for floating points only, and from chains of size |
SummaryTo summarize my km-length posts: (context:
|
Thoughts on numerical stabilitySince floating point ops are always lossy, Reducing the number of operations is always good. However, vectorization does not reduce the number of operations, it only parallelizes them. And it can be seen while writing parenthesis that one's never deleting any asterisks nor pluses ( That's why I think that this lint, while definitely helping with performance, would not help with stability. What do y'alls think? :) |
Updated my comments so that they are not huge, intimidating walls of code. Now the code is mostly in "spoiler blocks" ^^ |
Update: seems like the change from linear to vectorized occurs exactly at the amount of 8 variables. I haven't been able to make LLVM vectorize the chain with less variables, regardless of how I group the operations. 8 might be the "magic number" we ought to look for in this lint. |
In generall,
(a+b+c)+(d+e+f)
is faster to compute thana+b+c+d+e+f
, because the former reduces data dependencies, allowing the CPU to run additions in parallel. For float arithmetic, the tree-like addition may improve numerical stability. Suggest inserting parenthesis.The text was updated successfully, but these errors were encountered: