Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tupdate to Tensor class and start simplifying tscatter #101

Open
Mikolaj opened this issue Apr 16, 2023 · 8 comments
Open

Add tupdate to Tensor class and start simplifying tscatter #101

Mikolaj opened this issue Apr 16, 2023 · 8 comments
Labels
help wanted Extra attention is needed

Comments

@Mikolaj
Copy link
Owner

Mikolaj commented Apr 16, 2023

It should be such that tupdate (tzero sh) ix v is the transpose of tindex v ix. Also

-- astScatter sh v (Z, ix) = update (tzero sh 0) ix v

Probably tscatter can then be simplified using tupdate similarly as tgather simplifies using tindex right now. I'm not sure how much of the current complex tgather simplification code would dualize, but at least the trivial cases should do and they offer great benefits whenever they apply.

I suppose, we'd also need an Ast term for the operation, vectorization rules and forward pass and transpose rules. A similar operation is already implemented at the low level, because it's needed too implement scatter:

-- TODO: try to weave a similar magic as in tindex0R
-- TODO: for the non-singleton case see
-- https://github.com/Mikolaj/horde-ad/pull/81#discussion_r1096532164
updateNR :: forall m n a. (Numeric a, KnownNat m, KnownNat n)
=> OR.Array (m + n) a -> [(IndexInt m, OR.Array n a)]
-> OR.Array (m + n) a
updateNR arr upd =
let Data.Array.Internal.RankedS.A
(Data.Array.Internal.RankedG.A shRaw
Data.Array.Internal.T{offset, values}) = OR.normalize arr
!_A = assert (offset == 0) ()
in let sh = listShapeToShape shRaw
f t (ix, u) =
let v = OR.toVector u
i = fromIntegral $ toLinearIdx @m @n sh ix
in LA.vjoin [V.take i t, v, V.drop (i + V.length v) t]
in OR.fromVector shRaw (foldl' f values upd)

This needs to be generalized to non-singleton indexes but, OTOH, it can be specialized to just one update, at least initially.

Overall, this ticket is a big chunk of work, but quite modular. A couple of parts, but probably intertwined with others, are crucial for performance of the simplified horde-ad.

@Mikolaj Mikolaj added the help wanted Extra attention is needed label Apr 16, 2023
@Mikolaj Mikolaj changed the title Add tupdatet to Tensor class and start simplifying tscatter` Add tupdate' to Tensor class and start simplifying tscatter` Apr 16, 2023
@Mikolaj Mikolaj changed the title Add tupdate' to Tensor class and start simplifying tscatter` Add tupdate'to Tensor class and start simplifying tscatter Apr 16, 2023
@Mikolaj Mikolaj changed the title Add tupdate'to Tensor class and start simplifying tscatter Add tupdate to Tensor class and start simplifying tscatter Apr 16, 2023
@tomsmeding
Copy link
Collaborator

What would be the type of this new tupdate?

@Mikolaj
Copy link
Owner Author

Mikolaj commented Apr 18, 2023

I think, the simplest one that agrees with

-- astScatter sh v (Z, ix) = update (tzero sh 0) ix v

which is

tupdate ::  TensorOf (p + n) r -> IndexOf p r -> TensorOf n r -> TensorOf (p + n) r

which checks out with the type of transpose of update (tzero sh 0) ix v, which is tindex v ix

tindex :: TensorOf (p + n) r -> IndexOf p r -> TensorOf n r

@tomsmeding
Copy link
Collaborator

Wouldn't then tupdate base idx item necessarily copy (almost) the entirety of base? This is basically the one-hot encoding for the transposition of indexing, slightly modified to compute base + onehot i instead of just onehot i. I struggle to see how this will ever be remotely efficient if you're doing more than 1 indexing operation on an array; surely you want to batch them up into a single scatter?

@Mikolaj
Copy link
Owner Author

Mikolaj commented Apr 18, 2023

The motivating example

let x11 = tscatter [1] (tfromList [tsum (x3 * x9)])
                       (\[i10] -> [0])
  in x11 ! [0]

has nothing interesting to batch in a single scatter. Similarly, a transpose of indexing has just one one-hot, not a collection of them. I guess, a general rule for indexing of tupdate would permit us to perform the indexing from the motivating example early and not materialize any of the large tensors. In other cases, we can interpret/compile sequential tupdates jointly. We can think of the associative accumulators.

Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up.

@tomsmeding
Copy link
Collaborator

Even if we end up batching many things up in a single scatter, we have to represent them somehow while they are sprinkled in many places of the generated code. I'm guessing trivial cases of scatter may not be the best way. Then we can transform the code to get these things together and then, eventually, batch them up.

Ah, I see, you want an easier-to-recognise representation for trivial scatters. Because I feel that your given trivial scatter won't really be much slower than the corresponding tupdate, simply because all the overhead is in the copying of the base tensor. But if your point with tupdate is not performance but recognisability and hence easier recombination later in an efficient single scatter, then yes that makes sense.

Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of tscatter in a way that is not too hard to implement and subsumes the cases where tupdate would be useful.

But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a gather anyway. Do you happen to have a motivating example here?

@Mikolaj
Copy link
Owner Author

Mikolaj commented Apr 18, 2023

Ah, I see, you want an easier-to-recognise representation for trivial scatters.

Yes, that's the main point.

Because I feel that your given trivial scatter won't really be much slower than the corresponding tupdate, simply because all the overhead is in the copying of the base tensor.

Sure, but if I have a rule

tupdate u v ix ! ix --> v

then this is faster than leaving the scatter be, materializing it and then projecting. But, again, the rule can be just as well written for scatter, not tupdate, so it's mostly about presentation.

Though I wonder if it's necessary. Maybe we can find a way to combine (vectorise, essentially) more general forms of tscatter in a way that is not too hard to implement and subsumes the cases where tupdate would be useful.

That would be great.

But that depends on how they appear in the code to simplify, which in turn depends on how the indexing operations appear in the original program. If they appear easily batchable there already, then the problem doesn't even arise because the things are immediately vectorised to a gather anyway. Do you happen to have a motivating example here?

Not really. But once we construct tscatter in whatever smart way, we'd want to fuse tscatter and simplify it in other ways. What I have are, somewhat tangentially, the corresponding rules for tgather, e.g.,

(k :$ sh', (var ::: vars, i1 :. rest1)) ->
if | not (any (`intVarInAstInt` i1) vars0) ->
astGatherZOrStepOnly stepOnly sh0 (astIndex v0 (i1 :. ZI))
(vars0, rest1)
| case iN of
AstIntVar varN' ->
varN' == varN
&& not (any (varN `intVarInAstInt`) restN)
&& case ( dropShape @(m - 1) sh0
, dropShape @(p - 1) (shapeAst v0) ) of
(kN :$ _, vkN :$ _) -> kN == vkN
_ -> error "impossible pattern needlessly required"
_ -> False
-> astGatherZOrStepOnly stepOnly sh0 v0 (varsN, restN)
| intVarInIndex var ix0 ->
astGatherCase sh0 v0 (vars0, ix0)
| any (`intVarInIndex` ix0) vars ->
astKonst k (astGatherZOrStepOnly stepOnly sh' v0 (vars, ix0))
| otherwise ->
astKonstN sh0 (astIndex v0 ix0)

that simplify tgather a lot and use indexing (astIndex). I can't write such rules for tscatter, because I don't have tupdate (and using tgather instead of tindex and tscatter instead of tupdate would quickly lead to insanity).

@Mikolaj
Copy link
Owner Author

Mikolaj commented Apr 25, 2023

This is killing my CI, so I will have to at least add the update term so that it takes less memory than the special case of scatter. Then I'd either start simplifying indexing of update or fuse many update into one. That's still very ad-hoc and much easier than general dualising the simplification and fusion of gather, if it's possible at all.

@Mikolaj
Copy link
Owner Author

Mikolaj commented Apr 25, 2023

Eventually I simplified the scatters that are the transpose of indexing and I also started simplifying some special forms of scatters. This helped with tests speed, but not nearly enough. All without introducing tupdate yet, which would probably just be tupdate (c, ix) = AstScatter sh c (Z, ix) (which seems to be precisely dual to indexing both when transposing and when comparing scatter and gather simplification rules) or tupdate t (c, ix) = t + AstScatter sh c (Z, ix) (which may or may not fuse better in some cases). Other variants seem to have problems when getting vectorized.

All in all, scatter can certainly be fused with other scatters and can be simplified a bit more, but I'm no longer certain we can just reverse arrows in the gather simplification code. Reversing arrows seems tricky.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants