Skip to content

Commit 6d09a9e

Browse files
committed
shared(AbstractArray) and sync_workgroup
1 parent 3ffce4d commit 6d09a9e

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

ext/JACCAMDGPU/JACCAMDGPU.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,9 @@ function reduce_kernel_amdgpu_MN((M, N), op, red, ret)
493493
return nothing
494494
end
495495

496-
function JACC.shared(x::ROCDeviceArray{T, N}) where {T, N}
496+
function JACC.shared(::AMDGPUBackend, x::AbstractArray)
497497
size = length(x)
498-
shmem = @ROCDynamicLocalArray(T, size)
498+
shmem = @ROCDynamicLocalArray(eltype(x), size)
499499
num_threads = workgroupDim().x * workgroupDim().y
500500
if (size <= num_threads)
501501
if workgroupDim().y == 1
@@ -539,6 +539,8 @@ function JACC.shared(x::ROCDeviceArray{T, N}) where {T, N}
539539
return shmem
540540
end
541541

542+
JACC.sync_workgroup(::AMDGPUBackend) = AMDGPU.sync_workgroup()
543+
542544
JACC.array_type(::AMDGPUBackend) = AMDGPU.ROCArray
543545

544546
JACC.array(::AMDGPUBackend, x::Base.Array) = AMDGPU.ROCArray(x)

ext/JACCCUDA/JACCCUDA.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,9 @@ function reduce_kernel_cuda_MN((M, N), op, red, ret)
504504
return nothing
505505
end
506506

507-
function JACC.shared(x::CuDeviceArray{T, N}) where {T, N}
507+
function JACC.shared(::CUDABackend, x::AbstractArray)
508508
size = length(x)
509-
shmem = CuDynamicSharedArray(T, size)
509+
shmem = CuDynamicSharedArray(eltype(x), size)
510510
num_threads = blockDim().x * blockDim().y
511511
if (size <= num_threads)
512512
if blockDim().y == 1
@@ -556,6 +556,8 @@ function JACC.shared(x::CuDeviceArray{T, N}) where {T, N}
556556
return shmem
557557
end
558558

559+
JACC.sync_workgroup(::CUDABackend) = CUDA.sync_threads()
560+
559561
JACC.array_type(::CUDABackend) = CUDA.CuArray
560562

561563
JACC.array(::CUDABackend, x::Base.Array) = CUDA.CuArray(x)

ext/JACCONEAPI/JACCONEAPI.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -418,10 +418,10 @@ function reduce_kernel_oneapi_MN((M, N), op, red, ret)
418418
return nothing
419419
end
420420

421-
function JACC.shared(x::oneDeviceArray{T, N}) where {T, N}
421+
function JACC.shared(::oneAPIBackend, x::AbstractArray)
422422
size::Int32 = length(x)
423423
# This is wrong, we should use size not 512 ...
424-
shmem = oneLocalArray(T, 512)
424+
shmem = oneLocalArray(eltype(x), 512)
425425
num_threads = get_local_size(0) * get_local_size(1)
426426
if (size <= num_threads)
427427
if get_local_size(1) == 1
@@ -465,6 +465,8 @@ function JACC.shared(x::oneDeviceArray{T, N}) where {T, N}
465465
return shmem
466466
end
467467

468+
JACC.sync_workgroup(::oneAPIBackend) = oneAPI.barrier()
469+
468470
JACC.array_type(::oneAPIBackend) = oneAPI.oneArray
469471

470472
JACC.array(::oneAPIBackend, x::Base.Array) = oneAPI.oneArray(x)

src/JACC.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,19 @@ function parallel_reduce(
123123
return ret
124124
end
125125

126+
function shared(::ThreadsBackend, x::AbstractArray)
127+
return x
128+
end
129+
126130
array_type(::ThreadsBackend) = Base.Array
127131

128132
array(::ThreadsBackend, x::Base.Array) = x
129133

130134
default_float(::Any) = Float64
131135

132-
function shared(x::Base.Array{T, N}) where {T, N}
133-
return x
134-
end
136+
shared(x::AbstractArray) = shared(default_backend(), x)
137+
138+
sync_workgroup() = sync_workgroup(default_backend())
135139

136140
array_type() = array_type(default_backend())
137141

0 commit comments

Comments
 (0)