Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions lib/SPIRV/OCL20ToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,9 @@ void OCL20ToSPIRV::visitCallInst(CallInst &CI) {
return;
}
if (DemangledName == kOCLBuiltinName::Dot ||
DemangledName == kOCLBuiltinName::DotAccSat) {
DemangledName == kOCLBuiltinName::DotAccSat ||
DemangledName.find(kOCLBuiltinName::Dot4x8PackedPrefix, 0) == 0 ||
DemangledName.find(kOCLBuiltinName::DotAccSat4x8PackedPrefix, 0) == 0) {
if (CI.getOperand(0)->getType()->isVectorTy()) {
auto *VT = (VectorType *)(CI.getOperand(0)->getType());
if (!isa<llvm::IntegerType>(VT->getElementType())) {
Expand Down Expand Up @@ -1498,19 +1500,11 @@ void OCL20ToSPIRV::visitCallDot(CallInst *CI, StringRef MangledName,
// translation for dot function calls,
// to differentiate between integer dot products

SmallVector<Value *, 3> Args;
Args.push_back(CI->getOperand(0));
Args.push_back(CI->getOperand(1));
bool IsFirstSigned, IsSecondSigned;
bool IsDot = DemangledName == kOCLBuiltinName::Dot;
std::string FunName = (IsDot) ? "DotKHR" : "DotAccSatKHR";
if (CI->getNumArgOperands() > 2) {
Args.push_back(CI->getOperand(2));
}
if (CI->getNumArgOperands() > 3) {
Args.push_back(CI->getOperand(3));
}
if (CI->getOperand(0)->getType()->isVectorTy()) {
bool IsAccSat = DemangledName.contains(kOCLBuiltinName::DotAccSat);
bool IsPacked = CI->getOperand(0)->getType()->isIntegerTy();
if (!IsPacked) {
if (IsDot) {
// dot(char4, char4) _Z3dotDv4_cS_
// dot(char4, uchar4) _Z3dotDv4_cDv4_h
Expand Down Expand Up @@ -1551,21 +1545,28 @@ void OCL20ToSPIRV::visitCallDot(CallInst *CI, StringRef MangledName,
}
} else {
// for packed format
// dot(int, int, int) _Z3dotiii
// dot(int, uint, int) _Z3dotiji
// dot(uint, int, int) _Z3dotjii
// dot(uint, uint, int) _Z3dotjji
// dot_4x8packed_ss_int(uint, uint) _Z20dot_4x8packed_ss_intjj
// dot_4x8packed_su_int(uint, uint) _Z20dot_4x8packed_su_intjj
// dot_4x8packed_us_int(uint, uint) _Z20dot_4x8packed_us_intjj
// dot_4x8packed_uu_uint(uint, uint) _Z21dot_4x8packed_uu_uintjj
// or
// dot_acc_sat(int, int, int, int) _Z11dot_acc_satiiii
// dot_acc_sat(int, uint, int, int) _Z11dot_acc_satijii
// dot_acc_sat(uint, int, int, int) _Z11dot_acc_satjiii
// dot_acc_sat(uint, uint, int, int) _Z11dot_acc_satjjii
assert(MangledName.startswith("_Z3dot") ||
MangledName.startswith("_Z11dot_acc_sat"));
IsFirstSigned = (IsDot) ? (MangledName[MangledName.size() - 3] == 'i')
: (MangledName[MangledName.size() - 4] == 'i');
IsSecondSigned = (IsDot) ? (MangledName[MangledName.size() - 2] == 'i')
: (MangledName[MangledName.size() - 3] == 'i');
// dot_acc_sat_4x8packed_ss_int(uint, uint, int)
// _Z28dot_acc_sat_4x8packed_ss_intjji
// dot_acc_sat_4x8packed_su_int(uint, uint, int)
// _Z28dot_acc_sat_4x8packed_su_intjji
// dot_acc_sat_4x8packed_us_int(uint, uint, int)
// _Z28dot_acc_sat_4x8packed_us_intjji
// dot_acc_sat_4x8packed_uu_uint(uint, uint, uint)
// _Z29dot_acc_sat_4x8packed_uu_uintjjj
assert(MangledName.startswith("_Z20dot_4x8packed") ||
MangledName.startswith("_Z21dot_4x8packed") ||
MangledName.startswith("_Z28dot_acc_sat_4x8packed") ||
MangledName.startswith("_Z29dot_acc_sat_4x8packed"));
size_t SignIndex = IsAccSat
? strlen(kOCLBuiltinName::DotAccSat4x8PackedPrefix)
: strlen(kOCLBuiltinName::Dot4x8PackedPrefix);
IsFirstSigned = DemangledName[SignIndex] == 's';
IsSecondSigned = DemangledName[SignIndex + 1] == 's';
}
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstSPIRV(
Expand All @@ -1578,7 +1579,7 @@ void OCL20ToSPIRV::visitCallDot(CallInst *CI, StringRef MangledName,
std::swap(Args[0], Args[1]);
}
Op OC;
if (IsDot) {
if (!IsAccSat) {
OC = (IsFirstSigned != IsSecondSigned
? OpSUDotKHR
: ((IsFirstSigned) ? OpSDotKHR : OpUDotKHR));
Expand All @@ -1587,6 +1588,14 @@ void OCL20ToSPIRV::visitCallDot(CallInst *CI, StringRef MangledName,
? OpSUDotAccSatKHR
: ((IsFirstSigned) ? OpSDotAccSatKHR : OpUDotAccSatKHR));
}
if (IsPacked) {
// As per SPIRV specification the dot OpCodes
// which use scalar integers to represent
// packed vectors need additional argument
// specified - the Packed Vector Format
Args.push_back(
getInt32(M, PackedVectorFormatPackedVectorFormat4x8BitKHR));
}
return getSPIRVFuncName(OC);
},
&Attrs);
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/OCLUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ const static char Clamp[] = "clamp";
const static char ConvertPrefix[] = "convert_";
const static char Dot[] = "dot";
const static char DotAccSat[] = "dot_acc_sat";
const static char Dot4x8PackedPrefix[] = "dot_4x8packed_";
const static char DotAccSat4x8PackedPrefix[] = "dot_acc_sat_4x8packed_";
const static char EnqueueKernel[] = "enqueue_kernel";
const static char FMax[] = "fmax";
const static char FMin[] = "fmin";
Expand Down
32 changes: 16 additions & 16 deletions test/transcoding/SPV_KHR_integer_dot_product_OCLtoSPIRV_int.ll
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,40 @@ target triple = "spir"
; Function Attrs: convergent norecurse nounwind
define spir_kernel void @test1(i32 %ia, i32 %ua, i32 %ib, i32 %ub, i32 %ires, i32 %ures) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !5 !kernel_arg_type_qual !6 {
entry:
%call = tail call spir_func i32 @_Z3dotiii(i32 %ia, i32 %ib, i32 0) #2
%call1 = tail call spir_func i32 @_Z3dotiji(i32 %ia, i32 %ub, i32 0) #2
%call2 = tail call spir_func i32 @_Z3dotjii(i32 %ua, i32 %ib, i32 0) #2
%call3 = tail call spir_func i32 @_Z3dotjji(i32 %ua, i32 %ub, i32 0) #2
%call4 = tail call spir_func i32 @_Z11dot_acc_satiiii(i32 %ia, i32 %ib, i32 %ires, i32 0) #2
%call5 = tail call spir_func i32 @_Z11dot_acc_satijii(i32 %ia, i32 %ub, i32 %ires, i32 0) #2
%call6 = tail call spir_func i32 @_Z11dot_acc_satjiii(i32 %ua, i32 %ib, i32 %ires, i32 0) #2
%call7 = tail call spir_func i32 @_Z11dot_acc_satjjji(i32 %ua, i32 %ub, i32 %ures, i32 0) #2
%call = tail call spir_func i32 @_Z20dot_4x8packed_ss_intjj(i32 %ia, i32 %ib) #2
%call1 = tail call spir_func i32 @_Z20dot_4x8packed_su_intjj(i32 %ia, i32 %ub) #2
%call2 = tail call spir_func i32 @_Z20dot_4x8packed_us_intjj(i32 %ua, i32 %ib) #2
%call3 = tail call spir_func i32 @_Z21dot_4x8packed_uu_uintjj(i32 %ua, i32 %ub) #2
%call4 = tail call spir_func i32 @_Z28dot_acc_sat_4x8packed_ss_intjji(i32 %ia, i32 %ib, i32 %ires) #2
%call5 = tail call spir_func i32 @_Z28dot_acc_sat_4x8packed_su_intjji(i32 %ia, i32 %ub, i32 %ires) #2
%call6 = tail call spir_func i32 @_Z28dot_acc_sat_4x8packed_us_intjji(i32 %ua, i32 %ib, i32 %ires) #2
%call7 = tail call spir_func i32 @_Z29dot_acc_sat_4x8packed_uu_uintjjj(i32 %ua, i32 %ub, i32 %ures) #2
ret void
}

; Function Attrs: convergent
declare spir_func i32 @_Z3dotiii(i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z20dot_4x8packed_ss_intjj(i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotiji(i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z20dot_4x8packed_su_intjj(i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotjii(i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z20dot_4x8packed_us_intjj(i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotjji(i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z21dot_4x8packed_uu_uintjj(i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satiiii(i32, i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z28dot_acc_sat_4x8packed_ss_intjji(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satijii(i32, i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z28dot_acc_sat_4x8packed_su_intjji(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satjiii(i32, i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z28dot_acc_sat_4x8packed_us_intjji(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satjjji(i32, i32, i32, i32) local_unnamed_addr #1
declare spir_func i32 @_Z29dot_acc_sat_4x8packed_uu_uintjjj(i32, i32, i32) local_unnamed_addr #1

attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
Expand Down