-
The use case I'm describing is a function implemented in jax, lowered/exported to HLO, and executed with the XLA C++ runtime. I couldn't get shape polymorphism to load correctly with XLA, though it doesn't solve my use anyway since I need ahead-of-time compilation. I used Jax to export functions for varying powers of 2 of the input dimensions, and then select the smallest one at runtime. This works well, but for some functions I need to pad the inputs. Here are some things I've considered:
This, however, doesn't work since slicing needs to be static. I can't find a correct dynamic slice method either since they require the shape to static. We can try to construct the mask a different way by building up each dimension:
The matrix multiplication is slow, but works. We could also try successive
This problem doesn't seem like it should be difficult or expensive - after all, a simple comparison of the thread idx in a CUDA function with the runtime shape would perform this operation. I'm not sure exactly what the equivalent in CPU code would be. How can I implement this padding in pure Jax? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I don't know of any way to do this beyond the methods you've already suggested. One tweak, though: rather than element-wise multiplication for the mask, you might use element-wise and mask = 1
indices = jnp.meshgrid(*(jnp.arange(s) for s in shape), sparse=True, indexing='ij')
for i, dyn_len in zip(indices, actual_shape):
mask &= (i < dyn_len) |
Beta Was this translation helpful? Give feedback.
Well, this maps pretty directly to the operations the GPU will have to do to construct the output array, so I'm not sure what alternatives could exist.