From 0e082f74d85da92ccddb63d12478cba6ad6834d8 Mon Sep 17 00:00:00 2001 From: Dmitry Sidorov Date: Wed, 15 Feb 2023 19:49:45 +0100 Subject: [PATCH] [Backport to 16] Split JointMatrixMadINTEL instruction into 4 (#1833) JointMatrixMadINTEL will stand for signed/signed Matrix type JointMatrixSUMadINTEL will stand for signed/signed Matrix type JointMatrixUSMadINTEL will stand for unsigned/signed Matrix type JointMatrixUUMadINTEL will stand for unsigned/unsigned Matrix type Spec update: intel/llvm#8175 Signed-off-by: Dmitry Sidorov dmitry.sidorov@intel.com --- lib/SPIRV/SPIRVWriter.cpp | 11 +++------- lib/SPIRV/libSPIRV/SPIRVInstruction.h | 3 +++ lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h | 3 +++ lib/SPIRV/libSPIRV/spirv_internal.hpp | 6 +++++ .../SPV_INTEL_joint_matrix/joint_matrix.ll | 22 ++++++++++++++++++- 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 5650ad49fd..a6266a820f 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -608,14 +608,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) { return TranslatedTy; } -// Representation in LLVM IR before the translator is a pointer array wrapped -// in a structure: -// %struct.__spirv_JointMatrixINTEL = type { [R x [C x [L x [S x type]]]]* } -// where R = Rows, C = Columnts, L = Layout + 1, S = Scope + 1 -// this '+1' for the Layout and Scope is required because both of them can -// be '0', but array size can not be '0'. -// The result should look like SPIR-V friendly LLVM IR: -// %spirv.JointMatrixINTEL._char_2_2_0_3 +// Representation in LLVM IR before the translator is a pointer to an opaque +// structure: +// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use% // Here we check the structure name yet again. Another option would be to // check SPIR-V friendly function calls (by their name) and obtain return // or their parameter types, assuming, that the appropriate types are Matrix diff --git a/lib/SPIRV/libSPIRV/SPIRVInstruction.h b/lib/SPIRV/libSPIRV/SPIRVInstruction.h index d322d08d21..57add6b58d 100644 --- a/lib/SPIRV/libSPIRV/SPIRVInstruction.h +++ b/lib/SPIRV/libSPIRV/SPIRVInstruction.h @@ -3324,6 +3324,9 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase { _SPIRV_OP(JointMatrixLoad, true, 6, true) _SPIRV_OP(JointMatrixStore, false, 5, true) _SPIRV_OP(JointMatrixMad, true, 7) +_SPIRV_OP(JointMatrixSUMad, true, 7) +_SPIRV_OP(JointMatrixUSMad, true, 7) +_SPIRV_OP(JointMatrixUUMad, true, 7) _SPIRV_OP(JointMatrixWorkItemLength, true, 4) #undef _SPIRV_OP diff --git a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h index ea888d8aad..2682f86937 100644 --- a/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h +++ b/lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h @@ -9,6 +9,9 @@ _SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL) _SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL) _SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL) _SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL) +_SPIRV_OP_INTERNAL(JointMatrixSUMadINTEL, internal::OpJointMatrixSUMadINTEL) +_SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL) +_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL) _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL, internal::OpJointMatrixWorkItemLengthINTEL) _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL) diff --git a/lib/SPIRV/libSPIRV/spirv_internal.hpp b/lib/SPIRV/libSPIRV/spirv_internal.hpp index bb7555de85..203f6afadc 100644 --- a/lib/SPIRV/libSPIRV/spirv_internal.hpp +++ b/lib/SPIRV/libSPIRV/spirv_internal.hpp @@ -42,6 +42,9 @@ enum InternalOp { IOpJointMatrixLoadINTEL = 6120, IOpJointMatrixStoreINTEL = 6121, IOpJointMatrixMadINTEL = 6122, + IOpJointMatrixSUMadINTEL = 6128, + IOpJointMatrixUSMadINTEL = 6129, + IOpJointMatrixUUMadINTEL = 6130, IOpArithmeticFenceINTEL = 6145, IOpJointMatrixWorkItemLengthINTEL = 6410, IOpComplexFMulINTEL = 6415, @@ -105,6 +108,9 @@ _SPIRV_OP(Op, TypeJointMatrixINTEL) _SPIRV_OP(Op, JointMatrixLoadINTEL) _SPIRV_OP(Op, JointMatrixStoreINTEL) _SPIRV_OP(Op, JointMatrixMadINTEL) +_SPIRV_OP(Op, JointMatrixSUMadINTEL) +_SPIRV_OP(Op, JointMatrixUSMadINTEL) +_SPIRV_OP(Op, JointMatrixUUMadINTEL) _SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL) _SPIRV_OP(Capability, HWThreadQueryINTEL) _SPIRV_OP(BuiltIn, SubDeviceIDINTEL) diff --git a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll index 4c5a470ab1..72010ed93e 100644 --- a/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll +++ b/test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll @@ -34,6 +34,10 @@ ; CHECK-SPIRV: JointMatrixLoadINTEL [[#ATy]] [[#A:]] [[#Aptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]] ; CHECK-SPIRV: JointMatrixLoadINTEL [[#BTy]] [[#B:]] [[#Bptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]] ; CHECK-SPIRV: JointMatrixMadINTEL [[#CTy]] [[#CMad]] [[#A]] [[#B]] [[#C]] [[#Three]] +; CHECK-SPIRV: JointMatrixSUMadINTEL [[#CTy]] [[#UnusedMad1:]] [[#A]] [[#B]] [[#C]] [[#Three]] +; CHECK-SPIRV: JointMatrixUSMadINTEL [[#CTy]] [[#UnusedMad2:]] [[#A]] [[#B]] [[#C]] [[#Three]] +; CHECK-SPIRV: JointMatrixUUMadINTEL [[#CTy]] [[#UnusedMad3:]] [[#A]] [[#B]] [[#C]] [[#Three]] + ; CHECK-SPIRV: JointMatrixStoreINTEL [[#Cptr:]] [[#C]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]] ; CHECK-SPIRV: CompositeConstruct [[#CTy]] [[#Cnew:]] [[#FortyTwo]] ; CHECK-SPIRV: Store [[#PtrToZero:]] [[#Zero]] @@ -49,7 +53,11 @@ ; CHECK-LLVM: [[C:%.*]] = phi %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [ [[CLoaded]], %entry ], [ [[CMad:%.*]], %for.body.i ] ; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* @_Z79__spirv_JointMatrixLoadINTEL_RPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0) ; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0) -; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3) +; CHECK-LLVM: [[CMad1:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3) +; CHECK-LLVM: [[CMad2:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixSUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3) +; CHECK-LLVM: [[CMad3:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUSMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3) +; CHECK-LLVM: [[CMad4:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3) + ; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0) ; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42) ; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4 @@ -115,6 +123,9 @@ for.body.i: ; preds = %for.cond.i %add.ptr17.i = addrspacecast i8 addrspace(1)* %add.ptr17.i56 to i8 addrspace(4)* %call18.i = tail call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* @_Z28__spirv_JointMatrixLoadINTELIaLm16ELm2ELN5__spv12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEEPS5_mS1_S3_i(i8 addrspace(4)* %add.ptr17.i, i64 %_arg_1, i32 0, i32 3, i32 0) #3 %call19.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3 + %call20.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3 + %call21.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3 + %call22.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3 %add.i = add nuw nsw i32 %k.0.i, 16 br label %for.cond.i, !llvm.loop !19 @@ -142,6 +153,15 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* ; Function Attrs: convergent declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1 +; Function Attrs: convergent +declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1 + +; Function Attrs: convergent +declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1 + +; Function Attrs: convergent +declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1 + ; Function Attrs: convergent declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIsLm2ELm2ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(i16 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1