-
Notifications
You must be signed in to change notification settings - Fork 912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make internal broadcast and unbroadcast both primitives #292
base: dev-1.2
Are you sure you want to change the base?
Conversation
Yeah looks like a good idea. I'd been meaning to switch to using numpy's |
Also I feel like we should get the optional vspace checking setup as a matter of priority so that we can rigorously test that these functions are outputting the correct thing (dtype in particular). |
target_shape, target_ndim, _, target_iscomplex = target_meta | ||
x_shape = onp.shape(x) | ||
while onp.ndim(x) > target_ndim: | ||
x = onp.sum(x, axis=broadcast_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if we should replace the above two lines with:
x = onp.sum(x, axis=range(broadcast_idx, broadcast_idx + onp.ndim(x) - target_ndim))
or similar. Am I right that only calling sum
once might lead to better performance, basically because only one output array has to be allocated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I briefly tried something like that (though if broadcast_idx
is -1, which is the only nonzero use case I noticed in the code, then I think we want something different) and it didn't seem to make a speed difference, so I dropped it. Now is a good time to make sure it's performant, though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing a few timings it looks like there is a benefit for small arrays but it's not massive:
In [15]: a = np.ones((5, 5, 5))
In [16]: %timeit np.sum(a, axis=(0, 1))
5.38 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [17]: %timeit x = np.sum(a, axis=0); x = np.sum(x, axis=0)
8.62 µs ± 124 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
and for slightly bigger arrays it's the other way round (maybe I've made some mistake?):
In [18]: a = np.ones((50, 50, 50))
In [19]: %timeit np.sum(a, axis=(0, 1))
118 µs ± 930 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [20]: %timeit x = np.sum(a, axis=0); x = np.sum(x, axis=0)
81.6 µs ± 1.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, I got similar timings. That seems weird for the bigger arrays...
if anp.iscomplexobj(x) and not target_iscomplex: | ||
x = anp.real(x) | ||
if size == 1: # TODO(mattjj): bug here w/ passing through scalars? | ||
x = onp.sum(x, axis=axis, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do a similar thing for this one.
Btw I think this change was inspired by your improvements to the dot VJPs. Re: vspaces, I generally agree, though I'm thinking that if vspaces are primarily for testing, we should use them extensively in our testing code but not incur the costs at runtime for every grad eval. |
Yeah cool I'm totally agreed that's the right approach re: vspaces
…On Wed, 13 Sep 2017 at 16:48, Matthew Johnson ***@***.***> wrote:
Btw I think this change was inspired by your improvements to the dot VJPs.
Re: vspaces, I generally agree, though I'm thinking that if vspaces are
primarily for testing, we should use them extensively in our testing code
but not incur the costs at runtime for every grad eval.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#292 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AOjguxl1x1F_UjUW09Y9fAc3olejst2pks5sh_lqgaJpZM4PU72r>
.
|
I think this obviates the changes in HIPS#292.
I've effectively incorporated the changes in this pr into #312. |
At @dougalm's suggestion, I took a stab at making our internal
broadcast
andunbroadcast
functions into primitives. They seem to form a nice pair!This might prevent graph expansion and be a bit faster, though I haven't actually run asv on this change yet.
Any thoughts on this first pass, @j-towns?