diff --git a/base/random/RNGs.jl b/base/random/RNGs.jl index 113b7e3dc5bc1f..a28d0a6afcc434 100644 --- a/base/random/RNGs.jl +++ b/base/random/RNGs.jl @@ -2,6 +2,7 @@ ## RandomDevice + # SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...} SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...} const SamplerBoolBitInteger = SamplerUnion(Union{Bool, BitInteger}) @@ -60,24 +61,33 @@ srand(rng::RandomDevice) = rng ## MersenneTwister -const MTCacheLength = dsfmt_get_min_array_size() +const MT_CACHE_F = dsfmt_get_min_array_size() +const MT_CACHE_I = 501 << 4 mutable struct MersenneTwister <: AbstractRNG seed::Vector{UInt32} state::DSFMT_state vals::Vector{Float64} - idx::Int - - function MersenneTwister(seed, state, vals, idx) - length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength || - throw(DomainError((length(vals), idx), - "`length(vals)` and `idx` must be consistent with $MTCacheLength")) - new(seed, state, vals, idx) + ints::Vector{UInt128} + idxF::Int + idxI::Int + + function MersenneTwister(seed, state, vals, ints, idxF, idxI) + length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F || + throw(DomainError((length(vals), idxF), + "`length(vals)` and `idxF` must be consistent with $MT_CACHE_F")) + length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I || + throw(DomainError((length(ints), idxI), + "`length(ints)` and `idxI` must be consistent with $MT_CACHE_I")) + new(seed, state, vals, ints, idxF, idxI) end end MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) = - MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength) + MersenneTwister(seed, state, + Vector{Float64}(uninitialized, MT_CACHE_F), + Vector{UInt128}(uninitialized, MT_CACHE_I >> 4), + MT_CACHE_F, 0) """ MersenneTwister(seed) @@ -120,27 +130,38 @@ function copy!(dst::MersenneTwister, src::MersenneTwister) copyto!(resize!(dst.seed, length(src.seed)), src.seed) copy!(dst.state, src.state) copyto!(dst.vals, src.vals) - dst.idx = src.idx + copyto!(dst.ints, src.ints) + dst.idxF = src.idxF + dst.idxI = src.idxI dst end copy(src::MersenneTwister) = - MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), src.idx) + MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints), + src.idxF, src.idxI) + ==(r1::MersenneTwister, r2::MersenneTwister) = - r1.seed == r2.seed && r1.state == r2.state && isequal(r1.vals, r2.vals) && - r1.idx == r2.idx + r1.seed == r2.seed && r1.state == r2.state && + isequal(r1.vals, r2.vals) && + isequal(r1.ints, r2.ints) && + r1.idxF == r2.idxF && r1.idxI == r2.idxI -hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.idx)) +hash(r::MersenneTwister, h::UInt) = + foldr(hash, h, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI)) +### Wrapper which caches generated integers + ### low level API -mt_avail(r::MersenneTwister) = MTCacheLength - r.idx -mt_empty(r::MersenneTwister) = r.idx == MTCacheLength -mt_setfull!(r::MersenneTwister) = r.idx = 0 -mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength -mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1] +#### floats + +mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF +mt_empty(r::MersenneTwister) = r.idxF == MT_CACHE_F +mt_setfull!(r::MersenneTwister) = r.idxF = 0 +mt_setempty!(r::MersenneTwister) = r.idxF = MT_CACHE_F +mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idxF+=1] function gen_rand(r::MersenneTwister) @gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals)) @@ -149,9 +170,55 @@ end reserve_1(r::MersenneTwister) = (mt_empty(r) && gen_rand(r); nothing) # `reserve` allows one to call `rand_inbounds` n times -# precondition: n <= MTCacheLength +# precondition: n <= MT_CACHE_F reserve(r::MersenneTwister, n::Int) = (mt_avail(r) < n && gen_rand(r); nothing) +#### ints + +logsizeof(::Type{<:Union{Bool,Int8,UInt8}}) = 0 +logsizeof(::Type{<:Union{Int16,UInt16}}) = 1 +logsizeof(::Type{<:Union{Int32,UInt32}}) = 2 +logsizeof(::Type{<:Union{Int64,UInt64}}) = 3 +logsizeof(::Type{<:Union{Int128,UInt128}}) = 4 + +idxmask(::Type{<:Union{Bool,Int8,UInt8}}) = 15 +idxmask(::Type{<:Union{Int16,UInt16}}) = 7 +idxmask(::Type{<:Union{Int32,UInt32}}) = 3 +idxmask(::Type{<:Union{Int64,UInt64}}) = 1 +idxmask(::Type{<:Union{Int128,UInt128}}) = 0 + + +mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} = + r.idxI >> logsizeof(T) + +function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger}) + rand!(r, r.ints) + r.idxI = MT_CACHE_I +end + +mt_setempty!(r::MersenneTwister, ::Type{<:BitInteger}) = r.idxI = 0 + +function reserve1(r::MersenneTwister, ::Type{T}) where T<:BitInteger + r.idxI < sizeof(T) && mt_setfull!(r, T) + nothing +end + +function mt_pop!(r::MersenneTwister, ::Type{T}) where T<:BitInteger + reserve1(r, T) + r.idxI -= sizeof(T) + i = r.idxI + @inbounds x128 = r.ints[1 + i >> 4] + i128 = (i >> logsizeof(T)) & idxmask(T) # 0-based "indice" in x128 + (x128 >> (i128 * (sizeof(T) << 3))) % T +end +#= +function mt_pop!(r::MersenneTwister, ::Type{T}) where {T<:Union{Int128,UInt128}} + reserve1(r, T) + @inbounds res = r.ints[r.idxI >> 4] + r.idxI -= 16 + res +end +=# ### seeding @@ -193,6 +260,9 @@ function srand(r::MersenneTwister, seed::Vector{UInt32}) copyto!(resize!(r.seed, length(seed)), seed) dsfmt_init_by_array(r.state, r.seed) mt_setempty!(r) + fill!(r.vals, 0.0) # not strictly necessary, but why not, makes comparing two MT easier + mt_setempty!(r, UInt128) + fill!(r.ints, 0) return r end @@ -243,24 +313,8 @@ rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) = #### integers -rand(r::MersenneTwister, - T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) = - rand(r, UInt52Raw()) % T[] - -function rand(r::MersenneTwister, ::SamplerType{UInt64}) - reserve(r, 2) - rand_inbounds(r, UInt52Raw()) << 32 ⊻ rand_inbounds(r, UInt52Raw()) -end - -function rand(r::MersenneTwister, ::SamplerType{UInt128}) - reserve(r, 3) - xor(rand_inbounds(r, UInt52Raw(UInt128)) << 96, - rand_inbounds(r, UInt52Raw(UInt128)) << 48, - rand_inbounds(r, UInt52Raw(UInt128))) -end - -rand(r::MersenneTwister, ::SamplerType{Int64}) = rand(r, UInt64) % Int64 -rand(r::MersenneTwister, ::SamplerType{Int128}) = rand(r, UInt128) % Int128 +rand(r::MersenneTwister, T::SamplerUnion(BitInteger)) = mt_pop!(r, T[]) +rand(r::MersenneTwister, ::SamplerType{Bool}) = rand(r, UInt8) % Bool #### arrays of floats @@ -315,13 +369,13 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter mt_avail(r) == 0 && gen_rand(r) # from now on, at most one call to gen_rand(r) will be necessary m = min(n, mt_avail(r)) - @gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idx+1), m) + @gc_preserve r unsafe_copyto!(A.ptr, pointer(r.vals, r.idxF+1), m) if m == n - r.idx += m + r.idxF += m else # m < n gen_rand(r) @gc_preserve r unsafe_copyto!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m) - r.idx = n-m + r.idxF = n-m end if I isa CloseOpen for i=1:n @@ -470,7 +524,7 @@ end #### from a range -for T in (Bool, BitInteger_types...) # eval because of ambiguity otherwise +for T in BitInteger_types # eval because of ambiguity otherwise @eval Sampler(rng::MersenneTwister, r::UnitRange{$T}, ::Val{1}) = SamplerRangeFast(r) end diff --git a/base/random/generation.jl b/base/random/generation.jl index 3b4aa83469b7f6..05d5569d666062 100644 --- a/base/random/generation.jl +++ b/base/random/generation.jl @@ -12,7 +12,6 @@ # Note that the 1) is automated when the sampler is not intended to carry information, # i.e. the default fall-backs SamplerType and SamplerTrivial are used. - ## from types: rand(::Type, [dims...]) ### random floats @@ -101,6 +100,8 @@ rand(rng::AbstractRNG, sp::SamplerBigFloat{T}) where {T<:FloatInterval{BigFloat} ### random integers +#### UniformBits + rand(r::AbstractRNG, ::SamplerTrivial{UInt10Raw{UInt16}}) = rand(r, UInt16) rand(r::AbstractRNG, ::SamplerTrivial{UInt23Raw{UInt32}}) = rand(r, UInt32) @@ -111,7 +112,7 @@ _rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Ope _rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64) rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) = - rand(r, UInt52Raw(UInt128)) << 52 ⊻ rand_inbounds(r, UInt52Raw(UInt128)) + rand(r, UInt52Raw(UInt128)) << 52 ⊻ rand(r, UInt52Raw(UInt128)) rand(r::AbstractRNG, ::SamplerTrivial{UInt10{UInt16}}) = rand(r, UInt10Raw()) & 0x03ff rand(r::AbstractRNG, ::SamplerTrivial{UInt23{UInt32}}) = rand(r, UInt23Raw()) & 0x007fffff @@ -121,6 +122,32 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw()) rand(r::AbstractRNG, sp::SamplerTrivial{<:UniformBits{T}}) where {T} = rand(r, uint_default(sp[])) % T +#### BitInteger + +# rand_generic methods are intended to help RNG implementors with common operations +# we don't call them simply `rand` as this can easily contribute to create +# amibuities with user-side methods (forcing the user to resort to @eval) + +rand_generic(r::AbstractRNG, T::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}) = + rand(r, UInt52Raw()) % T[] + +rand_generic(r::AbstractRNG, ::Type{UInt64}) = + rand(r, UInt52Raw()) << 32 ⊻ rand(r, UInt52Raw()) + +rand_generic(r::AbstractRNG, ::Type{UInt128}) = _rand128(r, rng_native_52(r)) + +_rand128(r::AbstractRNG, ::Type{UInt64}) = + ((rand(r, UInt64) % UInt128) << 64) ⊻ rand(r, UInt64) + +function _rand128(r::AbstractRNG, ::Type{Float64}) + xor(rand(r, UInt52Raw(UInt128)) << 96, + rand(r, UInt52Raw(UInt128)) << 48, + rand(r, UInt52Raw(UInt128))) +end + +rand_generic(r::AbstractRNG, ::Type{Int128}) = rand(r, UInt128) % Int128 +rand_generic(r::AbstractRNG, ::Type{Int64}) = rand(r, UInt64) % Int64 + ### random complex numbers rand(r::AbstractRNG, ::SamplerType{Complex{T}}) where {T<:Real} = @@ -149,25 +176,27 @@ end #### helper functions -uint_sup(::Type{<:Union{Bool,BitInteger}}) = UInt32 +# disabled temporarily for reproducibility on 32 & 64 bit architectures +# uint_sup(::Type{<:Union{BitInteger}}) = UInt32 +uint_sup(::Type{<:BitInteger}) = UInt64 uint_sup(::Type{<:Union{Int64,UInt64}}) = UInt64 uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128 #### Fast -struct SamplerRangeFast{U<:BitUnsigned,T<:Union{BitInteger,Bool}} <: Sampler +struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler a::T # first element of the range bw::UInt # bit width m::U # range length - 1 mask::U # mask generated values before threshold rejection end -SamplerRangeFast(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} = +SamplerRangeFast(r::AbstractUnitRange{T}) where T<:BitInteger = SamplerRangeFast(r, uint_sup(T)) function SamplerRangeFast(r::AbstractUnitRange{T}, ::Type{U}) where {T,U} isempty(r) && throw(ArgumentError("range must be non-empty")) - m = (last(r) - first(r)) % U + m = (last(r) - first(r)) % unsigned(T) % U # % unsigned(T) to not propagate sign bit bw = (sizeof(U) << 3 - leading_zeros(m)) % UInt # bit-width mask = (1 % U << bw) - (1 % U) SamplerRangeFast{U,T}(first(r), bw, m, mask) @@ -215,7 +244,7 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} = unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} = div(sup, k + (k == 0))*k - one(k) -struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler +struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler a::T # first element of the range bw::Int # bit width k::U # range length or zero for full range @@ -223,13 +252,13 @@ struct SamplerRangeInt{T<:Union{Bool,Integer},U<:Unsigned} <: Sampler end -SamplerRangeInt(r::AbstractUnitRange{T}) where T<:Union{Bool,BitInteger} = +SamplerRangeInt(r::AbstractUnitRange{T}) where T<:BitInteger = SamplerRangeInt(r, uint_sup(T)) function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U} isempty(r) && throw(ArgumentError("range must be non-empty")) a = first(r) - m = (last(r) - first(r)) % U + m = (last(r) - first(r)) % unsigned(T) % U k = m + one(U) bw = (sizeof(U) << 3 - leading_zeros(m)) % Int mult = if U === UInt32 @@ -247,11 +276,11 @@ function SamplerRangeInt(r::AbstractUnitRange{T}, ::Type{U}) where {T,U} end Sampler(::AbstractRNG, r::AbstractUnitRange{T}, - ::Repetition) where {T<:Union{Bool,BitInteger}} = SamplerRangeInt(r) + ::Repetition) where {T<:BitInteger} = SamplerRangeInt(r) -rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:Union{Bool,BitInteger}} = - (unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T +rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt32}) where {T<:BitInteger} = + (unsigned(sp.a) + rem_knuth(rand(rng, LessThan(sp.u, uniform(UInt32))), sp.k)) % T # this function uses 52 bit entropy for small ranges of length <= 2^52 function rand(rng::AbstractRNG, sp::SamplerRangeInt{T,UInt64}) where T<:BitInteger diff --git a/test/random.jl b/test/random.jl index 88eb56a5a34d9b..1cc4a2a1591ce6 100644 --- a/test/random.jl +++ b/test/random.jl @@ -3,7 +3,7 @@ isdefined(Main, :TestHelpers) || @eval Main include(joinpath(dirname(@__FILE__), "TestHelpers.jl")) using Main.TestHelpers.OAs -using Base.Random: Sampler, SamplerRangeFast, SamplerRangeInt, dSFMT +using Base.Random: Sampler, SamplerRangeFast, SamplerRangeInt, dSFMT, MT_CACHE_F, MT_CACHE_I @testset "Issue #6573" begin srand(0) @@ -38,7 +38,7 @@ let A = zeros(2, 2) end let A = zeros(2, 2) @test_throws ArgumentError rand!(MersenneTwister(0), A, 5) - @test rand(MersenneTwister(0), Int64, 1) == [4439861565447045202] + @test rand(MersenneTwister(0), Int64, 1) == [5986602421100169002] end let A = zeros(Int64, 2, 2) rand!(MersenneTwister(0), A) @@ -278,11 +278,10 @@ let mt = MersenneTwister(0) B = Vector{T}(uninitialized, 31) rand!(mt, A) rand!(mt, B) - @test A[end] == Any[21,0x7b,17385,0x3086,-1574090021,0xadcb4460,6797283068698303107,0x4e91c9c4d4f5f759, - -3482609696641744459568613291754091152,Float16(0.03125),0.68733835f0][i] - - @test B[end] == Any[49,0x65,-3725,0x719d,814246081,0xdf61843a,-3010919637398300844,0x61b367cf8810985d, - -33032345278809823492812856023466859769,Float16(0.95),0.51829386f0][i] + @test A[end] == Any[21, 0x4e, -3158, 0x0ded, 2132370312, 0x5e76d222, 1701112237820550475, 0xde7c8e58fb113739, + -17260403799139981754163727590537874047, Float16(0.90234), 0.0909704f0][i] + @test B[end] == Any[94, 0xb8, 3111, 0xefa4, 411531475, 0xd8089c1d, -7344871485543005232, 0xedb4b5c61c037a43, + -118178167582054157562031602894265066400, Float16(0.91211), 0.2516626f0][i] end srand(mt, 0) @@ -561,10 +560,18 @@ end # MersenneTwister initialization with invalid values @test_throws DomainError Base.dSFMT.DSFMT_state(zeros(Int32, rand(0:Base.dSFMT.JN32-1))) + +@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(), + zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0) + @test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(), - zeros(Float64, 10), 0) + zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0) + +@test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(), + zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0) + @test_throws DomainError MersenneTwister(zeros(UInt32, 1), Base.dSFMT.DSFMT_state(), - zeros(Float64, Base.Random.MTCacheLength), -1) + zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1) # seed is private to MersenneTwister let seed = rand(UInt32, 10)