-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tupdate
to Tensor
class and start simplifying tscatter
#101
Comments
tupdatet to
Tensor class and start simplifying
tscatter`tupdate' to
Tensor class and start simplifying
tscatter`
tupdate' to
Tensor class and start simplifying
tscatter`tupdate'
to Tensor
class and start simplifying tscatter
tupdate'
to Tensor
class and start simplifying tscatter
tupdate
to Tensor
class and start simplifying tscatter
What would be the type of this new |
I think, the simplest one that agrees with
which is tupdate :: TensorOf (p + n) r -> IndexOf p r -> TensorOf n r -> TensorOf (p + n) r which checks out with the type of transpose of tindex :: TensorOf (p + n) r -> IndexOf p r -> TensorOf n r |
Wouldn't then |
The motivating example let x11 = tscatter [1] (tfromList [tsum (x3 * x9)])
(\[i10] -> [0])
in x11 ! [0] has nothing interesting to batch in a single scatter. Similarly, a transpose of indexing has just one one-hot, not a collection of them. I guess, a general rule for indexing of Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up. |
Ah, I see, you want an easier-to-recognise representation for trivial scatters. Because I feel that your given trivial scatter won't really be much slower than the corresponding Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a |
Yes, that's the main point.
Sure, but if I have a rule
then this is faster than leaving the
That would be great.
Not really. But once we construct horde-ad/simplified/HordeAd/Core/AstSimplify.hs Lines 626 to 645 in 35ea918
that simplify |
This is killing my CI, so I will have to at least add the |
Eventually I simplified the scatters that are the transpose of indexing and I also started simplifying some special forms of scatters. This helped with tests speed, but not nearly enough. All without introducing All in all, scatter can certainly be fused with other scatters and can be simplified a bit more, but I'm no longer certain we can just reverse arrows in the gather simplification code. Reversing arrows seems tricky. |
It should be such that
tupdate (tzero sh) ix v
is the transpose oftindex v ix
. Alsohorde-ad/simplified/HordeAd/Core/AstSimplify.hs
Line 433 in 6f88617
Probably
tscatter
can then be simplified usingtupdate
similarly astgather
simplifies usingtindex
right now. I'm not sure how much of the current complextgather
simplification code would dualize, but at least the trivial cases should do and they offer great benefits whenever they apply.I suppose, we'd also need an Ast term for the operation, vectorization rules and forward pass and transpose rules. A similar operation is already implemented at the low level, because it's needed too implement
scatter
:horde-ad/src/common/HordeAd/Internal/TensorOps.hs
Lines 101 to 117 in 6f88617
This needs to be generalized to non-singleton indexes but, OTOH, it can be specialized to just one update, at least initially.
Overall, this ticket is a big chunk of work, but quite modular. A couple of parts, but probably intertwined with others, are crucial for performance of the simplified horde-ad.
The text was updated successfully, but these errors were encountered: