Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilitySubgroupAvcMotionEstimationINTEL});
ADD_VEC_INIT(CapabilitySubgroupAvcMotionEstimationChromaINTEL,
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
{internal::CapabilityJointMatrixINTEL});
}

template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {
Expand Down
15 changes: 15 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3327,9 +3327,24 @@ _SPIRV_OP(JointMatrixMad, true, 7)
_SPIRV_OP(JointMatrixSUMad, true, 7)
_SPIRV_OP(JointMatrixUSMad, true, 7)
_SPIRV_OP(JointMatrixUUMad, true, 7)
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we move this right away in this PR? To preserve backward compatibility by not requiring support of JointMatrixWIInstructions capability for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usage of OpJointMatrixWorkItemLength will currently emit just CapabilityJointMatrixINTEL. Which is supported in our downstream. Given the fact, that it's SYCL's counterpart is already implemented and used in tests and user code making the SPIR-V instruction be bound with CapabilityJointMatrixWIInstructionsINTEL would make these tests/user code not working, since this capability is not yet introduced in our downstream.

With this patch I'm introducing it, and then, when our drivers adopt it - the TODO can be resolved.

_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
#undef _SPIRV_OP

class SPIRVJointMatrixINTELWorkItemInst : public SPIRVJointMatrixINTELInstBase {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilityJointMatrixWIInstructionsINTEL);
}
};

#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVJointMatrixINTELWorkItemInst, \
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(JointMatrixGetElementCoord, true, 5)
#undef _SPIRV_OP

class SPIRVSplitBarrierINTELBase : public SPIRVInstTemplateBase {
protected:
SPIRVCapVec getRequiredCapability() const override {
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
add(internal::CapabilityTensorFloat32ConversionINTEL,
"TensorFloat32ConversionINTEL");
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
"JointMatrixWIInstructionsINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ _SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
internal::OpJointMatrixWorkItemLengthINTEL)
_SPIRV_OP_INTERNAL(JointMatrixGetElementCoordINTEL,
internal::OpJointMatrixGetElementCoordINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
Expand Down
7 changes: 6 additions & 1 deletion lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ enum InternalOp {
IOpConvertFToTF32INTEL = 6426,
IOpMaskedGatherINTEL = 6428,
IOpMaskedScatterINTEL = 6429,
IOpJointMatrixGetElementCoordINTEL = 6440,
IOpPrev = OpMax - 2,
IOpForward
};
Expand All @@ -76,7 +77,8 @@ enum InternalCapability {
ICapGlobalVariableDecorationsINTEL = 6146,
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32ConversionINTEL = 6425,
ICapabilityMaskedGatherScatterINTEL = 6427
ICapabilityMaskedGatherScatterINTEL = 6427,
ICapabilityJointMatrixWIInstructionsINTEL = 6435
};

enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
Expand Down Expand Up @@ -104,6 +106,7 @@ enum InternalBuiltIn {

#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
_SPIRV_OP(Capability, JointMatrixINTEL)
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
_SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
Expand All @@ -112,6 +115,8 @@ _SPIRV_OP(Op, JointMatrixSUMadINTEL)
_SPIRV_OP(Op, JointMatrixUSMadINTEL)
_SPIRV_OP(Op, JointMatrixUUMadINTEL)
_SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
_SPIRV_OP(Op, JointMatrixGetElementCoordINTEL)

_SPIRV_OP(Capability, HWThreadQueryINTEL)
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
_SPIRV_OP(BuiltIn, GlobalHWThreadIDINTEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,26 @@
; RUN: llvm-spirv -r -emit-opaque-pointers %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: Capability JointMatrixINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 64
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]]
; CHECK-SPIRV-DAG: Capability JointMatrixINTEL
; CHECK-SPIRV-DAG: Capability JointMatrixWIInstructionsINTEL
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
; CHECK-SPIRV-DAG: TypeInt [[#TypeInt32:]] 32
; CHECK-SPIRV-DAG: TypeInt [[#TypeInt64:]] 64
; CHECK-SPIRV-DAG: TypeFloat [[#TypeFloat:]] 32
; CHECK-SPIRV-DAG: TypeJointMatrixINTEL [[#TypeMatrix:]] [[#TypeFloat]] [[#]] [[#]] [[#]] [[#]]
; CHECK-SPIRV-DAG: TypeVector [[#TypeVec:]] [[#TypeInt32]] 2
; CHECK-SPIRV: Phi [[#TypeMatrix]] [[#Matrix:]]
; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt]] [[#]] [[#Matrix]]
; CHECK-SPIRV: JointMatrixWorkItemLengthINTEL [[#TypeInt64]] [[#]] [[#Matrix]]
; CHECK-SPIRV: VectorExtractDynamic [[#TypeFloat]] [[#]] [[#Matrix]] [[#Index:]]
; CHECK-SPIRV: FMul [[#TypeFloat]] [[#NewVal:]] [[#]] [[#]]
; CHECK-SPIRV: VectorInsertDynamic [[#TypeMatrix]] [[#]] [[#Matrix]] [[#NewVal]] [[#Index]]
; CHECK-SPIRV: JointMatrixGetElementCoordINTEL [[#TypeVec]] [[#]] [[#Matrix]] [[#Index]]

; CHECK-LLVM: [[Length:%.*]] = call spir_func i64 @_Z38__spirv_JointMatrixWorkItemLengthINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3(ptr addrspace(1) [[Matrix:%.*]])
; CHECK-LLVM: [[Elem:%.*]] = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(ptr addrspace(1) [[Matrix]], i64 [[Index:%.*]])
; CHECK-LLVM: [[NewVal:%.*]] = fmul float [[Elem]], 5.000000e+00
; CHECK-LLVM: {{%.*}} = call spir_func ptr addrspace(1) @_Z27__spirv_VectorInsertDynamicPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3fl(ptr addrspace(1) [[Matrix]], float [[NewVal]], i64 [[Index]])
; CHECK-LLVM: {{%.*}} = call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELPU3AS141__spirv_JointMatrixINTEL__float_16_16_0_3l(ptr addrspace(1) [[Matrix]], i64 [[Index]])

source_filename = "/work/tmp/matrix-slice.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
Expand Down Expand Up @@ -69,6 +74,7 @@ for.body.i: ; preds = %for.cond.i
%call.i.i = tail call spir_func float @_Z28__spirv_VectorExtractDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmET_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEET4_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2
%mul.i.i = fmul float %call.i.i, 5.000000e+00
%call5.i.i = tail call spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* @_Z27__spirv_VectorInsertDynamicIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EmEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEES7_T4_S5_(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, float %mul.i.i, i64 %conv.i) #2
%call6 = tail call spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)* %A.sroa.0.0.i, i64 %conv.i) #2
%inc.i = add nuw nsw i32 %i.0.i, 1
br label %for.cond.i, !llvm.loop !7

Expand All @@ -92,6 +98,9 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4
; Function Attrs: convergent
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIfLm16ELm16ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(float addrspace(4)*, %spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare dso_local spir_func <2 x i32> @_Z39__spirv_JointMatrixGetElementCoordINTELIaLm8ELm32ELN5__spv9MatrixUseE0ELNS0_12MatrixLayoutE0ELNS0_5Scope4FlagE3EEDv2_jPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT3_EXT4_EXT2_EEEm(%spirv.JointMatrixINTEL._float_16_16_0_3 addrspace(4)*, i64) #2

attributes #0 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="/work/tmp/matrix-slice.cpp" "uniform-work-group-size"="true" }
attributes #1 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { convergent }
Expand Down