Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
"and replacing with supported ones";
}

def SPIRVReplicatedConstantCompositePass
: Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
let summary = "Convert splat composite constants and spec constants to "
"corresponding replicated constant composite ops defined by "
"SPV_EXT_replicated_composites";
}

#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLPass.cpp
ConvertToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
Expand Down Expand Up @@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion

add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLPass.cpp
ConvertToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert a splat composite spirv.Constant and
// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
// spirv.EXT.SpecConstantCompositeReplicate respectively.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

namespace mlir::spirv {
#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"

namespace {

static std::pair<Attribute, uint32_t>
getSplatAttributeAndCount(Attribute valueAttr) {
Attribute attr;
uint32_t splatCount = 0;
if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
if (denseAttr.isSplat()) {
attr = denseAttr.getSplatValue<Attribute>();
splatCount = denseAttr.size();
}
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
std::not_equal_to<>()) == arrayAttr.end()) {
attr = arrayAttr[0];
splatCount = arrayAttr.size();
}
}

if (attr) {
auto typedAttr = dyn_cast<TypedAttr>(attr);
if ((typedAttr && isa<spirv::CompositeType>(typedAttr.getType())) ||
isa<ArrayAttr>(attr)) {
std::pair<Attribute, uint32_t> newSplatAttrAndCount =
getSplatAttributeAndCount(attr);
if (newSplatAttrAndCount.first) {
return newSplatAttrAndCount;
}
}
}

return {attr, splatCount};
}

struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::ConstantOp op,
PatternRewriter &rewriter) const override {
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
if (!compositeType)
return rewriter.notifyMatchFailure(op, "not a composite constant");

auto [splatAttr, splatCount] = getSplatAttributeAndCount(op.getValue());
if (!splatAttr)
return rewriter.notifyMatchFailure(op, "composite is not splat");

if (splatCount == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one constituent");

rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
op, op.getType(), splatAttr);

return success();
}
};

struct SpecConstantCompositeOpConversion final
: OpRewritePattern<spirv::SpecConstantCompositeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
PatternRewriter &rewriter) const override {
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
if (!compositeType)
return rewriter.notifyMatchFailure(op, "not a composite constant");

ArrayAttr constituents = op.getConstituents();
if (constituents.size() == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one consituent");

if (!(std::adjacent_find(constituents.begin(), constituents.end(),
std::not_equal_to<>()) == constituents.end()))
return rewriter.notifyMatchFailure(op, "composite is not splat");

auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
if (!splatConstituent)
return rewriter.notifyMatchFailure(
op, "expected flat symbol reference for splat constituent");

rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);

return success();
}
};

struct ConvertToReplicatedConstantCompositePass final
: spirv::impl::SPIRVReplicatedConstantCompositePassBase<
ConvertToReplicatedConstantCompositePass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
context);
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};

} // namespace
} // namespace mlir::spirv
219 changes: 219 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/replicated-const-composites.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// RUN: mlir-opt --spirv-convert-to-replicated-const-composite --split-input-file --verify-diagnostics %s | FileCheck %s

spirv.module Logical GLSL450 {
spirv.func @splat_vector_of_i32() -> (vector<3xi32>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : vector<3xi32>
%0 = spirv.Constant dense<2> : vector<3xi32>
spirv.ReturnValue %0 : vector<3xi32>
}

spirv.func @splat_array_of_i32() -> (!spirv.array<3 x i32>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1 : i32] : !spirv.array<3 x i32>
%0 = spirv.Constant [1 : i32, 1 : i32, 1 : i32] : !spirv.array<3 x i32>
spirv.ReturnValue %0 : !spirv.array<3 x i32>
}

spirv.func @splat_array_of_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
%0 = spirv.Constant [[3 : i32, 3 : i32, 3 : i32], [3 : i32, 3 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}

spirv.func @splat_array_of_non_splat_array_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
%0 = spirv.Constant [[1 : i32, 2 : i32, 3 : i32], [1 : i32, 2 : i32, 3 : i32]] : !spirv.array<2 x !spirv.array<3 x i32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}

spirv.func @splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 2]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
}

spirv.func @splat_array_of_splat_vectors_of_i32() -> (!spirv.array<2 x vector<2xi32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.array<2 x vector<2xi32>>
%0 = spirv.Constant [dense<2> : vector<2xi32>, dense<2> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
}

spirv.func @splat_tensor_of_i32() -> (!spirv.array<2 x !spirv.array<3 x i32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3 : i32] : !spirv.array<2 x !spirv.array<3 x i32>>
%0 = spirv.Constant dense<3> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x i32>>
}

spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2 : i32] : !spirv.arm.tensor<2x3xi32>
%0 = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
}

spirv.func @splat_vector_of_f32() -> (vector<3xf32>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : vector<3xf32>
%0 = spirv.Constant dense<2.0> : vector<3xf32>
spirv.ReturnValue %0 : vector<3xf32>
}

spirv.func @splat_array_of_f32() -> (!spirv.array<3 x f32>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : !spirv.array<3 x f32>
%0 = spirv.Constant [1.0 : f32, 1.0 : f32, 1.0 : f32] : !spirv.array<3 x f32>
spirv.ReturnValue %0 : !spirv.array<3 x f32>
}

spirv.func @splat_array_of_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
%0 = spirv.Constant [[3.0 : f32, 3.0 : f32, 3.0 : f32], [3.0 : f32, 3.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
}

spirv.func @splat_array_of_non_splat_array_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate {{\[}}[1.000000e+00 : f32, 2.000000e+00 : f32, 3.000000e+00 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
%0 = spirv.Constant [[1.0 : f32, 2.0 : f32, 3.0 : f32], [1.0 : f32, 2.0 : f32, 3.0 : f32]] : !spirv.array<2 x !spirv.array<3 x f32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
}

spirv.func @splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [dense<[1.000000e+00, 2.000000e+00]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 2.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
}

spirv.func @splat_array_of_splat_vectors_of_f32() -> (!spirv.array<2 x vector<2xf32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.array<2 x vector<2xf32>>
%0 = spirv.Constant [dense<2.0> : vector<2xf32>, dense<2.0> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
}

spirv.func @splat_tensor_of_f32() -> (!spirv.array<2 x !spirv.array<3 x f32>>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [3.000000e+00 : f32] : !spirv.array<2 x !spirv.array<3 x f32>>
%0 = spirv.Constant dense<3.0> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
spirv.ReturnValue %0 : !spirv.array<2 x !spirv.array<3 x f32>>
}

spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
// CHECK: {{%.*}} = spirv.EXT.ConstantCompositeReplicate [2.000000e+00 : f32] : !spirv.arm.tensor<2x3xf32>
%0 = spirv.Constant dense<2.0> : !spirv.arm.tensor<2x3xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
}

spirv.func @array_of_one_i32() -> (!spirv.array<1 x i32>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant [1 : i32] : !spirv.array<1 x i32>
spirv.ReturnValue %0 : !spirv.array<1 x i32>
}

spirv.func @arm_tensor_of_one_i32() -> (!spirv.arm.tensor<1xi32>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi32>
spirv.ReturnValue %0 : !spirv.arm.tensor<1xi32>
}

spirv.func @non_splat_vector_of_i32() -> (vector<3xi32>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant dense<[0, 1, 2]> : vector<3xi32>
spirv.ReturnValue %0 : vector<3xi32>
}

spirv.func @non_splat_array_of_vectors_of_i32() -> (!spirv.array<2xvector<2xi32>>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant [dense<[1, 2]> : vector<2xi32>, dense<[1, 3]> : vector<2xi32>] : !spirv.array<2 x vector<2xi32>>
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xi32>>
}

spirv.func @array_of_one_f32() -> (!spirv.array<1 x f32>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant [1.0 : f32] : !spirv.array<1 x f32>
spirv.ReturnValue %0 : !spirv.array<1 x f32>
}

spirv.func @arm_tensor_of_one_f32() -> (!spirv.arm.tensor<1xf32>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant dense<1.0> : !spirv.arm.tensor<1xf32>
spirv.ReturnValue %0 : !spirv.arm.tensor<1xf32>
}
}

// -----

spirv.module Logical GLSL450 {
spirv.func @non_splat_vector_of_f32() -> (vector<3xf32>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant dense<[0.0, 1.0, 2.0]> : vector<3xf32>
spirv.ReturnValue %0 : vector<3xf32>
}
}

// -----

spirv.module Logical GLSL450 {
spirv.func @non_splat_array_of_vectors_of_f32() -> (!spirv.array<2xvector<2xf32>>) "None" {
// CHECK-NOT: spirv.EXT.ConstantCompositeReplicate
%0 = spirv.Constant [dense<[1.0, 2.0]> : vector<2xf32>, dense<[1.0, 3.0]> : vector<2xf32>] : !spirv.array<2 x vector<2xf32>>
spirv.ReturnValue %0 : !spirv.array<2 x vector<2xf32>>
}
}

// -----

spirv.module Logical GLSL450 {

spirv.SpecConstant @sc_i32_1 = 1 : i32

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_i32 (@sc_i32_1) : !spirv.array<3 x i32>
spirv.SpecConstantComposite @scc_splat_array_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.array<3 x i32>

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_i32 (@sc_i32_1) : !spirv.struct<(i32, i32, i32)>
spirv.SpecConstantComposite @scc_splat_struct_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.struct<(i32, i32, i32)>

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_i32 (@sc_i32_1) : vector<3xi32>
spirv.SpecConstantComposite @scc_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : vector<3 x i32>

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_i32 (@sc_i32_1) : !spirv.arm.tensor<3xi32>
spirv.SpecConstantComposite @scc_splat_arm_tensor_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>

spirv.SpecConstant @sc_f32_1 = 1.0 : f32

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_array_of_f32 (@sc_f32_1) : !spirv.array<3 x f32>
spirv.SpecConstantComposite @scc_splat_array_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.array<3 x f32>

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_struct_of_f32 (@sc_f32_1) : !spirv.struct<(f32, f32, f32)>
spirv.SpecConstantComposite @scc_splat_struct_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.struct<(f32, f32, f32)>

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_vector_of_f32 (@sc_f32_1) : vector<3xf32>
spirv.SpecConstantComposite @scc_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : vector<3 x f32>

// CHECK: spirv.EXT.SpecConstantCompositeReplicate @scc_splat_arm_tensor_of_f32 (@sc_f32_1) : !spirv.arm.tensor<3xf32>
spirv.SpecConstantComposite @scc_splat_arm_tensor_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>

spirv.SpecConstant @sc_i32_2 = 2 : i32

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_array_of_one_i32 (@sc_i32_1) : !spirv.array<1 x i32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_arm_tensor_of_one_i32 (@sc_i32_1) : !spirv.arm.tensor<1xi32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_non_splat_vector_of_i32 (@sc_i32_1, @sc_i32_1, @sc_i32_2) : vector<3 x i32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_i32 (@sc_i32_2, @sc_i32_1, @sc_i32_1) : !spirv.arm.tensor<3xi32>

spirv.SpecConstant @sc_f32_2 = 2.0 : f32

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_array_of_one_f32 (@sc_f32_1) : !spirv.array<1 x f32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_arm_tensor_of_one_f32 (@sc_f32_1) : !spirv.arm.tensor<1xf32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_non_splat_vector_of_f32 (@sc_f32_1, @sc_f32_1, @sc_f32_2) : vector<3 x f32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_non_splat_arm_tensor_of_f32 (@sc_f32_2, @sc_f32_1, @sc_f32_1) : !spirv.arm.tensor<3xf32>

// CHECK-NOT: spirv.EXT.SpecConstantCompositeReplicate
spirv.SpecConstantComposite @scc_struct_of_i32_and_f32 (@sc_i32_1, @sc_i32_1, @sc_f32_1) : !spirv.struct<(i32, i32, f32)>
}