Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ EXT(SPV_INTEL_global_variable_decorations)
EXT(SPV_INTEL_complex_float_mul_div)
EXT(SPV_INTEL_split_barrier)
EXT(SPV_INTEL_masked_gather_scatter)
EXT(SPV_INTEL_tensor_float32_conversion)
EXT(SPV_INTEL_tensor_float32_conversion) // TODO: to remove old extension
EXT(SPV_INTEL_tensor_float32_rounding)
EXT(SPV_EXT_relaxed_printf_string_address_space)
EXT(SPV_INTEL_fpga_argument_interfaces)
12 changes: 6 additions & 6 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3563,10 +3563,10 @@ _SPIRV_OP(MaskedScatter, false, 5)
#undef _SPIRV_OP

template <Op OC>
class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
return getVec(internal::CapabilityTensorFloat32ConversionINTEL);
return getVec(internal::CapabilityTensorFloat32RoundingINTEL);
}

std::optional<ExtensionID> getRequiredExtension() const override {
Expand All @@ -3587,8 +3587,8 @@ class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
// because it may call a method of class Module that may modify LiteralMap
// of Module field. That modification is not impacting validate method for
// these instructions, so const_cast is safe here.
using SPVTF32ConvTy = SPIRVTensorFloat32ConversionINTELInstBase<OC>;
SPIRVValue *Input = const_cast<SPVTF32ConvTy *>(this)->getOperand(0);
using SPVTF32RoundTy = SPIRVTensorFloat32RoundingINTELInstBase<OC>;
SPIRVValue *Input = const_cast<SPVTF32RoundTy *>(this)->getOperand(0);

SPIRVType *InCompTy = Input->getType();
SPIRVWord InCompCount = 1;
Expand Down Expand Up @@ -3616,8 +3616,8 @@ class SPIRVTensorFloat32ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
};

#define _SPIRV_OP(x) \
typedef SPIRVTensorFloat32ConversionINTELInstBase<internal::Op##x> SPIRV##x;
_SPIRV_OP(ConvertFToTF32INTEL)
typedef SPIRVTensorFloat32RoundingINTELInstBase<internal::Op##x> SPIRV##x;
_SPIRV_OP(RoundFToTF32INTEL)
#undef _SPIRV_OP
} // namespace SPIRV

Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
"GlobalVariableDecorationsINTEL");
add(internal::CapabilityComplexFloatMulDivINTEL, "ComplexFloatMulDivINTEL");
add(internal::CapabilityMaskedGatherScatterINTEL, "MaskedGatherScatterINTEL");
add(internal::CapabilityTensorFloat32ConversionINTEL,
"TensorFloat32ConversionINTEL");
add(internal::CapabilityTensorFloat32RoundingINTEL,
"TensorFloat32RoundingINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
2 changes: 1 addition & 1 deletion lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ _SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
_SPIRV_OP_INTERNAL(MaskedScatterINTEL, internal::OpMaskedScatterINTEL)
_SPIRV_OP_INTERNAL(ConvertFToTF32INTEL, internal::ConvertFToTF32INTEL)
_SPIRV_OP_INTERNAL(RoundFToTF32INTEL, internal::RoundFToTF32INTEL)
8 changes: 4 additions & 4 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ enum InternalOp {
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
IOpComplexFDivINTEL = 6416,
IOpConvertFToTF32INTEL = 6426,
IOpRoundFToTF32INTEL = 6426,
IOpMaskedGatherINTEL = 6428,
IOpMaskedScatterINTEL = 6429,
IOpPrev = OpMax - 2,
Expand All @@ -72,7 +72,7 @@ enum InternalCapability {
ICapFPArithmeticFenceINTEL = 6144,
ICapGlobalVariableDecorationsINTEL = 6146,
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32ConversionINTEL = 6425,
ICapabilityTensorFloat32RoundingINTEL = 6425,
ICapabilityMaskedGatherScatterINTEL = 6427
};

Expand Down Expand Up @@ -118,8 +118,8 @@ _SPIRV_OP(Capability, MaskedGatherScatterINTEL)
_SPIRV_OP(Op, MaskedGatherINTEL)
_SPIRV_OP(Op, MaskedScatterINTEL)

_SPIRV_OP(Capability, TensorFloat32ConversionINTEL)
_SPIRV_OP(Op, ConvertFToTF32INTEL)
_SPIRV_OP(Capability, TensorFloat32RoundingINTEL)
_SPIRV_OP(Op, RoundFToTF32INTEL)
#undef _SPIRV_OP

constexpr Op OpForward = static_cast<Op>(IOpForward);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; CHECK-SPIRV: Capability TensorFloat32ConversionINTEL
; CHECK-SPIRV: Capability TensorFloat32RoundingINTEL
; CHECK-SPIRV: Extension "SPV_INTEL_tensor_float32_conversion"
; CHECK-SPIRV: TypeFloat [[#FP32Ty:]] 32
; CHECK-SPIRV: TypeVector [[#FP32v8Ty:]] [[#FP32Ty]] 8
Expand All @@ -22,24 +22,24 @@ target triple = "spir64-unknown-unknown"
; CHECK-SPIRV: FunctionParameter [[#FP32Ty]] [[FP32ValId:.*]]
; CHECK-SPIRV: FunctionParameter [[#FP32v8Ty]] [[FP32v8ValId:.*]]

; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
; CHECK-SPIRV: ConvertFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32Ty]] [[#]] [[FP32ValId]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32v8Ty]] [[#]] [[FP32v8ValId]]
; CHECK-SPIRV: RoundFToTF32INTEL [[#FP32Ty]] [[#]] [[#CONST]]

; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float
; CHECK-LLVM: call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>
; CHECK-LLVM: call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
; CHECK-LLVM: call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float
; CHECK-LLVM: call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>
; CHECK-LLVM: call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.000000e+00)

define spir_func void @_Z2opffv8(float %a, <8 x float> %in) {
%1 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float %a)
%2 = tail call spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float> %in)
%3 = tail call spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float 1.000000e+00)
%1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
%2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
%3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.000000e+00)
ret void
}

declare spir_func float @_Z27__spirv_ConvertFToTF32INTELf(float)
declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)

declare spir_func <8 x float> @_Z27__spirv_ConvertFToTF32INTELDv8_f(<8 x float>)
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)

!opencl.spir.version = !{!0}
!spirv.Source = !{!1}
Expand Down