Skip to content

Commit

Permalink
fix sycl accessor (#2733)
Browse files Browse the repository at this point in the history
  • Loading branch information
guizili0 committed Jul 5, 2024
1 parent e0bf4a3 commit 1f84fc9
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions itex/core/kernels/gpu/stateful_random_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ struct FillKernelTask {
FillKernelTask(sycl::local_accessor<char, 1> local_philox_acc,
StateElementType* state_data,
typename Distribution::ResultElementType* output_data,
int64_t output_size, int* item_count_ptr, Distribution dist)
int64_t output_size, sycl::accessor<int, 1> item_count_acc,
Distribution dist)
:

local_philox_acc(local_philox_acc),
state_data(state_data),
output_data(output_data),
output_size(output_size),
item_count_ptr(item_count_ptr),
item_count_acc(item_count_acc),
dist(dist) {}
void operator()(sycl::nd_item<1> myItem) const {
// Items in a group share `philox`. Item 0 is responsible for
Expand All @@ -68,6 +69,8 @@ struct FillKernelTask {
f(myItem);
// The last item updates the state.
auto total_item_count = myItem.get_global_range()[0];
auto item_count_ptr =
item_count_acc.template get_multi_ptr<sycl::access::decorated::no>();
auto atomic_val =
sycl::atomic_ref<int, sycl::memory_order::relaxed,
sycl::memory_scope::device,
Expand All @@ -84,7 +87,7 @@ struct FillKernelTask {
StateElementType* state_data;
typename Distribution::ResultElementType* output_data;
int64_t output_size;
int* item_count_ptr;
sycl::accessor<int, 1> item_count_acc;
Distribution dist;
};

Expand All @@ -106,16 +109,13 @@ void FillKernel(const GPUDevice& d, const int total_count, Distribution dist,
sycl::buffer<int, 1> item_count_buf{&item_count, 1};

stream->submit([&](sycl::handler& cgh) {
auto item_count_ptr =
item_count_buf
.get_access<sycl::access::mode::read_write,
sycl::access::target::device>(cgh)
.template get_multi_ptr<sycl::access::decorated::no>()
.get();
auto item_count_acc =
item_count_buf.get_access<sycl::access::mode::read_write,
sycl::access::target::device>(cgh);
sycl::local_accessor<char, 1> local_philox_acc(
sycl::range<1>(sizeof(PhiloxRandom)), cgh);
FillKernelTask<Distribution> task(local_philox_acc, state_data, output_data,
output_size, item_count_ptr, dist);
output_size, item_count_acc, dist);
cgh.parallel_for<FillKernelTask<Distribution>>(
sycl::nd_range<1>(sycl::range<1>(work_group * work_group_size),
sycl::range<1>(work_group_size)),
Expand Down

0 comments on commit 1f84fc9

Please sign in to comment.