Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,19 @@ function Base.copyto!(A::ArrayPartition, src::ArrayPartition)
A
end

function Base.fill!(A::ArrayPartition, x)
unrolled_foreach!(A.x) do x_
fill!(x_, x)
end
A
end

function recursivefill!(b::ArrayPartition, a::T2) where {T2 <: Union{Number, Bool}}
unrolled_foreach!(b.x) do x
fill!(x, a)
end
end


## indexing

# Interface for the linear indexing. This is just a view of the underlying nested structure
Expand Down
4 changes: 4 additions & 0 deletions test/gpu/arraypartition_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ mask = pA .> 0
# Test recursive filling is done using GPU kernels and not scalar indexing
RecursiveArrayTools.recursivefill!(pA, true)
@test all(pA .== true)

# Test that regular filling is done using GPU kernels and not scalar indexing
fill!(pA, false)
@test all(pA .== false)