Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 86 additions & 27 deletions include/hip/hcc_detail/hip_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ THE SOFTWARE.
//#include <cstring>
#if __cplusplus
#include <cmath>
#include <cstdint>
#else
#include <math.h>
#include <string.h>
Expand Down Expand Up @@ -198,35 +199,93 @@ __device__ int __hip_move_dpp_N(int src);

#if defined __HCC__

template <
typename std::common_type<decltype(hc_get_group_id), decltype(hc_get_group_size),
decltype(hc_get_num_groups), decltype(hc_get_workitem_id)>::type f>
class Coordinates {
using R = decltype(f(0));

struct X {
__device__ operator R() const { return f(0); }
__device__ uint32_t operator=(R _) { return f(0); }
};
struct Y {
__device__ operator R() const { return f(1); }
__device__ uint32_t operator=(R _) { return f(1); }
};
struct Z {
__device__ operator R() const { return f(2); }
__device__ uint32_t operator=(R _) { return f(2); }
};

public:
static constexpr X x{};
static constexpr Y y{};
static constexpr Z z{};
namespace hip_impl {
struct GroupId {
using R = decltype(hc_get_group_id(0));

__device__
R operator()(std::uint32_t x) const noexcept { return hc_get_group_id(x); }
};
struct GroupSize {
using R = decltype(hc_get_group_size(0));

__device__
R operator()(std::uint32_t x) const noexcept {
return hc_get_group_size(x);
}
};
struct NumGroups {
using R = decltype(hc_get_num_groups(0));

__device__
R operator()(std::uint32_t x) const noexcept {
return hc_get_num_groups(x);
}
};
struct WorkitemId {
using R = decltype(hc_get_workitem_id(0));

__device__
R operator()(std::uint32_t x) const noexcept {
return hc_get_workitem_id(x);
}
};
} // Namespace hip_impl.

template <typename F>
struct Coordinates {
using R = decltype(F{}(0));

struct X { __device__ operator R() const noexcept { return F{}(0); } };
struct Y { __device__ operator R() const noexcept { return F{}(1); } };
struct Z { __device__ operator R() const noexcept { return F{}(2); } };

static constexpr X x{};
static constexpr Y y{};
static constexpr Z z{};
};

static constexpr Coordinates<hc_get_group_size> blockDim;
static constexpr Coordinates<hc_get_group_id> blockIdx;
static constexpr Coordinates<hc_get_num_groups> gridDim;
static constexpr Coordinates<hc_get_workitem_id> threadIdx;
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::NumGroups>::X,
Coordinates<hip_impl::GroupSize>::X) noexcept {
return hc_get_grid_size(0);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::GroupSize>::X,
Coordinates<hip_impl::NumGroups>::X) noexcept {
return hc_get_grid_size(0);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::NumGroups>::Y,
Coordinates<hip_impl::GroupSize>::Y) noexcept {
return hc_get_grid_size(1);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::GroupSize>::Y,
Coordinates<hip_impl::NumGroups>::Y) noexcept {
return hc_get_grid_size(1);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::NumGroups>::Z,
Coordinates<hip_impl::GroupSize>::Z) noexcept {
return hc_get_grid_size(2);
}
inline
__device__
std::uint32_t operator*(Coordinates<hip_impl::GroupSize>::Z,
Coordinates<hip_impl::NumGroups>::Z) noexcept {
return hc_get_grid_size(2);
}

static constexpr Coordinates<hip_impl::GroupSize> blockDim{};
static constexpr Coordinates<hip_impl::GroupId> blockIdx{};
static constexpr Coordinates<hip_impl::NumGroups> gridDim{};
static constexpr Coordinates<hip_impl::WorkitemId> threadIdx{};

#define hipThreadIdx_x (hc_get_workitem_id(0))
#define hipThreadIdx_y (hc_get_workitem_id(1))
Expand Down