Skip to content

Commit

Permalink
[Mosaic TPU] Fold sublane offset to indices when storing to untiled ref.
Browse files Browse the repository at this point in the history
This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]).

PiperOrigin-RevId: 698896373
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Nov 21, 2024
1 parent f3e7e68 commit f899d51
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1640,14 +1640,14 @@ class VectorLayoutInferer {
// Since it is untiled, we can store to any arbitrary address which
// means the sublane offset can be any value and we can fold it to
// 2nd minor index.
// TODO(jevinjiang): We can fold the sublane offset into the 2nd minor
// index. But we need to handle negative index in lower-to-llo. For
// now, we just force the sublane offset to be 0.
auto prev_store_layout = getLayout(op.getValueToStore());
TPU_CHECK_OP(prev_store_layout.has_value(), "missing vector layout");
offsets[0] = prev_store_layout->offsets()[0].value_or(0);
if (offsets[1].value_or(0) >= tiling[1]) {
offsets[1] = 0;
}
store_layout = VectorLayout(bitwidth, {0, offsets[1]},
nativeTiling(bitwidth), ImplicitDim::kNone);
store_layout = VectorLayout(bitwidth, offsets, nativeTiling(bitwidth),
ImplicitDim::kNone);
} else {
store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]},
ImplicitDim::kNone);
Expand Down

0 comments on commit f899d51

Please sign in to comment.