Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[XLA:Python] Fix DLPack behavior with unit dimensions.
As discovered in jax-ml/jax#24680, when a PyTorch tensor has a dimension with size `1`, it seems to report the DLPack stride for that dimension as `1`. This behavior wasn't supported by the logic in XLA, resulting in an incorrect layout on the imported array. PiperOrigin-RevId: 696341186
- Loading branch information