diff --git a/sycl/include/CL/sycl/detail/stl_type_traits.hpp b/sycl/include/CL/sycl/detail/stl_type_traits.hpp index 03403b7ba063e..82b8509c6a0e4 100644 --- a/sycl/include/CL/sycl/detail/stl_type_traits.hpp +++ b/sycl/include/CL/sycl/detail/stl_type_traits.hpp @@ -78,6 +78,12 @@ struct is_output_iterator> { static constexpr bool value = true; }; +template +inline constexpr bool is_same_v = std::is_same::value; + +template +inline constexpr bool is_convertible_v = std::is_convertible::value; + } // namespace detail } // namespace sycl } // __SYCL_INLINE_NAMESPACE(cl) diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index bc2be4de427b7..bf9ca91dae24a 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -917,11 +917,14 @@ class __SYCL_EXPORT handler { } template struct TransformUserItemType { - using type = typename std::conditional< - std::is_convertible, LambdaArgType>::value, nd_item, - typename std::conditional< - std::is_convertible, LambdaArgType>::value, item, - LambdaArgType>::type>::type; + using type = typename std::conditional_t< + detail::is_same_v, LambdaArgType>, LambdaArgType, + typename std::conditional_t< + detail::is_convertible_v, LambdaArgType>, + nd_item, + typename std::conditional_t< + detail::is_convertible_v, LambdaArgType>, item, + LambdaArgType>>>; }; /// Defines and invokes a SYCL kernel function for the specified range. diff --git a/sycl/test/basic_tests/parallel_for_type_check.cpp b/sycl/test/basic_tests/parallel_for_type_check.cpp new file mode 100644 index 0000000000000..984657c430e4c --- /dev/null +++ b/sycl/test/basic_tests/parallel_for_type_check.cpp @@ -0,0 +1,30 @@ +// RUN: %clangxx -fsycl -fsycl-device-only -D__SYCL_INTERNAL_API -O0 -c -emit-llvm -S -o - %s | FileCheck %s + +// This test performs basic type check for sycl::id that is used in result type. + +#include +#include + +int main() { + sycl::queue q; + + // Initialize data array + const int sz = 16; + int data[sz] = {0}; + for (int i = 0; i < sz; ++i) { + data[i] = i; + } + + // Check user defined sycl::item wrapper + sycl::buffer data_buf(data, sz); + q.submit([&](sycl::handler &h) { + auto buf_acc = data_buf.get_access(h); + h.parallel_for( + sycl::range<1>{sz}, + // CHECK: cl{{.*}}sycl{{.*}}detail{{.*}}RoundedRangeKernel{{.*}}id{{.*}}main{{.*}}handler + [=](sycl::id<1> item) { buf_acc[item] += 1; }); + }); + q.wait(); + + return 0; +}