Replies: 1 comment 2 replies
-
I think my intuition of non-contiguous memory was correct, because I found a pretty hacky but surprisingly fast way of dealing with this. out = [0,0,0,0,0]
for i,Mi in enumerate(M):
for j,Mij in enumerate(Mi):
#Mij = (512*512)
out[i] += Mij*L[j]
return [x.T for x in out] Luckily, the output of Anyone looking at that code will probably wonder wth I am doing there but it works and is fast. I don't know if there is already an elegant way of circumventing the issues I had and how viable the idea would be, but something like a lambda-allocator would be very nice. Basically, a function that takes a lambda expression as input and populates a contiguous memory tensor whose entry at f = lambda i,j: (i==j)*1.0
I = lambda_allocator(f, bounds=[(0,10), (0, 10)]) #equivalent to jnp.eye(10)
f = lambda i,j: jnp.zeros((100,))
M = lambda_allocator(f, bounds=[(0,13), (0, 10)]) #equivalent to jnp.zeros((13,10,100)) |
Beta Was this translation helpful? Give feedback.
-
I am experiencing undesired behavior which leads a significant slow down.
I am not sure If I am on the right track of solving this issue, but I think it has something to do with the memory management of my function.
I have a function that gets called 100s of times. It is very important for my application that this function is evaluated quickly. Essentially what the function computes is a dot-product of a matrix that I need to generate and another matrix that is provided, thus essentially
M(r,d)@ L
.M
andL
have fairly large leading dimensions ofN = 250,000 to 1,000,000
(butN
is fixed and does not vary within a run).r
andd
have dimensions(3, N)
which in turn meansM
has shape(N,35,5)
andL
has dimensions(N,4,35)
. The problem that I have run into is that there is no "easy" way to expressM
in terms ofr
andd
(i.e. as a sequence of matmuls or something). I am currently generatingM
from a list of submatrices like so:This is very fast! 1.2ms kind of what I need but I now have a list of arrays that I unfortunately cannot shove into
dot
.As soon as I turn this sequence of arrays into an
jnp.array
(jnp.asarray
does not help) and transpose it so I can useM@L
, i.e. if I doreturn jnp.array(M).transpose((2,0,1))
the function becomes 40x slower.If tried various things, like using
einsum
to avoid thetranspose
but all of these strategies have failed. I believe my problem is that when I generate the tuple of tuplesM
, the respective blocks of memory are fragmented and as soon as I calltranspose
XLA starts copying memory around to make thearray
M
contiguous.Am I interpreting this behavior correctly?
Is there a way of directly generating the submatrices at the "correct location in terms of memory" such that it already is contiguous and does not require copying?
Beta Was this translation helpful? Give feedback.
All reactions