Skip to content

Commit

Permalink
[Mosaic TPU] Support bitcast without forcing retiling.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676998768
  • Loading branch information
bythew3i authored and Google-ML-Automation committed Sep 20, 2024
1 parent d63afd8 commit 0a7503a
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -934,16 +934,16 @@ class VectorLayoutInferer {
auto out_ty = cast<VectorType>(op.getOutput().getType());
auto in_bitwidth = in_ty.getElementTypeBitWidth();
auto out_bitwidth = out_ty.getElementTypeBitWidth();
auto src_layout = getLayout(op.getInput());
LayoutOffsets src_offsets = src_layout->offsets();
auto implicit_dim = src_layout->implicit_dim();
if (src_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
auto in_layout = getLayout(op.getInput());
LayoutOffsets in_offsets = in_layout->offsets();
auto implicit_dim = in_layout->implicit_dim();
if (in_offsets[0].value_or(0) * in_bitwidth % out_bitwidth != 0) {
// Force offset to zero if the input offset on the second minor dimension
// is not a multiple of the ratio of output and input bitwidth.
src_offsets[0] = 0;
} else if (!src_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
in_offsets[0] = 0;
} else if (!in_offsets[0].has_value() && in_bitwidth > out_bitwidth) {
// We can't preserve replicated offset for decreasing bitwidth.
src_offsets[0] = 0;
in_offsets[0] = 0;
}
// Force implicit dim to None if the bitwidth changes. Because we expect 2nd
// minor dim size ratio matches the bitwidth ratio in input and output.
Expand All @@ -955,20 +955,23 @@ class VectorLayoutInferer {
}
implicit_dim = ImplicitDim::kNone;
}
// TODO(b/348485035): Instead of forcing to native tiling, bitcast should
// keep the input tiling and infer bitcastable tiling for output. For
// example, it is valid to bitcast vector<8x128xi32> with tile (1, 128) to
// vector<8x128xbf16> with tile (2, 128).
const auto &in_tiling = in_layout->tiling();
if (in_tiling[0] * in_bitwidth % out_bitwidth != 0) {
return op.emitOpError(
"Expected input sublane tiling can be bitcasted to output sublane "
"tiling.");
}
auto out_tiling = in_tiling;
out_tiling[0] = out_tiling[0] * in_bitwidth / out_bitwidth;

auto out_offsets = in_offsets;
if (in_offsets[0].has_value()) {
out_offsets[0] = in_offsets[0].value() * in_bitwidth / out_bitwidth;
}

setLayout(
op,
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
implicit_dim),
VectorLayout(out_bitwidth,
{src_offsets[0].has_value()
? src_offsets[0].value() * in_bitwidth / out_bitwidth
: src_offsets[0],
src_offsets[1]},
nativeTiling(out_bitwidth), implicit_dim));
op, VectorLayout(in_bitwidth, in_offsets, in_tiling, implicit_dim),
VectorLayout(out_bitwidth, out_offsets, out_tiling, implicit_dim));
return success();
}

Expand Down

0 comments on commit 0a7503a

Please sign in to comment.