Skip to content

Commit

Permalink
MersenneTwister: make constructors matching show
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Sep 29, 2020
1 parent 80bd66d commit 1439840
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
102 changes: 102 additions & 0 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,105 @@ function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
s.adv_jump = adv_jump
s
end

# NON-PUBLIC
function jump(r::MersenneTwister, steps::Integer)
iseven(steps) || throw(ArgumentError("steps must be even"))
# steps >= 0 checked in calc_jump (`steps >> 1 < 0` if `steps < 0`)
j = _randjump(r, Random.DSFMT.calc_jump(steps >> 1))
j.adv_jump += steps
j
end

# NON-PUBLIC
jump!(r::MersenneTwister, steps::Integer) = (copy!(r, jump(r, steps)); r)


### constructors matching show (EXPERIMENTAL)

# parameters in the tuples are:
# 1: jump steps
# 2: number of generated floats at the DSFMT_state level since seeding, besides jumps
# 3, 4: counters to reconstruct the float chache (optional if 5-8 not shown)
# 5-8: counters to reconstruct the integer chache (optional)

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{8,Integer}) =
advance!(MersenneTwister(seed), advance...)

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{4,Integer}) =
MersenneTwister(seed, (advance..., -1, -1, -1, -1))

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{2,Integer}) =
MersenneTwister(seed, (advance..., 0, 0, -1, -1, -1, -1))

# advances raw state (per fill_array!) of r by n steps (Float64 values)
function _advance_n!(r::MersenneTwister, n::Int, work::Vector{Float64})
n == 0 && return
n < 0 && throw(DomainError(n, "can't advance $r to the specified state"))
ms = dsfmt_get_min_array_size() % Int
@assert n >= ms "$n >= $ms"
lw = ms + n % ms
resize!(work, lw)
GC.@preserve work fill_array!(r, pointer(work), lw, CloseOpen12())
c = lw
GC.@preserve work while n - c > 0
fill_array!(r, pointer(work), ms, CloseOpen12())
c += ms
end
@assert n - c == 0
end

function _advance_to!(r::MersenneTwister, adv::Int, work)
_advance_n!(r, adv - r.adv, work)
@assert r.adv == adv
end

function _advance_F!(r::MersenneTwister, adv_vals, idxF, work)
if adv_vals == idxF == 0
# this case happens only when integer cache was generated before float cache
# then (0, 0) is printed instead of (-1, MT_CACHE_F) which is somewhat confusing;
# in this case, nothing to do, the float cache mustn't be filled
if r.adv_vals == -1 && r.idxF == MT_CACHE_F
return
else
throw(DomainError(n, "can't advance $r to the specified state"))
end
end
if r.adv_vals != adv_vals
_advance_to!(r, adv_vals, work)
gen_rand(r)
@assert r.adv_vals == adv_vals
r.idxF = idxF
else
# advancing was done automatically while generating the integer cache
@assert r.idxF == idxF
end
nothing
end

function _advance_I!(r::MersenneTwister, adv_ints, idxI, work)
_advance_to!(r, adv_ints, work)
mt_setfull!(r, Int) # sets r.adv_ints
@assert r.adv_ints == adv_ints
r.idxI = 16*length(r.ints) - 8*idxI
end

function advance!(r::MersenneTwister, adv_jump, adv, adv_vals, idxF,
adv_ints, adv_vals_pre, adv_idxF_pre, idxI)
ms = dsfmt_get_min_array_size() % Int
work = sizehint!(Vector{Float64}(), 2ms)
jump!(r, adv_jump)
if adv_vals_pre != -1
_advance_F!(r, adv_vals_pre, adv_idxF_pre, work)
_advance_I!(r, adv_ints, idxI, work)

@assert r.adv_vals_pre == adv_vals_pre ||
r.adv_vals_pre == -1 && adv_vals_pre == 0
@assert r.adv_idxF_pre == adv_idxF_pre ||
r.adv_idxF_pre == 1002 && adv_idxF_pre == 0

end
_advance_F!(r, adv_vals, idxF, work)
_advance_to!(r, adv, work)
r
end
20 changes: 20 additions & 0 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,23 @@ end
@testset "RNGs broadcast as scalars: T" for T in (MersenneTwister, RandomDevice)
@test length.(rand.(T(), 1:3)) == 1:3
end

@testset "show" begin
m = MersenneTwister(123)
@test string(m) == "MersenneTwister(123)"
Random.jump!(m, 2*big(10)^20)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 0))"
@test m == MersenneTwister(123, (200000000000000000000, 0))
rand(m)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 1002, 0, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 1002, 0, 1))
rand(m, Int)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 2002, 0, 255, 1002, 0, 1, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 2002, 0, 255, 1002, 0, 1, 1))

m = MersenneTwister(0x0ecfd77f89dcd508caa37a17ebb7556b)
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b)"
rand(m, Int)
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 2002, 1000, 254, 0, 0, 0, 1))"
@test m == MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 2002, 1000, 254, 0, 0, 0, 1))
end

0 comments on commit 1439840

Please sign in to comment.