Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract array value from a one-hot Vec Bool mask #132

Closed
Moelf opened this issue Oct 29, 2024 · 5 comments
Closed

Extract array value from a one-hot Vec Bool mask #132

Moelf opened this issue Oct 29, 2024 · 5 comments

Comments

@Moelf
Copy link

Moelf commented Oct 29, 2024

julia> a = Vec((1,2,3,4))
<4 x Int64>[1, 2, 3, 4]

julia> mask = a == 3
<4 x Bool>[0, 0, 1, 0]

julia> b = Vec((5,6,7,8))
<4 x Int64>[5, 6, 7, 8]

julia> b[mask]
# errors

What's the canonical way to get7 in this case other than a long chain of mask[1] ? b[1] : mask[2]: b[2]...

@KristofferC
Copy link
Collaborator

KristofferC commented Oct 29, 2024

julia> maski = convert(Vec{4, Int}, mask)
<4 x Int64>[0, 0, 1, 0]

julia> maski * b
<4 x Int64>[0, 0, 7, 0]

julia> sum(maski * b)
7

Dunno 🤷. What would you do in other languages?

@Moelf
Copy link
Author

Moelf commented Oct 29, 2024

idk why other language matters, I mean in Julia we do have ary[mask] method but of course it won't work for SIMD naively -- I was just wondering what's the fastest approximation to that. What you have seems to be faster so I wil take that:

julia> a = Vec((1,2,3,4));

julia> mask = a == 3;

julia> function f(v, mask)
           maski = convert(Vec{4, Int}, mask)
           sum(maski * v)
       end
f (generic function with 1 method)

julia> @be f($a, $mask)
Benchmark: 4663 samples with 10414 evaluations
 min    1.868 ns
 median 1.918 ns
 mean   1.950 ns
 max    19.795 ns

julia> function g(v, mask)
           mask[1] ? v[1] : mask[2] ? :v[2] : mask[3] ? v[3] : v[4]
       end
g (generic function with 1 method)

julia> @be g($a, $mask)
Benchmark: 4316 samples with 6734 evaluations
 min    3.101 ns
 median 3.166 ns
 mean   3.193 ns
 max    19.319 ns

@Moelf Moelf closed this as completed Oct 29, 2024
@Moelf Moelf changed the title Extract value from a Vec Bool mask Extract array value from a one-hot Vec Bool mask Oct 29, 2024
@Moelf
Copy link
Author

Moelf commented Oct 29, 2024

oh, that doesn't work with floats:

julia> a = Vec((1.0,2.0,3.0,4.0));

julia> mask = a == 3.0
<4 x Bool>[0, 0, 1, 0]

julia> convert(Vec{4, Float64}, mask)
ERROR: unreachable
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35

@KristofferC
Copy link
Collaborator

KristofferC commented Oct 29, 2024

idk why other language matters

Because this is just a high level wrapper over Simd intrinsic so the same combination used in other languages would (or should) probably work here.

Basically can you find what operations in https://llvm.org/docs/LangRef.html that would be used to do this?

@Moelf
Copy link
Author

Moelf commented Oct 29, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants