Skip to content

Commit

Permalink
[Mosaic TPU] Break out implicit dim changes from relayout
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658752228
  • Loading branch information
apaszke authored and jax authors committed Aug 2, 2024
1 parent efba5f6 commit 99625ff
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5372,6 +5372,50 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
<< dst_tiling[1] << ")";
}

FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
OpBuilder &builder, const std::array<int64_t, 2> target_shape,
const Location loc, VectorType vty, const VectorLayout src,
xla::Array<Value> vregs, const VectorLayout::ImplicitDim dst_implicit_dim,
const LayoutOffsets dst_offset_hints) {
if (src.implicit_dim() == dst_implicit_dim) {
return std::make_pair(src, std::move(vregs));
}
// Remove second minor implicit dim, for values that have (8, 128) tiling.
// TODO(apaszke): We should allow replicated dst_offset_hints[0].
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
dst_implicit_dim == VectorLayout::ImplicitDim::kNone &&
src.bitwidth() == 32 && src.tiling() == std::array<int64_t, 2>{8, 128} &&
dst_offset_hints[0]) {
int64_t dst_sublane_offset = *dst_offset_hints[0];
VectorLayout dst(src.bitwidth(), {dst_sublane_offset, src.offsets()[1]},
src.tiling(), dst_implicit_dim);
xla::Array<Value> new_vregs(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
new_vregs.Each([&](const absl::Span<const int64_t> idx,
Value *tile) {
const int64_t dst_2nd_minor_idx = idx.size() - 2;
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src.insertImplicit<int64_t>(src_idx, 0);
const int dst_sl_start =
idx[dst_2nd_minor_idx] == 0 ? dst_sublane_offset : 0;
src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] +
dst_sl_start - dst_sublane_offset;
for (int dst_sl_idx = dst_sl_start;
dst_sl_idx < target_shape[0] &&
src_idx[dst_2nd_minor_idx] < vregs.dim(dst_2nd_minor_idx);
++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) {
*tile = copy_one_sublane(builder, vregs(src_idx),
src.offsets()[0].value_or(dst_sl_idx), *tile,
dst_sl_idx, target_shape);
}
});
return std::make_pair(dst, new_vregs);
}
return emitError(loc,
"Not implemented: Unsupported implicit dim change: from ")
<< src << " to " << dst_implicit_dim;
}

// TODO(apaszke): Test this function properly
FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
OpBuilder &builder,
Expand Down Expand Up @@ -5413,12 +5457,6 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
}
}
}
auto not_implemented = [&]() -> LogicalResult {
return emitError(v.getLoc(),
"Not implemented: Unsupported layout change for ")
<< vty << ": " << src << " -> " << dst;
};

// Save the original value of dst to use it at the end. It determines the
// out_layout of the result of assemble.
// TODO(apaszke): Retiling should not care about the implicit dim. Move
Expand Down Expand Up @@ -5468,37 +5506,11 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
dst.offsets()[0] == std::nullopt &&
src.offsets()[0] != std::nullopt));

// Remove second minor implicit dim, for values that have (8, 128) tiling.
if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor &&
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
src.bitwidth() == 32 && dst.offsets()[0] &&
src.offsets()[1] == dst.offsets()[1] && src.tiling() == dst.tiling() &&
src.tiling() == std::array<int64_t, 2>{8, 128}) {
xla::Array<Value> src_tiles_retiled(
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
src_tiles_retiled.Each([&](const absl::Span<const int64_t> idx,
Value *tile) {
const int64_t dst_2nd_minor_idx = idx.size() - 2;
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
src.insertImplicit<int64_t>(src_idx, 0);
const int dst_sl_start =
idx[dst_2nd_minor_idx] == 0 ? *dst.offsets()[0] : 0;
src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] +
dst_sl_start - *dst.offsets()[0];
for (int dst_sl_idx = dst_sl_start;
dst_sl_idx < target_shape[0] &&
src_idx[dst_2nd_minor_idx] < src_tiles.dim(dst_2nd_minor_idx);
++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) {
*tile = copy_one_sublane(builder, src_tiles(src_idx),
src.offsets()[0].value_or(dst_sl_idx), *tile,
dst_sl_idx, target_shape);
}
});
src = dst;
src_tiles = std::move(src_tiles_retiled);
} else if (src.implicit_dim() != dst.implicit_dim()) {
return not_implemented();
}
FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
changeImplicitDim(builder, ctx.target_shape, v.getLoc(), vty, src,
std::move(src_tiles), dst.implicit_dim(),
dst.offsets()));

FAILUREOR_ASSIGN_OR_RETURN(
std::tie(src, src_tiles),
Expand Down

0 comments on commit 99625ff

Please sign in to comment.