Skip to content

Commit

Permalink
[mlir][gpu] Add 'cluster_size' attribute to gpu.subgroup_reduce (#104851
Browse files Browse the repository at this point in the history
)

This enables performing several reductions in parallel, each smaller
than the size of the subgroup.

One potential application is flash attention with subgroup-wide matrix
multiplication and reduction combined in one kernel. The multiplication
operation requires a 2D matrix to be distributed over the lanes of the
subgroup, which then constrains the shape the following reduction can
have if we want to keep data in registers.
  • Loading branch information
andfau-amd authored Aug 20, 2024
1 parent 93eda08 commit 7aa22f0
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 21 deletions.
40 changes: 32 additions & 8 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1197,22 +1197,29 @@ def AnyIntegerOrFloatOr1DVector :
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]> {
let summary = "Reduce values among subgroup.";
let description = [{
The `subgroup_reduce` op reduces the value of every lane (work item) across
a subgroup. The result is equal for all lanes.
The `subgroup_reduce` op reduces the values of lanes (work items) across a
subgroup.

The subgroup is divided into clusters of `cluster_size` contiguous lanes
each, and a reduction is done for every lane of each cluster (in parallel).
The result is equal for all lanes in a cluster. When `cluster_size` is
omitted, there is a single cluster covering the entire subgroup.

When the reduced value is of a vector type, each vector element is reduced
independently. Only 1-d vector types are allowed.

Example:

```mlir
%1 = gpu.subgroup_reduce add %a : (f32) -> (f32)
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> (vector<4xf16>)
%1 = gpu.subgroup_reduce add %a : (f32) -> f32
%2 = gpu.subgroup_reduce add %b : (vector<4xf16>) -> vector<4xf16>
%3 = gpu.subgroup_reduce add %c cluster_size(4) : (f32) -> f32
```

If `uniform` flag is set either none or all lanes of a subgroup need to execute
this op in convergence. The reduction operation must be one
of:
this op in convergence.

The reduction operation must be one of:
* Integer types: `add`, `mul`, `minui`, `minsi`, `maxui`, `maxsi`, `and`,
`or`, `xor`
* Floating point types: `add`, `mul`, `minnumf`, `maxnumf`, `minimumf`,
Expand All @@ -1222,12 +1229,29 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", [SameOperandsAndResultType]
let arguments = (ins
AnyIntegerOrFloatOr1DVector:$value,
GPU_AllReduceOperationAttr:$op,
UnitAttr:$uniform
UnitAttr:$uniform,
OptionalAttr<I32Attr>:$cluster_size
);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);

let builders = [
OpBuilder<(ins "Value":$value,
"::mlir::gpu::AllReduceOperation":$op,
"bool":$uniform), [{
build($_builder, $_state, value, op, uniform, /*cluster_size=*/ nullptr);
}]>,
OpBuilder<(ins "Value":$value,
"::mlir::gpu::AllReduceOperation":$op,
"bool":$uniform,
"std::optional<uint32_t>":$cluster_size), [{
build($_builder, $_state, value, op, uniform, cluster_size ? $_builder.getI32IntegerAttr(*cluster_size) : nullptr);
}]>
];

let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
(`uniform` $uniform^)? attr-dict
(`uniform` $uniform^)?
(`cluster_size` `(` $cluster_size^ `)`)?
attr-dict
`:` functional-type(operands, results) }];

let hasFolder = 1;
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ struct GPUSubgroupReduceOpLowering

matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getClusterSize())
return rewriter.notifyMatchFailure(
op, "lowering for clustered reduce not implemented");

if (!op.getUniform())
return rewriter.notifyMatchFailure(
op, "cannot be lowered to redux as the op must be run "
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ class GPUSubgroupReduceConversion final
LogicalResult
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getClusterSize())
return rewriter.notifyMatchFailure(
op, "lowering for clustered reduce not implemented");

if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");

Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,22 @@ LogicalResult gpu::SubgroupReduceOp::verify() {
<< "` reduction operation is not compatible with type "
<< getType();
}

if (auto clusterSize = getClusterSize()) {
uint32_t size = *clusterSize;
if (!llvm::isPowerOf2_32(size)) {
return emitOpError() << "cluster size " << size
<< " is not a power of two";
}
}

return success();
}

OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) {
if (getClusterSize() == 1)
return getValue();

if (!getUniform() && canMakeGroupOpUniform(*this)) {
setUniform(true);
return getResult();
Expand Down
49 changes: 36 additions & 13 deletions mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();

auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() < 2)
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
Expand Down Expand Up @@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
}

Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, op.getOp(), op.getUniform());
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
if (numElems == 1) {
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
continue;
Expand Down Expand Up @@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();

auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() != 1)
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
Expand All @@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final
Location loc = op.getLoc();
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
loc, extracted, op.getOp(), op.getUniform());
loc, extracted, op.getOp(), op.getUniform(), clusterSize);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
return success();
}
Expand All @@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
static Value createSubgroupShuffleReduction(
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
unsigned subgroupSize, function_ref<Value(Value)> packFn,
function_ref<Value(Value)> unpackFn) {
unsigned clusterSize, unsigned subgroupSize,
function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
assert(llvm::isPowerOf2_32(clusterSize));
assert(llvm::isPowerOf2_32(subgroupSize));
assert(clusterSize <= subgroupSize);
// Lane value always stays in the original type. We use it to perform arith
// reductions.
Value laneVal = input;
// Parallel reduction using butterfly shuffles.
for (unsigned i = 1; i < subgroupSize; i <<= 1) {
for (unsigned i = 1; i < clusterSize; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
/*width=*/subgroupSize,
Expand All @@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();
if (clusterSize && *clusterSize > subgroupSize)
return op.emitOpError()
<< "cluster size " << *clusterSize
<< " is greater than subgroup size " << subgroupSize;
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);

Type valueTy = op.getType();
unsigned elemBitwidth =
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
Expand All @@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final
auto identityFn = [](Value v) { return v; };
rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, identityFn, identityFn));
effectiveClusterSize, subgroupSize, identityFn,
identityFn));
return success();
}

Expand All @@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
};

rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
subgroupSize, packFn, unpackFn));
rewriter.replaceOp(
op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
op.getOp(), effectiveClusterSize,
subgroupSize, packFn, unpackFn));
return success();
}

Expand All @@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final

LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
std::optional<uint32_t> clusterSize = op.getClusterSize();
if (clusterSize && *clusterSize > subgroupSize)
return op.emitOpError()
<< "cluster size " << *clusterSize
<< " is greater than subgroup size " << subgroupSize;
unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);

auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
return rewriter.notifyMatchFailure(op, "value type is not a vector");
Expand Down Expand Up @@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
};

Value res =
createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
subgroupSize, packFn, unpackFn);
Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
op.getOp(), effectiveClusterSize,
subgroupSize, packFn, unpackFn);

if (vecBitwidth < shuffleBitwidth) {
res = rewriter.create<vector::ExtractStridedSliceOp>(
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/GPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ func.func @make_subgroup_reduce_uniform() {

// -----

// CHECK-LABEL: func @subgroup_reduce_cluster_size_1
// CHECK: gpu.launch blocks
// CHECK: %[[V1:.*]] = "test.test2"() : () -> i32
// CHECK: "test.test3"(%[[V1]]) : (i32) -> ()
func.func @subgroup_reduce_cluster_size_1() {
%0:6 = "test.test1"() : () -> (index, index, index, index, index, index)
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2)
threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) {
%1 = "test.test2"() : () -> i32
%2 = gpu.subgroup_reduce add %1 cluster_size(1) : (i32) -> (i32)
"test.test3"(%2) : (i32) -> ()
gpu.terminator
}
return
}

// -----

// The GPU kernel does not have any side effecting ops, so the entire
// gpu.launch op can fold away.

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,22 @@ func.func @reduce_invalid_op_type_maximumf(%arg0 : i32) {

// -----

func.func @subgroup_reduce_zero_cluster_size(%arg0 : vector<4xf32>) {
// expected-error@+1 {{cluster size 0 is not a power of two}}
%res = gpu.subgroup_reduce add %arg0 cluster_size(0) : (vector<4xf32>) -> vector<4xf32>
return
}

// -----

func.func @subgroup_reduce_npot_cluster_size(%arg0 : vector<4xf32>) {
// expected-error@+1 {{cluster size 3 is not a power of two}}
%res = gpu.subgroup_reduce add %arg0 cluster_size(3) : (vector<4xf32>) -> vector<4xf32>
return
}

// -----

func.func @subgroup_reduce_bad_type(%arg0 : vector<2x2xf32>) {
// expected-error@+1 {{'gpu.subgroup_reduce' op operand #0 must be Integer or Float or vector of}}
%res = gpu.subgroup_reduce add %arg0 : (vector<2x2xf32>) -> vector<2x2xf32>
Expand Down
Loading

0 comments on commit 7aa22f0

Please sign in to comment.