Skip to content

Commit 1e9028d

Browse files
committed
shared(AbstractArray) and sync_workgroup
1 parent 48eef90 commit 1e9028d

File tree

5 files changed

+17
-9
lines changed

5 files changed

+17
-9
lines changed

ext/AMDGPUExt/AMDGPUExt.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -488,9 +488,9 @@ function reduce_kernel_amdgpu_MN((M, N), op, red, ret)
488488
return nothing
489489
end
490490

491-
function JACC.shared(x::ROCDeviceArray{T, N}) where {T, N}
491+
function JACC.shared(::AMDGPUBackend, x::AbstractArray)
492492
size = length(x)
493-
shmem = @ROCDynamicLocalArray(T, size)
493+
shmem = @ROCDynamicLocalArray(eltype(x), size)
494494
num_threads = workgroupDim().x * workgroupDim().y
495495
if (size <= num_threads)
496496
if workgroupDim().y == 1
@@ -534,6 +534,8 @@ function JACC.shared(x::ROCDeviceArray{T, N}) where {T, N}
534534
return shmem
535535
end
536536

537+
JACC.sync_workgroup(::AMDGPUBackend) = AMDGPU.sync_workgroup()
538+
537539
JACC.array_type(::AMDGPUBackend) = AMDGPU.ROCArray
538540

539541
JACC.array(::AMDGPUBackend, x::Base.Array) = AMDGPU.ROCArray(x)

ext/CUDAExt/CUDAExt.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,9 @@ function reduce_kernel_cuda_MN((M, N), op, red, ret)
493493
return nothing
494494
end
495495

496-
function JACC.shared(x::CuDeviceArray{T, N}) where {T, N}
496+
function JACC.shared(::CUDABackend, x::AbstractArray)
497497
size = length(x)
498-
shmem = CuDynamicSharedArray(T, size)
498+
shmem = CuDynamicSharedArray(eltype(x), size)
499499
num_threads = blockDim().x * blockDim().y
500500
if (size <= num_threads)
501501
if blockDim().y == 1
@@ -545,6 +545,8 @@ function JACC.shared(x::CuDeviceArray{T, N}) where {T, N}
545545
return shmem
546546
end
547547

548+
JACC.sync_workgroup(::CUDABackend) = CUDA.sync_threads()
549+
548550
JACC.array_type(::CUDABackend) = CUDA.CuArray
549551

550552
JACC.array(::CUDABackend, x::Base.Array) = CUDA.CuArray(x)

ext/oneAPIExt/oneAPIExt.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,10 @@ function reduce_kernel_oneapi_MN((M, N), op, red, ret)
375375
return nothing
376376
end
377377

378-
function JACC.shared(x::oneDeviceArray{T, N}) where {T, N}
378+
function JACC.shared(::oneAPIBackend, x::AbstractArray)
379379
size::Int32 = length(x)
380380
# This is wrong, we should use size not 512 ...
381-
shmem = oneLocalArray(T, 512)
381+
shmem = oneLocalArray(eltype(x), 512)
382382
num_threads = get_local_size(0) * get_local_size(1)
383383
if (size <= num_threads)
384384
if get_local_size(1) == 1
@@ -422,6 +422,8 @@ function JACC.shared(x::oneDeviceArray{T, N}) where {T, N}
422422
return shmem
423423
end
424424

425+
JACC.sync_workgroup(::oneAPIBackend) = oneAPI.barrier()
426+
425427
JACC.array_type(::oneAPIBackend) = oneAPI.oneArray
426428

427429
JACC.array(::oneAPIBackend, x::Base.Array) = oneAPI.oneArray(x)

src/JACC.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ launch_spec(; kw...) = LaunchSpec{typeof(default_backend())}(; kw...)
4747

4848
default_float(::Any) = Float64
4949

50-
function shared(x::Base.Array{T, N}) where {T, N}
51-
return x
52-
end
50+
shared(x::AbstractArray) = shared(default_backend(), x)
51+
52+
sync_workgroup() = sync_workgroup(default_backend())
5353

5454
array_type() = array_type(default_backend())
5555

src/threads/threads.jl

+2
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,6 @@ JACC.array_type(::ThreadsBackend) = Base.Array
105105

106106
JACC.array(::ThreadsBackend, x::Base.Array) = x
107107

108+
JACC.shared(::ThreadsBackend, x::AbstractArray) = x
109+
108110
end

0 commit comments

Comments
 (0)