Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kongdd committed Oct 3, 2024
1 parent bdb17d7 commit 5e8c52b
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/IO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export writelines
! `x` 需要是string,不然文件错误
"""
function writelines(x::AbstractVector{<:AbstractString}, f; mode="w", eof="\n")
function writelines(x::AbstractVector{<:AbstractString}, f::AbstractString; mode="w", eof="\n")
fid = open(f, mode)
@inbounds for _x in x
write(fid, _x)
Expand All @@ -31,7 +31,7 @@ function writelines(x::AbstractVector{<:AbstractString}, f; mode="w", eof="\n")
close(fid)
end

function writelines(x::AbstractString, f; mode="w", eof="\n")
function writelines(x::AbstractString, f::AbstractString; mode="w", eof="\n")
fid = open(f, mode)
write(fid, x)
write(fid, eof)
Expand Down
10 changes: 8 additions & 2 deletions src/Statistics/match2.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# import StatsBase: countmap
# table = countmap
function dict2order(dict; rev=true)
_keys = keys(dict) |> collect
_values = values(dict) |> collect
inds = sortperm(_values; rev)
OrderedDict(_keys[i] => _values[i] for i in inds)
end

"""
table(x::AbstractVector)
Expand All @@ -8,7 +14,7 @@
This function is about 5X slower than `StatsBase: countmap`.
If speed matters for you, use `StatsBase.countmap` instead.
"""
function table(x::AbstractVector)
function table(x::AbstractVector; rev=true)
tbl = Dict{eltype(x),Int}()
for element in x
if haskey(tbl, element)
Expand All @@ -17,7 +23,7 @@ function table(x::AbstractVector)
tbl[element] = 1
end
end
return tbl
return dict2order(tbl; rev)
end


Expand Down
1 change: 1 addition & 0 deletions src/r_base.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import DataStructures: OrderedDict
export OrderedDict
# using LoopVectorization: @turbo

# function r_in(x::AbstractVector, y::AbstractVector)::BitVector
Expand Down
13 changes: 13 additions & 0 deletions test/test-Ipaper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,18 @@ end
str = ["a", "b"]
writelines(str, f)
@test readlines(f) == str

writelines("Hello world", f)
@test readlines("a.txt")[1] == "Hello world"
rm(f)
end

@testset "table" begin
tbl = table([2, 2, 2, 1, 1, 1, 1]; rev=true)
@test collect(values(tbl)) == [4, 3]

tbl = table([2, 2, 2, 1, 1, 1, 1]; rev=false)
@test collect(values(tbl)) == [3, 4]
end


14 changes: 11 additions & 3 deletions test/test-stat_weighted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ end
@testset "weighted_nansum" begin
x = [1.0, NaN, 3]
A = reshape(x, 1, 1, 3)
weighted_nansum([1.0, 2, 3], [1, 1, 1]) == 6.0
weighted_nansum(x, [1, 1, 1]) == 4.0
weighted_nansum(A, [1, 1, 1])[1] == 4.0
@test weighted_nansum([1.0, 2, 3], [1, 1, 1]) == 6.0
@test weighted_nansum(x, [1, 1, 1]) == 4.0
@test weighted_nansum(A, [1, 1, 1])[1] == 4.0
end

@testset "weighted_nansum" begin
x = [1.0, NaN, 3]
A = reshape(x, 1, 1, 3)
@test weighted_nanmean([1.0, 2, 3], [1, 1, 1]) == 2.0
@test weighted_nanmean(x, [1, 1, 1]) == 2.0
# weighted_nanmean(A, [1, 1, 1])[1] == 4.0
end

0 comments on commit 5e8c52b

Please sign in to comment.