Skip to content

Commit

Permalink
[Mosaic] Fix mask creation for packed sublanes
Browse files Browse the repository at this point in the history
Unaligned concat used to be f32 only, but implicitly protected via unimplemented support for multi-row-shift in sub32 types. When this was added, we started invoking unaligned concat flow w/ sub32 types, but the masking code that assumed full rows (unpacked types) was no longer sufficient - we need better granularity for these cases. This only affects sublanes, as that is where we pack, we don't have partial lanes.

This CL, as a small benefit, also adds better error messages to the ops involved in lower_to_llo.cc.

PiperOrigin-RevId: 695538451
  • Loading branch information
Google-ML-Automation committed Nov 12, 2024
1 parent 3a5ac48 commit ee20e88
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2664,7 +2664,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,

const auto bitwidth = res_ty.getElementTypeBitWidth();
const int packing = res_layout->packing();

SmallVector<int64_t> out_idx;
vreg.Each([&](absl::Span<const int64_t> idx, Value *v) {
out_idx.assign(idx.begin(), idx.end());
Expand All @@ -2674,17 +2673,29 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
const VectorType vmask_ty = getNativeVregOrVmaskType(
builder.getI1Type(), bitwidth, ctx.target_shape);
if (tiling_dim.value() == 0) { // sublane
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(), vmask_ty,
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(operand_offset * packing),
boundIdxConst(layout->tiling()[1])});
if (operand_offset % packing != 0) {
// Packed case, degenerate where we have a half or quarter
// sublane.
// TODO(mvoz): We can probably always use the
// CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add
// support for unpacked types in some of the invariants in
// lower_to_llo.
mask = builder.create<tpu::CreateSubelementMaskOp>(
op.getLoc(), vmask_ty, 0, operand_offset, packing);
} else {
auto sublane_offset = operand_offset / packing;
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(), vmask_ty,
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(sublane_offset),
boundIdxConst(layout->tiling()[1])});
}
} else { // lane
mask = builder.create<tpu::CreateMaskOp>(
op.getLoc(), vmask_ty,
ArrayRef<Value>{boundIdxConst(0), boundIdxConst(0)},
ArrayRef<Value>{boundIdxConst(layout->tiling()[0]),
boundIdxConst(operand_offset * packing)});
boundIdxConst(operand_offset)});
}
// Blend the current value with the existing value in the output.
*v = builder.create<arith::SelectOp>(op.getLoc(), mask,
Expand Down

0 comments on commit ee20e88

Please sign in to comment.