Skip to content

Commit

Permalink
[Mosaic TPU] Fix mosaic alignment check in concatenate rule.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670837792
  • Loading branch information
bythew3i authored and jax authors committed Sep 4, 2024
1 parent ebc6c18 commit c1d3c2d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2466,11 +2466,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
"Not implemented: Only native tiling with offset (0, 0) is supported "
"when concatenation along tiling dims.");
}
// Check if shapes of src and res are aligned to native tiling.
// Check if the concat dim size of src and res is aligned to native tiling.
auto check_aligned = [&](const VectorType &vty) {
auto i = dimension - res_ty.getRank();
return vty.getRank() >= 2 &&
*(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) == 0 &&
*(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) == 0;
*(vty.getShape().end() + i) % *(layout.tiling().end() + i) == 0;
};
bool is_aligned = check_aligned(res_ty);
int op_idx = 0;
Expand Down

0 comments on commit c1d3c2d

Please sign in to comment.