Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ EXT(SPV_INTEL_hw_thread_queries)
EXT(SPV_INTEL_global_variable_decorations)
EXT(SPV_INTEL_complex_float_mul_div)
EXT(SPV_INTEL_split_barrier)
EXT(SPV_INTEL_tensor_float32_conversion)
EXT(SPV_INTEL_masked_gather_scatter)
EXT(SPV_INTEL_tensor_float32_conversion) // TODO: to remove old extension
EXT(SPV_INTEL_tensor_float32_rounding)
EXT(SPV_EXT_relaxed_printf_string_address_space)
116 changes: 58 additions & 58 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3414,64 +3414,6 @@ _SPIRV_OP(ComplexFMulINTEL)
_SPIRV_OP(ComplexFDivINTEL)
#undef _SPIRV_OP

template <Op OC>
class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilityTensorFloat32ConversionINTEL);
}

llvm::Optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_INTEL_tensor_float32_conversion;
}

void validate() const override {
SPIRVUnaryInst<OC>::validate();

SPIRVType *ResCompTy = this->getType();
SPIRVWord ResCompCount = 1;
if (ResCompTy->isTypeVector()) {
ResCompCount = ResCompTy->getVectorComponentCount();
ResCompTy = ResCompTy->getVectorComponentType();
}

// validate is a const method, whilst getOperand is non-const method
// because it may call a method of class Module that may modify LiteralMap
// of Module field. That modification is not impacting validate method for
// these instructions, so const_cast is safe here.
using SPVTF32ConvTy = SPIRVTensorFloat32ConversionINTELInstBase<OC>;
SPIRVValue *Input = const_cast<SPVTF32ConvTy *>(this)->getOperand(0);

SPIRVType *InCompTy = Input->getType();
SPIRVWord InCompCount = 1;
if (InCompTy->isTypeVector()) {
InCompCount = InCompTy->getVectorComponentCount();
InCompTy = InCompTy->getVectorComponentType();
}

auto InstName = OpCodeNameMap::map(OC);
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();

SPVErrLog.checkError(
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
InstName + "\nResult value must be a scalar or vector of floating-point"
" 32-bit type\n");
SPVErrLog.checkError(InCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
InstName +
"\nInput value must be a scalar or vector of "
"floating-point 32-bit type\n");
SPVErrLog.checkError(
ResCompCount == InCompCount, SPIRVEC_InvalidInstruction,
InstName + "\nInput type must have the same number of components as "
"result type\n");
}
};

#define _SPIRV_OP(x) \
typedef SPIRVTensorFloat32ConversionINTELInstBase<internal::Op##x> SPIRV##x;
_SPIRV_OP(ConvertFToTF32INTEL)
#undef _SPIRV_OP

class SPIRVMaskedGatherScatterINTELInstBase : public SPIRVInstTemplateBase {
protected:
SPIRVCapVec getRequiredCapability() const override {
Expand Down Expand Up @@ -3619,6 +3561,64 @@ class SPIRVMaskedScatterINTELInst
_SPIRV_OP(MaskedGather, true, 7)
_SPIRV_OP(MaskedScatter, false, 5)
#undef _SPIRV_OP

template <Op OC>
class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilityTensorFloat32RoundingINTEL);
}

llvm::Optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_INTEL_tensor_float32_conversion;
}

void validate() const override {
SPIRVUnaryInst<OC>::validate();

SPIRVType *ResCompTy = this->getType();
SPIRVWord ResCompCount = 1;
if (ResCompTy->isTypeVector()) {
ResCompCount = ResCompTy->getVectorComponentCount();
ResCompTy = ResCompTy->getVectorComponentType();
}

// validate is a const method, whilst getOperand is non-const method
// because it may call a method of class Module that may modify LiteralMap
// of Module field. That modification is not impacting validate method for
// these instructions, so const_cast is safe here.
using SPVTF32RoundTy = SPIRVTensorFloat32RoundingINTELInstBase<OC>;
SPIRVValue *Input = const_cast<SPVTF32RoundTy *>(this)->getOperand(0);

SPIRVType *InCompTy = Input->getType();
SPIRVWord InCompCount = 1;
if (InCompTy->isTypeVector()) {
InCompCount = InCompTy->getVectorComponentCount();
InCompTy = InCompTy->getVectorComponentType();
}

auto InstName = OpCodeNameMap::map(OC);
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();

SPVErrLog.checkError(
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
InstName + "\nResult value must be a scalar or vector of floating-point"
" 32-bit type\n");
SPVErrLog.checkError(InCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
InstName +
"\nInput value must be a scalar or vector of "
"floating-point 32-bit type\n");
SPVErrLog.checkError(
ResCompCount == InCompCount, SPIRVEC_InvalidInstruction,
InstName + "\nInput type must have the same number of components as "
"result type\n");
}
};

#define _SPIRV_OP(x) \
typedef SPIRVTensorFloat32RoundingINTELInstBase<internal::Op##x> SPIRV##x;
_SPIRV_OP(RoundFToTF32INTEL)
#undef _SPIRV_OP
} // namespace SPIRV

#endif // SPIRV_LIBSPIRV_SPIRVINSTRUCTION_H
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(internal::CapabilityGlobalVariableDecorationsINTEL,
"GlobalVariableDecorationsINTEL");
add(internal::CapabilityComplexFloatMulDivINTEL, "ComplexFloatMulDivINTEL");
add(internal::CapabilityTensorFloat32ConversionINTEL,
"TensorFloat32ConversionINTEL");
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
add(internal::CapabilityTensorFloat32RoundingINTEL,
"TensorFloat32RoundingINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ _SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
internal::OpJointMatrixWorkItemLengthINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(ConvertFToTF32INTEL, internal::ConvertFToTF32INTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
_SPIRV_OP_INTERNAL(RoundFToTF32INTEL, internal::RoundFToTF32INTEL)
10 changes: 5 additions & 5 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum InternalOp {
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
IOpComplexFDivINTEL = 6416,
IOpConvertFToTF32INTEL = 6426,
IOpRoundFToTF32INTEL = 6426,
IOpMaskedGatherINTEL = 6428,
IOpMaskedScatterINTEL = 6429,
IOpPrev = OpMax - 2,
Expand Down Expand Up @@ -79,7 +79,7 @@ enum InternalCapability {
ICapFPArithmeticFenceINTEL = 6144,
ICapGlobalVariableDecorationsINTEL = 6146,
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32ConversionINTEL = 6425,
ICapabilityTensorFloat32RoundingINTEL = 6425,
ICapabilityMaskedGatherScatterINTEL = 6427
};

Expand Down Expand Up @@ -124,12 +124,12 @@ _SPIRV_OP(Capability, ComplexFloatMulDivINTEL)
_SPIRV_OP(Op, ComplexFMulINTEL)
_SPIRV_OP(Op, ComplexFDivINTEL)

_SPIRV_OP(Capability, TensorFloat32ConversionINTEL)
_SPIRV_OP(Op, ConvertFToTF32INTEL)

_SPIRV_OP(Capability, MaskedGatherScatterINTEL)
_SPIRV_OP(Op, MaskedGatherINTEL)
_SPIRV_OP(Op, MaskedScatterINTEL)

_SPIRV_OP(Capability, TensorFloat32RoundingINTEL)
_SPIRV_OP(Op, RoundFToTF32INTEL)
#undef _SPIRV_OP

constexpr Op OpForward = static_cast<Op>(IOpForward);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; CHECK-SPIRV: Capability TensorFloat32ConversionINTEL
; CHECK-SPIRV: Capability TensorFloat32RoundingINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_tensor_float32_conversion"
; CHECK-SPIRV: TypeFloat [[#FP32Ty:]] 32
; CHECK-SPIRV: TypeVector [[#FP32v8Ty:]] [[#FP32Ty]] 8
Expand All @@ -22,24 +22,24 @@ target triple = "spir64-unknown-unknown"
; CHECK-SPIRV: FunctionParameter [[#FP32Ty]] [[FP32ValId:.*]]
; CHECK-SPIRV: FunctionParameter [[#FP32v8Ty]] [[FP32v8ValId:.*]]

; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]]

; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float
; CHECK-LLVM: call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>
; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
; CHECK-LLVM: call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float
; CHECK-LLVM: call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>
; CHECK-LLVM: call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.000000e+00)

define spir_func void @_Z2opffv8(float %a, <8 x float> %in) {
%1 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float %a)
%2 = tail call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> %in)
%3 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
%1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
%2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
%3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.000000e+00)
ret void
}

declare spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float)
declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)

declare spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>)
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)

!opencl.spir.version = !{!0}
!spirv.Source = !{!1}
Expand Down