diff --git a/base/random/RNGs.jl b/base/random/RNGs.jl index f3e6602411bd1..35ed9245e2166 100644 --- a/base/random/RNGs.jl +++ b/base/random/RNGs.jl @@ -67,17 +67,19 @@ mutable struct MersenneTwister <: AbstractRNG state::DSFMT_state vals::Vector{Float64} idx::Int + advance::BigInt + advance_last_vals::BigInt - function MersenneTwister(seed, state, vals, idx) + function MersenneTwister(seed, state, vals, idx, advance, advance_last_vals) length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength || throw(DomainError((length(vals), idx), "`length(vals)` and `idx` must be consistent with $MTCacheLength")) - new(seed, state, vals, idx) + new(seed, state, vals, idx, advance, advance_last_vals) end end MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) = - MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength) + MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength, 0, 0) """ MersenneTwister(seed) @@ -134,7 +136,7 @@ copy(src::MersenneTwister) = hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.idx)) show(io::IO, rng::MersenneTwister) = - print(io, "MersenneTwister RNG with seed 0x$(hex(from_seed(m.seed)))") + print(io, "MersenneTwister(0x$(hex(from_seed(rng.seed))), $(rng.advance), $(rng.advance_last_vals), $(rng.idx))") ### low level API @@ -146,7 +148,8 @@ mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1] function gen_rand(r::MersenneTwister) - @gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals)) + MPZ.set!(r.advance_last_vals, r.advance) + @gc_preserve r fill_array!(r, pointer(r.vals), length(r.vals), Close1Open2()) mt_setfull!(r) end @@ -337,6 +340,11 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter A end +function fill_array!(rng::MersenneTwister, A::Ptr{Float64}, n::Int, I) + MPZ.add_ui!(rng.advance, n) + #rng.advance += n + fill_array!(rng.state, A, n, I) +end fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) = dsfmt_fill_array_close_open!(s, A, n) @@ -361,10 +369,10 @@ function rand!(r::MersenneTwister, A::UnsafeView{Float64}, align = Csize_t(pA) % 16 if align > 0 pA2 = pA + 16 - align - fill_array!(r.state, pA2, n2, I[]) # generate the data in-place, but shifted + fill_array!(r, pA2, n2, I[]) # generate the data in-place, but shifted unsafe_copyto!(pA, pA2, n2) # move the data to the beginning of the array else - fill_array!(r.state, pA, n2, I[]) + fill_array!(r, pA, n2, I[]) end for i=n2+1:n A[i] = rand(r, I[]) @@ -501,7 +509,9 @@ function randjump(mt::MersenneTwister, jumps::Integer, jumppoly::AbstractString) push!(mts, mt) for i in 1:jumps-1 cmt = mts[end] - push!(mts, MersenneTwister(copy(cmt.seed), dSFMT.dsfmt_jump(cmt.state, jumppoly))) + newrng = MersenneTwister(copy(cmt.seed), dSFMT.dsfmt_jump(cmt.state, jumppoly)) + MPZ.set_ui!(newrng.advance, Int64(10)^20) + push!(mts, newrng) end return mts end diff --git a/base/random/random.jl b/base/random/random.jl index 6ed336d2e4412..0281d5f0cef7b 100644 --- a/base/random/random.jl +++ b/base/random/random.jl @@ -227,6 +227,8 @@ rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X function __init__() try srand() + GLOBAL_RNG.advance = big(0) + GLOBAL_RNG.advance_last_vals = big(0) catch ex Base.showerror_nostdio(ex, "WARNING: Error during initialization of module Random")