diff --git a/base/abstractdict.jl b/base/abstractdict.jl index d9989265e0e0c..2056c55ab9b77 100644 --- a/base/abstractdict.jl +++ b/base/abstractdict.jl @@ -533,38 +533,22 @@ mutable struct IdDict{K,V} <: AbstractDict{K,V} end IdDict{K,V}(d::IdDict{K,V}) where {K, V} = new{K,V}(copy(d.ht)) - function IdDict{K,V}(d::IdDict{K}) where {K, V} - d = IdDict{K,V}() - sizehint!(d, length(pairs)) - for (k,v) in pairs; d[k] = v; end - d - end end -const _IdDict = IdDict{Any,Any} # this is needed to make src/dump.c and src/staticdata.c work IdDict() = IdDict{Any,Any}() IdDict(kv::Tuple{}) = IdDict() -IdDict(ps::Pair{K,V}...) where {K,V} = IdDict{K,V}(ps) -IdDict(ps::Pair{K}...) where {K} = IdDict{K,Any}(ps) -IdDict(ps::(Pair{K,V} where K)...) where {V} = IdDict{Any,V}(ps) -IdDict(ps::Pair...) = IdDict{Any,Any}(ps) - -function IdDict(kv) - try - dict_with_eltype((K, V) -> IdDict{K, V}, kv, eltype(kv)) - catch e - if !applicable(start, kv) || !all(x->isa(x,Union{Tuple,Pair}),kv) - throw(ArgumentError( - "IdDict(kv): kv needs to be an iterator of tuples or pairs")) - else - rethrow(e) - end - end -end +IdDict(ps::Pair...) = IdDict{Any,Any}(ps) +IdDict(itr) = IdDict{Any,Any}(itr) empty(d::IdDict, ::Type{K}, ::Type{V}) where {K, V} = IdDict{K,V}() +# TODO: this should probably be removed +function haskey(d::IdDict, k) + v = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, k, secret_table_token) + v !== secret_table_token +end + function rehash!(d::IdDict, newsz = length(d.ht)) d.ht = ccall(:jl_idtable_rehash, Any, (Any, Csize_t), d.ht, newsz) d @@ -580,7 +564,8 @@ function sizehint!(d::IdDict, newsz) rehash!(d, newsz) end -function setindex!(d::IdDict{K,V}, v, k::K) where {K, V} +function setindex!(d::IdDict{K,V}, v, k) where {K, V} + !isa(k, K) && throw(KeyError(k)) v = convert(V, v) if d.ndel >= ((3*length(d.ht))>>2) rehash!(d, max(length(d.ht)>>1, 32)) @@ -590,19 +575,20 @@ function setindex!(d::IdDict{K,V}, v, k::K) where {K, V} return d end -function get(d::IdDict{K,V}, key::K, default) where {K, V} - v = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, default) - v===default ? default : v::V +function get(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} + val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, default) + val === default ? default : val::V end -function getindex(d::IdDict{K,V}, key::K) where {K, V} - v = get(d, key, secret_table_token) - v == secret_table_token ? throw(KeyError(key)) : v::V +function getindex(d::IdDict{K,V}, key) where {K, V} + val = get(d, key, secret_table_token) + val === secret_table_token && throw(KeyError(key)) + return val::V end -function pop!(d::IdDict{K,V}, key::K, default) where {K, V} +function pop!(d::IdDict{K,V}, key, default) where {K, V} val = ccall(:jl_eqtable_pop, Any, (Any, Any, Any), d.ht, key, default) # TODO: this can underestimate `ndel` - if val==default + if val === default return default else (d.ndel += 1) @@ -612,7 +598,8 @@ end function pop!(d::IdDict{K,V}, key::K) where {K, V} val = pop!(d, key, secret_table_token) - val !== secret_table_token ? val::V : throw(KeyError(key)) + val === secret_table_token && throw(KeyError(key)) + return val::V end function delete!(d::IdDict{K}, key::K) where K @@ -627,7 +614,7 @@ function empty!(d::IdDict) return d end -_oidd_nextind(a, i) = reinterpret(Int,ccall(:jl_eqtable_nextind, Csize_t, (Any, Csize_t), a, i)) +_oidd_nextind(a, i) = reinterpret(Int, ccall(:jl_eqtable_nextind, Csize_t, (Any, Csize_t), a, i)) start(d::IdDict) = _oidd_nextind(d.ht, 0) done(d::IdDict, i) = (i == -1) diff --git a/base/codevalidation.jl b/base/codevalidation.jl index 957a304b9b9d7..a28fda60505b8 100644 --- a/base/codevalidation.jl +++ b/base/codevalidation.jl @@ -1,7 +1,7 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license # Expr head => argument count bounds -const VALID_EXPR_HEADS = IdDict( +const VALID_EXPR_HEADS = IdDict{Any,Any}( :call => 1:typemax(Int), :invoke => 2:typemax(Int), :static_parameter => 1:1, diff --git a/src/dump.c b/src/dump.c index d36dc422e195b..92b2321e91564 100644 --- a/src/dump.c +++ b/src/dump.c @@ -2297,7 +2297,7 @@ JL_DLLEXPORT int jl_save_incremental(const char *fname, jl_array_t *worklist) htable_new(&backref_table, 5000); ptrhash_put(&backref_table, jl_main_module, (char*)HT_NOTFOUND + 1); backref_table_numel = 1; - jl_idtable_type = jl_base_module ? jl_get_global(jl_base_module, jl_symbol("_IdDict")) : NULL; + jl_idtable_type = jl_base_module ? jl_get_global(jl_base_module, jl_symbol("IdDict")) : NULL; int en = jl_gc_enable(0); // edges map is not gc-safe jl_array_t *lambdas = jl_alloc_vec_any(0); diff --git a/src/staticdata.c b/src/staticdata.c index 3a81270bd43c0..64b7ca4cbe48e 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -1205,7 +1205,7 @@ static void jl_save_system_image_to_stream(ios_t *f) } } - jl_idtable_type = jl_base_module ? jl_get_global(jl_base_module, jl_symbol("_IdDict")) : NULL; + jl_idtable_type = jl_base_module ? jl_get_global(jl_base_module, jl_symbol("IdDict")) : NULL; { // step 1: record values (recursively) that need to go in the image jl_serialize_value(&s, jl_core_module); diff --git a/test/dict.jl b/test/dict.jl index 6fc268ae62bbc..848f057b72ee5 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -384,14 +384,14 @@ mutable struct T10647{T}; x::T; end Base.show(Base.IOContext(IOBuffer(), :limit => true), a) end -@testset "Base.ObjectIdDict" begin - a = Base.ObjectIdDict() +@testset "IdDict{Any,Any}" begin + a = IdDict{Any,Any}() a[1] = a a[a] = 2 sa = empty(a) @test isempty(sa) - @test isa(sa, Base.ObjectIdDict) + @test isa(sa, IdDict{Any,Any}) @test length(a) == 2 @test 1 in keys(a) @@ -411,16 +411,16 @@ end d = Dict('a'=>1, 'b'=>1, 'c'=> 3) @test a != d - @test length(Base.ObjectIdDict(1=>2, 1.0=>3)) == 2 + @test length(IdDict{Any,Any}(1=>2, 1.0=>3)) == 2 @test length(Dict(1=>2, 1.0=>3)) == 1 - d = @inferred Base.ObjectIdDict(i=>i for i=1:3) - @test isa(d, Base.ObjectIdDict) - @test d == Base.ObjectIdDict(1=>1, 2=>2, 3=>3) + d = @inferred IdDict{Any,Any}(i=>i for i=1:3) + @test isa(d, IdDict{Any,Any}) + @test d == IdDict{Any,Any}(1=>1, 2=>2, 3=>3) - d = @inferred Base.ObjectIdDict(Pair(1,1), Pair(2,2), Pair(3,3)) - @test isa(d, Base.ObjectIdDict) - @test d == Base.ObjectIdDict(1=>1, 2=>2, 3=>3) + d = @inferred IdDict{Any,Any}(Pair(1,1), Pair(2,2), Pair(3,3)) + @test isa(d, IdDict{Any,Any}) + @test d == IdDict{Any,Any}(1=>1, 2=>2, 3=>3) @test eltype(d) == Pair{Any,Any} end @@ -461,13 +461,23 @@ end d = @inferred IdDict(Pair(1,1), Pair(2,2), Pair(3,3)) @test isa(d, IdDict) @test d == IdDict(1=>1, 2=>2, 3=>3) + @test eltype(d) == Pair{Any,Any} + + d = @inferred IdDict{Int,Int}(Pair(1,1), Pair(2,2), Pair(3,3)) + @test d == IdDict{Int,Int}(1=>1, 2=>2, 3=>3) + @test d == IdDict{Any,Any}(1=>1, 2=>2, 3=>3) @test eltype(d) == Pair{Int,Int} - @test_throws MethodError d[:a] - @test_throws MethodError d[:a] = 1 + @test_throws KeyError d[:a] + @test_throws KeyError d[:a] = 1 @test_throws MethodError d[1] = :a + # copy constructor + d = IdDict{Int,Int}(Pair(1,1), Pair(2,2), Pair(3,3)) + @test collect(values(IdDict{Int,Float64}(d))) == collect(values(d)) + @test_throws KeyError IdDict{Float64,Int}(d) + # check that returned values are inferred - d = @inferred IdDict(Pair(1,1), Pair(2,2), Pair(3,3)) + d = @inferred IdDict{Int,Int}(Pair(1,1), Pair(2,2), Pair(3,3)) @test 1 == @inferred d[1] @inferred setindex!(d, -1, 10) @test d[10] == -1