Skip to content

Commit

Permalink
[XLA:Python] Fix DLPack behavior with unit dimensions.
Browse files Browse the repository at this point in the history
As discovered in jax-ml/jax#24680, when a PyTorch tensor has a dimension with size `1`, it seems to report the DLPack stride for that dimension as `1`. This behavior wasn't supported by the logic in XLA, resulting in an incorrect layout on the imported array.

PiperOrigin-RevId: 696341186
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 14, 2024
1 parent dde3c51 commit dc3692c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions xla/python/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,17 @@ absl::StatusOr<std::vector<int64_t>> StridesToLayout(
std::vector<int64_t> minor_to_major(dims.size());
std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
absl::c_sort(minor_to_major, [&](int a, int b) {
// If two dimensions have the same stride, prefer the major-to-minor
// interpretation of the ordering, since that's what JAX wants.
if (dims[a] <= 1 || dims[b] <= 1) {
return b < a;
}
if (strides[a] < strides[b]) {
return true;
}
if (strides[a] > strides[b]) {
return false;
}
// If two dimensions have the same stride, prefer the major-to-minor
// interpretation of the ordering, since that's what JAX wants.
return b < a;
});
int64_t stride = 1;
Expand Down

0 comments on commit dc3692c

Please sign in to comment.