Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ cscope.out
autoconf/aclocal.m4
autoconf/autom4te.cache
compile_commands.json
_codeql_detected_source_root
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityPackedCooperativeMatrixINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
"JointMatrixPackedInt2ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
"JointMatrixPackedInt4ComponentTypeINTEL");
add(internal::CapabilityPackedCooperativeMatrixINTEL,
"PackedCooperativeMatrixINTEL");
add(internal::CapabilityCooperativeMatrixPrefetchINTEL,
"CooperativeMatrixPrefetchINTEL");
add(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,
Expand Down
13 changes: 13 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,19 @@ void SPIRVTypeJointMatrixINTEL::decode(std::istream &I) {
Decoder >> Id >> CompType >> Args;
}

SPIRVCapVec SPIRVTypeJointMatrixINTEL::getRequiredCapability() const {
auto CV = getVec(internal::CapabilityJointMatrixINTEL);
if (SPIRVValue *LayoutVal = getLayout()) {
if (isConstantOpCode(LayoutVal->getOpCode())) {
uint64_t Layout =
static_cast<SPIRVConstant *>(LayoutVal)->getZExtIntValue();
if (Layout == internal::PackedA || Layout == internal::PackedB)
CV.push_back(internal::CapabilityPackedCooperativeMatrixINTEL);
}
}
return CV;
}

SPIRVTypeCooperativeMatrixKHR::SPIRVTypeCooperativeMatrixKHR(
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
std::vector<SPIRVValue *> Args)
Expand Down
4 changes: 1 addition & 3 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -1196,9 +1196,7 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
std::optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_INTEL_joint_matrix;
}
SPIRVCapVec getRequiredCapability() const override {
return {internal::CapabilityJointMatrixINTEL};
}
SPIRVCapVec getRequiredCapability() const override;
void setWordCount(SPIRVWord WordCount) override {
SPIRVType::setWordCount(WordCount);
Args.resize(WordCount - FixedWC);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_joint_matrix -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM

; Test that PackedCooperativeMatrixINTEL capability is emitted when using
; PackedA (layout=2) or PackedB (layout=3) matrix layouts.

; CHECK-SPIRV-DAG: Capability JointMatrixINTEL
; CHECK-SPIRV-DAG: Capability PackedCooperativeMatrixINTEL
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
; CHECK-SPIRV-DAG: TypeInt [[#Int32Ty:]] 32 0
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const12:]] 12
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const48:]] 48
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const3:]] 3
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const2:]] 2
; CHECK-SPIRV-DAG: Constant [[#Int32Ty]] [[#Const0:]] 0
; Layout 2 = PackedA, Layout 3 = PackedB
; TypeJointMatrixINTEL: Result Type Rows Cols Layout Scope Use
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTyA:]] [[#Int8Ty]] [[#Const12]] [[#Const48]] [[#Const2]] [[#Const3]] [[#Const0]]
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#MatTyB:]] [[#Int8Ty]] [[#Const48]] [[#Const12]] [[#Const3]] [[#Const3]] {{[0-9]+}}

; CHECK-SPIRV: JointMatrixLoadINTEL [[#MatTyA]]
; CHECK-SPIRV: JointMatrixLoadINTEL [[#MatTyB]]

; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", i8, 12, 48, 2, 3, 0)
; CHECK-LLVM: call spir_func target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 1)

; ModuleID = 'joint_matrix_packed.bc'
source_filename = "joint_matrix_packed.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

define weak_odr dso_local spir_kernel void @test_packed_matrix(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, i64 noundef %_arg_K) {
entry:
%call.ascast.i.a = addrspacecast ptr addrspace(1) %_arg_accA to ptr addrspace(4)
%call.ascast.i.b = addrspacecast ptr addrspace(1) %_arg_accB to ptr addrspace(4)
; Load matrix A with PackedA layout (layout=2)
%matrixA = tail call spir_func noundef target("spirv.JointMatrixINTEL", i8, 12, 48, 2, 3, 0) @_Z28__spirv_JointMatrixLoadINTEL_PackedA(ptr addrspace(4) noundef %call.ascast.i.a, i64 noundef %_arg_K, i32 noundef 2, i32 noundef 3, i32 noundef 0) #1
; Load matrix B with PackedB layout (layout=3)
%matrixB = tail call spir_func noundef target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 1) @_Z28__spirv_JointMatrixLoadINTEL_PackedB(ptr addrspace(4) noundef %call.ascast.i.b, i64 noundef %_arg_K, i32 noundef 3, i32 noundef 3, i32 noundef 0) #1
ret void
}

; Function declaration for loading matrix with PackedA layout
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i8, 12, 48, 2, 3, 0) @_Z28__spirv_JointMatrixLoadINTEL_PackedA(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) #0

; Function declaration for loading matrix with PackedB layout
declare dso_local spir_func noundef target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 1) @_Z28__spirv_JointMatrixLoadINTEL_PackedB(ptr addrspace(4) noundef, i64 noundef, i32 noundef, i32 noundef, i32 noundef) #0

attributes #0 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #1 = { convergent }