Skip to content

Commit

Permalink
add advance variable
Browse files Browse the repository at this point in the history
[ci skip]
[av skip]
  • Loading branch information
rfourquet committed Dec 17, 2017
1 parent e27419d commit da45200
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
26 changes: 18 additions & 8 deletions base/random/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,19 @@ mutable struct MersenneTwister <: AbstractRNG
state::DSFMT_state
vals::Vector{Float64}
idx::Int
advance::BigInt
advance_last_vals::BigInt

function MersenneTwister(seed, state, vals, idx)
function MersenneTwister(seed, state, vals, idx, advance, advance_last_vals)
length(vals) == MTCacheLength && 0 <= idx <= MTCacheLength ||
throw(DomainError((length(vals), idx),
"`length(vals)` and `idx` must be consistent with $MTCacheLength"))
new(seed, state, vals, idx)
new(seed, state, vals, idx, advance, advance_last_vals)
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength)
MersenneTwister(seed, state, zeros(Float64, MTCacheLength), MTCacheLength, 0, 0)

"""
MersenneTwister(seed)
Expand Down Expand Up @@ -134,7 +136,7 @@ copy(src::MersenneTwister) =
hash(r::MersenneTwister, h::UInt) = foldr(hash, h, (r.seed, r.state, r.vals, r.idx))

show(io::IO, rng::MersenneTwister) =
print(io, "MersenneTwister RNG with seed 0x$(hex(from_seed(m.seed)))")
print(io, "MersenneTwister(0x$(hex(from_seed(rng.seed))), $(rng.advance), $(rng.advance_last_vals), $(rng.idx))")


### low level API
Expand All @@ -146,7 +148,8 @@ mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]

function gen_rand(r::MersenneTwister)
@gc_preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals))
MPZ.set!(r.advance_last_vals, r.advance)
@gc_preserve r fill_array!(r, pointer(r.vals), length(r.vals), Close1Open2())
mt_setfull!(r)
end

Expand Down Expand Up @@ -337,6 +340,11 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
A
end

function fill_array!(rng::MersenneTwister, A::Ptr{Float64}, n::Int, I)
MPZ.add_ui!(rng.advance, n)
#rng.advance += n
fill_array!(rng.state, A, n, I)
end

fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
dsfmt_fill_array_close_open!(s, A, n)
Expand All @@ -361,10 +369,10 @@ function rand!(r::MersenneTwister, A::UnsafeView{Float64},
align = Csize_t(pA) % 16
if align > 0
pA2 = pA + 16 - align
fill_array!(r.state, pA2, n2, I[]) # generate the data in-place, but shifted
fill_array!(r, pA2, n2, I[]) # generate the data in-place, but shifted
unsafe_copyto!(pA, pA2, n2) # move the data to the beginning of the array
else
fill_array!(r.state, pA, n2, I[])
fill_array!(r, pA, n2, I[])
end
for i=n2+1:n
A[i] = rand(r, I[])
Expand Down Expand Up @@ -501,7 +509,9 @@ function randjump(mt::MersenneTwister, jumps::Integer, jumppoly::AbstractString)
push!(mts, mt)
for i in 1:jumps-1
cmt = mts[end]
push!(mts, MersenneTwister(copy(cmt.seed), dSFMT.dsfmt_jump(cmt.state, jumppoly)))
newrng = MersenneTwister(copy(cmt.seed), dSFMT.dsfmt_jump(cmt.state, jumppoly))
MPZ.set_ui!(newrng.advance, Int64(10)^20)
push!(mts, newrng)
end
return mts
end
Expand Down
2 changes: 2 additions & 0 deletions base/random/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X
function __init__()
try
srand()
GLOBAL_RNG.advance = big(0)
GLOBAL_RNG.advance_last_vals = big(0)
catch ex
Base.showerror_nostdio(ex,
"WARNING: Error during initialization of module Random")
Expand Down

0 comments on commit da45200

Please sign in to comment.