Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion lib/SPIRV/OCLToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase> {
/// Transforms OpDot instructions with a scalar type to a fmul instruction
void visitCallDot(CallInst *CI);

/// Transforms OpDot instructions with a vector or scalar (packed vector) type
/// to dot or dot_acc_sat instructions
void visitCallDot(CallInst *CI, StringRef MangledName,
StringRef DemangledName);

/// Fixes for built-in functions with vector+scalar arguments that are
/// translated to the SPIR-V instructions where all arguments must have the
/// same type.
Expand Down Expand Up @@ -521,10 +526,23 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
return;
}
if (DemangledName == kOCLBuiltinName::Dot &&
!(CI.getOperand(0)->getType()->isVectorTy())) {
(CI.getOperand(0)->getType()->isFloatTy() ||
CI.getOperand(1)->getType()->isDoubleTy())) {
visitCallDot(&CI);
return;
}
if (DemangledName == kOCLBuiltinName::Dot ||
DemangledName == kOCLBuiltinName::DotAccSat) {
if (CI.getOperand(0)->getType()->isVectorTy()) {
auto *VT = (VectorType *)(CI.getOperand(0)->getType());
if (!isa<llvm::IntegerType>(VT->getElementType())) {
visitCallBuiltinSimple(&CI, MangledName, DemangledName);
return;
}
}
visitCallDot(&CI, MangledName, DemangledName);
return;
}
if (DemangledName == kOCLBuiltinName::FMin ||
DemangledName == kOCLBuiltinName::FMax ||
DemangledName == kOCLBuiltinName::Min ||
Expand Down Expand Up @@ -1480,6 +1498,105 @@ void OCLToSPIRVBase::visitCallDot(CallInst *CI) {
CI->eraseFromParent();
}

void OCLToSPIRVBase::visitCallDot(CallInst *CI, StringRef MangledName,
StringRef DemangledName) {
// translation for dot function calls,
// to differentiate between integer dot products

SmallVector<Value *, 3> Args;
Args.push_back(CI->getOperand(0));
Args.push_back(CI->getOperand(1));
bool IsFirstSigned, IsSecondSigned;
bool IsDot = DemangledName == kOCLBuiltinName::Dot;
std::string FunName = (IsDot) ? "DotKHR" : "DotAccSatKHR";
if (CI->getNumArgOperands() > 2) {
Args.push_back(CI->getOperand(2));
}
if (CI->getNumArgOperands() > 3) {
Args.push_back(CI->getOperand(3));
}
if (CI->getOperand(0)->getType()->isVectorTy()) {
if (IsDot) {
// dot(char4, char4) _Z3dotDv4_cS_
// dot(char4, uchar4) _Z3dotDv4_cDv4_h
// dot(uchar4, char4) _Z3dotDv4_hDv4_c
// dot(uchar4, uchar4) _Z3dotDv4_hS_
// or
// dot(short2, short2) _Z3dotDv2_sS_
// dot(short2, ushort2) _Z3dotDv2_sDv2_t
// dot(ushort2, short2) _Z3dotDv2_tDv2_s
// dot(ushort2, ushort2) _Z3dotDv2_tS_
assert(MangledName.startswith("_Z3dotDv"));
if (MangledName[MangledName.size() - 1] == '_') {
IsFirstSigned = ((MangledName[MangledName.size() - 3] == 'c') ||
(MangledName[MangledName.size() - 3] == 's'));
IsSecondSigned = IsFirstSigned;
} else {
IsFirstSigned = ((MangledName[MangledName.size() - 6] == 'c') ||
(MangledName[MangledName.size() - 6] == 's'));
IsSecondSigned = ((MangledName[MangledName.size() - 1] == 'c') ||
(MangledName[MangledName.size() - 1] == 's'));
}
} else {
// dot_acc_sat(char4, char4, int) _Z11dot_acc_satDv4_cS_i
// dot_acc_sat(char4, uchar4, int) _Z11dot_acc_satDv4_cDv4_hi
// dot_acc_sat(uchar4, char4, int) _Z11dot_acc_satDv4_hDv4_ci
// dot_acc_sat(uchar4, uchar4, uint) _Z11dot_acc_satDv4_hS_j
// or
// dot_acc_sat(short2, short2, int) _Z11dot_acc_satDv4_sS_i
// dot_acc_sat(short2, ushort2, int) _Z11dot_acc_satDv4_sDv4_ti
// dot_acc_sat(ushort2, short2, int) _Z11dot_acc_satDv4_tDv4_si
// dot_acc_sat(ushort2, ushort2, uint) _Z11dot_acc_satDv4_tS_j
assert(MangledName.startswith("_Z11dot_acc_satDv"));
IsFirstSigned = ((MangledName[19] == 'c') || (MangledName[19] == 's'));
IsSecondSigned = (MangledName[20] == 'S'
? IsFirstSigned
: ((MangledName[MangledName.size() - 2] == 'c') ||
(MangledName[MangledName.size() - 2] == 's')));
}
} else {
// for packed format
// dot(int, int, int) _Z3dotiii
// dot(int, uint, int) _Z3dotiji
// dot(uint, int, int) _Z3dotjii
// dot(uint, uint, int) _Z3dotjji
// or
// dot_acc_sat(int, int, int, int) _Z11dot_acc_satiiii
// dot_acc_sat(int, uint, int, int) _Z11dot_acc_satijii
// dot_acc_sat(uint, int, int, int) _Z11dot_acc_satjiii
// dot_acc_sat(uint, uint, int, int) _Z11dot_acc_satjjii
assert(MangledName.startswith("_Z3dot") ||
MangledName.startswith("_Z11dot_acc_sat"));
IsFirstSigned = (IsDot) ? (MangledName[MangledName.size() - 3] == 'i')
: (MangledName[MangledName.size() - 4] == 'i');
IsSecondSigned = (IsDot) ? (MangledName[MangledName.size() - 2] == 'i')
: (MangledName[MangledName.size() - 3] == 'i');
}
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstSPIRV(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
// If arguments are in order unsigned -> signed
// then the translator should swap them,
// so that the OpSUDotKHR can be used properly
if (IsFirstSigned == false && IsSecondSigned == true) {
std::swap(Args[0], Args[1]);
}
Op OC;
if (IsDot) {
OC = (IsFirstSigned != IsSecondSigned
? OpSUDotKHR
: ((IsFirstSigned) ? OpSDotKHR : OpUDotKHR));
} else {
OC = (IsFirstSigned != IsSecondSigned
? OpSUDotAccSatKHR
: ((IsFirstSigned) ? OpSDotAccSatKHR : OpUDotAccSatKHR));
}
return getSPIRVFuncName(OC);
},
&Attrs);
}

void OCLToSPIRVBase::visitCallScalToVec(CallInst *CI, StringRef MangledName,
StringRef DemangledName) {
// Check if all arguments have the same type - it's simple case.
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/OCLUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ const static char Barrier[] = "barrier";
const static char Clamp[] = "clamp";
const static char ConvertPrefix[] = "convert_";
const static char Dot[] = "dot";
const static char DotAccSat[] = "dot_acc_sat";
const static char EnqueueKernel[] = "enqueue_kernel";
const static char FixedSqrtINTEL[] = "intel_arbitrary_fixed_sqrt";
const static char FixedRecipINTEL[] = "intel_arbitrary_fixed_recip";
Expand Down
85 changes: 85 additions & 0 deletions test/transcoding/SPV_KHR_integer_dot_product_OCLtoSPIRV_char4.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv -s %t.bc -o %t.regularized.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_integer_dot_product -o %t-spirv.spv
; RUN: spirv-val %t-spirv.spv
; RUN: llvm-dis %t.regularized.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
; RUN: llvm-spirv %t.bc -spirv-text --spirv-ext=+SPV_KHR_integer_dot_product -o - | FileCheck %s --check-prefix=CHECK-SPIRV

;CHECK-LLVM: call spir_func i32 @_Z15__spirv_SDotKHR
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
;CHECK-LLVM: call spir_func i32 @_Z15__spirv_UDotKHR

;CHECK-LLVM: call spir_func i32 @_Z21__spirv_SDotAccSatKHR
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
;CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotAccSatKHR

;CHECK-SPIRV: SDotKHR
;CHECK-SPIRV: SUDotKHR
;CHECK-SPIRV: SUDotKHR
;CHECK-SPIRV: UDotKHR

;CHECK-SPIRV: SDotAccSatKHR
;CHECK-SPIRV: SUDotAccSatKHR
;CHECK-SPIRV: SUDotAccSatKHR
;CHECK-SPIRV: UDotAccSatKHR

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent norecurse nounwind
define spir_kernel void @test1(<4 x i8> %ia, <4 x i8> %ua, <4 x i8> %ib, <4 x i8> %ub, <4 x i8> %ires, <4 x i8> %ures) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !6 !kernel_arg_type_qual !7 {
entry:
%call = tail call spir_func i32 @_Z3dotDv4_cS_(<4 x i8> %ia, <4 x i8> %ib) #2
%call1 = tail call spir_func i32 @_Z3dotDv4_cDv4_h(<4 x i8> %ia, <4 x i8> %ub) #2
%call2 = tail call spir_func i32 @_Z3dotDv4_hDv4_c(<4 x i8> %ua, <4 x i8> %ib) #2
%call3 = tail call spir_func i32 @_Z3dotDv4_hS_(<4 x i8> %ua, <4 x i8> %ub) #2
%call4 = tail call spir_func i32 @_Z11dot_acc_satDv4_cS_i(<4 x i8> %ia, <4 x i8> %ib, i32 %call2) #2
%call5 = tail call spir_func i32 @_Z11dot_acc_satDv4_cDv4_hi(<4 x i8> %ia, <4 x i8> %ub, i32 %call4) #2
%call6 = tail call spir_func i32 @_Z11dot_acc_satDv4_hDv4_ci(<4 x i8> %ua, <4 x i8> %ib, i32 %call5) #2
%call7 = tail call spir_func i32 @_Z11dot_acc_satDv4_hS_j(<4 x i8> %ua, <4 x i8> %ub, i32 %call3) #2
ret void
}

; Function Attrs: convergent
declare spir_func i32 @_Z3dotDv4_cS_(<4 x i8>, <4 x i8>) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotDv4_cDv4_h(<4 x i8>, <4 x i8>) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotDv4_hDv4_c(<4 x i8>, <4 x i8>) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotDv4_hS_(<4 x i8>, <4 x i8>) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satDv4_cS_i(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satDv4_cDv4_hi(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satDv4_hDv4_ci(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satDv4_hS_j(<4 x i8>, <4 x i8>, i32) local_unnamed_addr #1

attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="128" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pocharer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { convergent nounwind }

!llvm.module.flags = !{!0}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}
!llvm.ident = !{!2}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}
!2 = !{!"clang version 11.0.0 (https://github.com/c199914007/llvm.git 8b94769313ca84cb9370b60ed008501ff692cb71)"}
!3 = !{i32 0, i32 0, i32 0, i32 0, i32 0, i32 0}
!4 = !{!"none", !"none", !"none", !"none", !"none", !"none"}
!5 = !{!"char4", !"uchar4", !"char4", !"uchar4", !"char4", !"uchar4"}
!6 = !{!"char __attribute__((ext_vector_type(4)))", !"uchar __attribute__((ext_vector_type(4)))", !"char __attribute__((ext_vector_type(4)))", !"uchar __attribute__((ext_vector_type(4)))", !"char __attribute__((ext_vector_type(4)))", !"uchar __attribute__((ext_vector_type(4)))"}
!7 = !{!"", !"", !"", !"", !"", !""}
84 changes: 84 additions & 0 deletions test/transcoding/SPV_KHR_integer_dot_product_OCLtoSPIRV_int.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv -s %t.bc -o %t.regularized.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_integer_dot_product -o %t-spirv.spv
; RUN: spirv-val %t-spirv.spv
; RUN: llvm-dis %t.regularized.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
; RUN: llvm-spirv %t.bc -spirv-text --spirv-ext=+SPV_KHR_integer_dot_product -o - | FileCheck %s --check-prefix=CHECK-SPIRV

;CHECK-LLVM: call spir_func i32 @_Z15__spirv_SDotKHR
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
;CHECK-LLVM: call spir_func i32 @_Z16__spirv_SUDotKHR
;CHECK-LLVM: call spir_func i32 @_Z15__spirv_UDotKHR

;CHECK-LLVM: call spir_func i32 @_Z21__spirv_SDotAccSatKHR
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
;CHECK-LLVM: call spir_func i32 @_Z22__spirv_SUDotAccSatKHR
;CHECK-LLVM: call spir_func i32 @_Z21__spirv_UDotAccSatKHR

;CHECK-SPIRV: SDotKHR
;CHECK-SPIRV: SUDotKHR
;CHECK-SPIRV: SUDotKHR
;CHECK-SPIRV: UDotKHR

;CHECK-SPIRV: SDotAccSatKHR
;CHECK-SPIRV: SUDotAccSatKHR
;CHECK-SPIRV: SUDotAccSatKHR
;CHECK-SPIRV: UDotAccSatKHR

target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent norecurse nounwind
define spir_kernel void @test1(i32 %ia, i32 %ua, i32 %ib, i32 %ub, i32 %ires, i32 %ures) local_unnamed_addr #0 !kernel_arg_addr_space !3 !kernel_arg_access_qual !4 !kernel_arg_type !5 !kernel_arg_base_type !5 !kernel_arg_type_qual !6 {
entry:
%call = tail call spir_func i32 @_Z3dotiii(i32 %ia, i32 %ib, i32 0) #2
%call1 = tail call spir_func i32 @_Z3dotiji(i32 %ia, i32 %ub, i32 0) #2
%call2 = tail call spir_func i32 @_Z3dotjii(i32 %ua, i32 %ib, i32 0) #2
%call3 = tail call spir_func i32 @_Z3dotjji(i32 %ua, i32 %ub, i32 0) #2
%call4 = tail call spir_func i32 @_Z11dot_acc_satiiii(i32 %ia, i32 %ib, i32 %ires, i32 0) #2
%call5 = tail call spir_func i32 @_Z11dot_acc_satijii(i32 %ia, i32 %ub, i32 %ires, i32 0) #2
%call6 = tail call spir_func i32 @_Z11dot_acc_satjiii(i32 %ua, i32 %ib, i32 %ires, i32 0) #2
%call7 = tail call spir_func i32 @_Z11dot_acc_satjjji(i32 %ua, i32 %ub, i32 %ures, i32 0) #2
ret void
}

; Function Attrs: convergent
declare spir_func i32 @_Z3dotiii(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotiji(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotjii(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z3dotjji(i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satiiii(i32, i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satijii(i32, i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satjiii(i32, i32, i32, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare spir_func i32 @_Z11dot_acc_satjjji(i32, i32, i32, i32) local_unnamed_addr #1

attributes #0 = { convergent norecurse nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { convergent "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #2 = { convergent nounwind }

!llvm.module.flags = !{!0}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!1}
!llvm.ident = !{!2}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 2, i32 0}
!2 = !{!"clang version 11.0.0 (https://github.com/c199914007/llvm.git f2b7028a3598d4d88ddf1f76b50946da4e135845)"}
!3 = !{i32 0, i32 0, i32 0, i32 0, i32 0, i32 0}
!4 = !{!"none", !"none", !"none", !"none", !"none", !"none"}
!5 = !{!"int", !"uint", !"int", !"uint", !"int", !"uint"}
!6 = !{!"", !"", !"", !"", !"", !""}
Loading