Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sycl/include/sycl/detail/nd_range_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class nd_range_view {
nd_range_view &operator=(const nd_range_view &Desc) = default;
nd_range_view &operator=(nd_range_view &&Desc) = default;

template <int Dims_>
nd_range_view(sycl::range<Dims_> &N)
: MGlobalSize(&(N[0])), MDims(size_t(Dims_)) {}

template <int Dims_>
nd_range_view(sycl::nd_range<Dims_> &ExecutionRange)
: MGlobalSize(&(ExecutionRange.globalSize[0])),
Expand Down
143 changes: 18 additions & 125 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <sycl/nd_range.hpp>
#include <sycl/property_list.hpp>
#include <sycl/range.hpp>
#include <sycl/range_rounding.hpp>
#include <sycl/sampler.hpp>

#include <assert.h>
Expand Down Expand Up @@ -262,106 +263,6 @@ __SYCL_EXPORT void *getValueFromDynamicParameter(
ext::oneapi::experimental::detail::dynamic_parameter_base
&DynamicParamBase);

template <int Dims> class RoundedRangeIDGenerator {
id<Dims> Id;
id<Dims> InitId;
range<Dims> UserRange;
range<Dims> RoundedRange;
bool Done = false;

public:
RoundedRangeIDGenerator(const id<Dims> &Id, const range<Dims> &UserRange,
const range<Dims> &RoundedRange)
: Id(Id), InitId(Id), UserRange(UserRange), RoundedRange(RoundedRange) {
for (int i = 0; i < Dims; ++i)
if (Id[i] >= UserRange[i])
Done = true;
}

explicit operator bool() { return !Done; }

void updateId() {
for (int i = 0; i < Dims; ++i) {
Id[i] += RoundedRange[i];
if (Id[i] < UserRange[i])
return;
Id[i] = InitId[i];
}
Done = true;
}

id<Dims> getId() { return Id; }

template <typename KernelType> auto getItem() {
if constexpr (std::is_invocable_v<KernelType, item<Dims> &> ||
std::is_invocable_v<KernelType, item<Dims> &, kernel_handler>)
return detail::Builder::createItem<Dims, true>(UserRange, getId(), {});
else {
static_assert(std::is_invocable_v<KernelType, item<Dims, false> &> ||
std::is_invocable_v<KernelType, item<Dims, false> &,
kernel_handler>,
"Kernel must be invocable with an item!");
return detail::Builder::createItem<Dims, false>(UserRange, getId());
}
}
};

// TODO: The wrappers can be optimized further so that the body
// essentially looks like this:
// for (auto z = it[2]; z < UserRange[2]; z += it.get_range(2))
// for (auto y = it[1]; y < UserRange[1]; y += it.get_range(1))
// for (auto x = it[0]; x < UserRange[0]; x += it.get_range(0))
// KernelFunc({x,y,z});
template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernel {
public:
range<Dims> UserRange;
KernelType KernelFunc;
void operator()(item<Dims> It) const {
auto RoundedRange = It.get_range();
for (RoundedRangeIDGenerator Gen(It.get_id(), UserRange, RoundedRange); Gen;
Gen.updateId()) {
auto item = Gen.template getItem<KernelType>();
KernelFunc(item);
}
}

// Copy the properties_tag getter from the original kernel to propagate
// property(s)
template <
typename T = KernelType,
typename = std::enable_if_t<ext::oneapi::experimental::detail::
HasKernelPropertiesGetMethod<T>::value>>
auto get(ext::oneapi::experimental::properties_tag) const {
return KernelFunc.get(ext::oneapi::experimental::properties_tag{});
}
};

template <typename TransformedArgType, int Dims, typename KernelType>
class RoundedRangeKernelWithKH {
public:
range<Dims> UserRange;
KernelType KernelFunc;
void operator()(item<Dims> It, kernel_handler KH) const {
auto RoundedRange = It.get_range();
for (RoundedRangeIDGenerator Gen(It.get_id(), UserRange, RoundedRange); Gen;
Gen.updateId()) {
auto item = Gen.template getItem<KernelType>();
KernelFunc(item, KH);
}
}

// Copy the properties_tag getter from the original kernel to propagate
// property(s)
template <
typename T = KernelType,
typename = std::enable_if_t<ext::oneapi::experimental::detail::
HasKernelPropertiesGetMethod<T>::value>>
auto get(ext::oneapi::experimental::properties_tag) const {
return KernelFunc.get(ext::oneapi::experimental::properties_tag{});
}
};

using std::enable_if_t;
using sycl::detail::queue_impl;

Expand All @@ -384,6 +285,13 @@ template <int Dims> bool range_size_fits_in_size_t(const range<Dims> &r) {
return true;
}

template <int Dims, typename LambdaArgType> struct TransformUserItemType {
using type = std::conditional_t<
std::is_convertible_v<nd_item<Dims>, LambdaArgType>, nd_item<Dims>,
std::conditional_t<std::is_convertible_v<item<Dims>, LambdaArgType>,
item<Dims>, LambdaArgType>>;
};
Comment on lines +288 to +293
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Example usage:
// using mapped = map_type<type_to_map, from0, /*->*/ to0,
// from1, /*->*/ to1,
// ...>
template <typename...> struct map_type {
using type = void;
};
template <typename T, typename From, typename To, typename... Rest>
struct map_type<T, From, To, Rest...> {
using type = std::conditional_t<std::is_same_v<From, T>, To,
typename map_type<T, Rest...>::type>;
};
might be helpful.


} // namespace detail

/// Command group handler class.
Expand Down Expand Up @@ -1019,6 +927,9 @@ class __SYCL_EXPORT handler {

bool eventNeeded() const;

device get_device() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an extra std::shared_ptr copy. https://github.com/intel/llvm/pull/20698/files#r2561353949 is related. Can you collaborate with @lbushi25 for a better fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I am planning to create a separate PR with a refactor to avoid this.


#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're in ABI breaking window, just drop it.

template <int Dims, typename LambdaArgType> struct TransformUserItemType {
using type = std::conditional_t<
std::is_convertible_v<nd_item<Dims>, LambdaArgType>, nd_item<Dims>,
Expand Down Expand Up @@ -1154,6 +1065,7 @@ class __SYCL_EXPORT handler {
return {range<Dims>{}, false};
return {RoundedRange, true};
}
#endif

/// Defines and invokes a SYCL kernel function for the specified range.
///
Expand Down Expand Up @@ -1193,7 +1105,7 @@ class __SYCL_EXPORT handler {
// sycl::item/sycl::nd_item to transport item information
using TransformedArgType = std::conditional_t<
std::is_integral<LambdaArgType>::value && Dims == 1, item<Dims>,
typename TransformUserItemType<Dims, LambdaArgType>::type>;
typename detail::TransformUserItemType<Dims, LambdaArgType>::type>;

static_assert(!std::is_same_v<TransformedArgType, sycl::nd_item<Dims>>,
"Kernel argument cannot have a sycl::nd_item type in "
Expand All @@ -1216,11 +1128,12 @@ class __SYCL_EXPORT handler {
// Range rounding is supported only for newer SYCL standards.
#if !defined(__SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING__) && \
SYCL_LANGUAGE_VERSION >= 202012L
auto [RoundedRange, HasRoundedRange] = getRoundedRange(UserRange);
auto [RoundedRange, HasRoundedRange] =
detail::getRoundedRange(UserRange, get_device());
if (HasRoundedRange) {
using NameWT = typename detail::get_kernel_wrapper_name_t<NameT>::name;
auto Wrapper =
getRangeRoundedKernelLambda<NameWT, TransformedArgType, Dims>(
detail::getRangeRoundedKernelLambda<NameWT, TransformedArgType, Dims>(
KernelFunc, UserRange);

using KName = std::conditional_t<std::is_same<KernelType, NameT>::value,
Expand Down Expand Up @@ -1743,7 +1656,7 @@ class __SYCL_EXPORT handler {
using LambdaArgType = sycl::detail::lambda_arg_type<KernelType, item<Dims>>;
using TransformedArgType = std::conditional_t<
std::is_integral<LambdaArgType>::value && Dims == 1, item<Dims>,
typename TransformUserItemType<Dims, LambdaArgType>::type>;
typename detail::TransformUserItemType<Dims, LambdaArgType>::type>;
wrap_kernel<detail::WrapAs::parallel_for, KernelName, TransformedArgType,
Dims>(KernelFunc, {} /*Props*/, NumWorkItems, WorkItemOffset);
}
Expand Down Expand Up @@ -3260,34 +3173,14 @@ class __SYCL_EXPORT handler {
friend class ext::oneapi::experimental::detail::dynamic_parameter_impl;
friend class ext::oneapi::experimental::detail::dynamic_command_group_impl;

#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just drop it.

bool DisableRangeRounding();

bool RangeRoundingTrace();

void GetRangeRoundingSettings(size_t &MinFactor, size_t &GoodFactor,
size_t &MinRange);

template <typename WrapperT, typename TransformedArgType, int Dims,
typename KernelType,
std::enable_if_t<detail::KernelLambdaHasKernelHandlerArgT<
KernelType, TransformedArgType>::value> * = nullptr>
auto getRangeRoundedKernelLambda(KernelType KernelFunc,
range<Dims> UserRange) {
return detail::RoundedRangeKernelWithKH<TransformedArgType, Dims,
KernelType>{UserRange, KernelFunc};
}

template <typename WrapperT, typename TransformedArgType, int Dims,
typename KernelType,
std::enable_if_t<!detail::KernelLambdaHasKernelHandlerArgT<
KernelType, TransformedArgType>::value> * = nullptr>
auto getRangeRoundedKernelLambda(KernelType KernelFunc,
range<Dims> UserRange) {
return detail::RoundedRangeKernel<TransformedArgType, Dims, KernelType>{
UserRange, KernelFunc};
}

#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
const std::shared_ptr<detail::context_impl> &getContextImplPtr() const;
#endif
detail::context_impl &getContextImpl() const;
Expand Down
Loading