Parallelizing conditionals on single gpu, xmap? #14779
Unanswered
jakubMitura14
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello, I have function f(x) that is part of the optimization problem and is differentiable.
Function is relatively simple, although not trivial. Will be executed millions of times in each epoch.
I have a trivial way to check is the output of f(x) will be 0 without executing the function as in the end I sum all of the outputs I also know then that it will not affect final gradient computation.
Problem:
In this case, adding a conditional that check first is the output destined to be 0 and execute function only otherwise is crucial for performance.
The function is simple, so on gpu it will lead to small occupancy if looping or scanning.
The solution would be to vmap f(x) to maximize occupancy, BUT in docs, it is stated that lax.cond will be changed into lax.select when vmapped what makes my conditional useless. As both branches will be executed either way.
Is there any work around? Maybe xmap can help?
Beta Was this translation helpful? Give feedback.
All reactions