From f0362535b91337060f09d8df506460ace9e7df9e Mon Sep 17 00:00:00 2001 From: Dongdong Kong Date: Thu, 26 Sep 2024 22:26:40 +0800 Subject: [PATCH] improve the efficiency of agg_time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 必须要有findall --- src/agg.jl | 9 +++++---- src/apply.jl | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/agg.jl b/src/agg.jl index 6223b8c..edc3d4b 100644 --- a/src/agg.jl +++ b/src/agg.jl @@ -25,7 +25,7 @@ function agg!(R::AbstractArray{FT,3}, A::AbstractArray{<:Real,3}; end progress && next!(p) end - R + return R end function agg(A::AbstractArray{<:Real,3}; fact=2, parallel=true, fun=mean) @@ -47,7 +47,7 @@ function agg_time(A::AbstractArray{T,3}; fact::Int=2, parallel=true, progress=fa R[i, j, k] = fun(@view A[i, j, I]) end end - R + return R end @@ -66,12 +66,13 @@ function agg_time(A::AbstractArray{T,3}, by::Vector; parallel=true, progress=fal p = Progress(ntime) @inbounds @par parallel for k = 1:_ntime progress && next!(p) - I = grps[k] .== by + I = (grps[k] .== by) |> findall # 必须要有findall + for j = 1:nlat, i = 1:nlon R[i, j, k] = fun(@view A[i, j, I]) end end - R + return R end export agg!, agg, agg_time diff --git a/src/apply.jl b/src/apply.jl index f4f312d..27040d1 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -57,7 +57,7 @@ function apply(A::AbstractArray, dims_by=3, args...; dims=dims_by, else grps = unique(by) |> sort res = par_map(grp -> begin - ind = by .== grp + ind = findall(by .== grp) data = selectdim(A, dims_by, ind) # |> collect # ans = fun(data, args...; kw...) # ans = par_mapslices(fun2, data; dims, parallel, progress) @@ -70,7 +70,7 @@ function apply(A::AbstractArray, dims_by=3, args...; dims=dims_by, res = cat(res..., dims=along) end end - res + return res end