Skip to content

Commit

Permalink
MersenneTwister: new constructors matching show
Browse files Browse the repository at this point in the history
E.g.
```
julia> m = MersenneTwister(0); rand(m); m
MersenneTwister(0, (0, 1002, 0, 1))

julia> m == MersenneTwister(0, (0, 1002, 0, 1))
true
```
  • Loading branch information
rfourquet committed Oct 4, 2020
1 parent be733c0 commit 7a19359
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
104 changes: 104 additions & 0 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,3 +721,107 @@ function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
s.adv_jump = adv_jump
s
end

# NON-PUBLIC
function jump(r::MersenneTwister, steps::Integer)
iseven(steps) || throw(DomainError(steps, "steps must be even"))
# steps >= 0 checked in calc_jump (`steps >> 1 < 0` if `steps < 0`)
j = _randjump(r, Random.DSFMT.calc_jump(steps >> 1))
j.adv_jump += steps
j
end

# NON-PUBLIC
jump!(r::MersenneTwister, steps::Integer) = copy!(r, jump(r, steps))


### constructors matching show (EXPERIMENTAL)

# parameters in the tuples are:
# 1: .adv_jump (jump steps)
# 2: .adv (number of generated floats at the DSFMT_state level since seeding, besides jumps)
# 3, 4: .adv_vals, .idxF (counters to reconstruct the float chache, optional if 5-8 not shown))
# 5-8: .adv_ints, .adv_vals_pre, .adv_idxF_pre, .idxI (counters to reconstruct the integer chache, optional)

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{8,Integer}) =
advance!(MersenneTwister(seed), advance...)

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{4,Integer}) =
MersenneTwister(seed, (advance..., -1, -1, -1, -1))

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{2,Integer}) =
MersenneTwister(seed, (advance..., 0, 0, -1, -1, -1, -1))

# advances raw state (per fill_array!) of r by n steps (Float64 values)
function _advance_n!(r::MersenneTwister, n::Int64, work::Vector{Float64})
n == 0 && return
n < 0 && throw(DomainError(n, "can't advance $r to the specified state"))
ms = dsfmt_get_min_array_size() % Int64
@assert n >= ms
lw = ms + n % ms
resize!(work, lw)
GC.@preserve work fill_array!(r, pointer(work), lw, CloseOpen12())
c::Int64 = lw
GC.@preserve work while n > c
fill_array!(r, pointer(work), ms, CloseOpen12())
c += ms
end
@assert n == c
end

function _advance_to!(r::MersenneTwister, adv::Int64, work)
_advance_n!(r, adv - r.adv, work)
@assert r.adv == adv
end

function _advance_F!(r::MersenneTwister, adv_vals, idxF, work)
if adv_vals == idxF == 0
# this case happens only when integer cache was generated before float cache
# then (0, 0) is printed instead of (-1, MT_CACHE_F) which is somewhat confusing;
# in this case, nothing to do, the float cache mustn't be filled
if r.adv_vals == -1 && r.idxF == MT_CACHE_F
return
else
throw(DomainError(n, "can't advance $r to the specified state"))
end
end
if r.adv_vals != adv_vals
_advance_to!(r, adv_vals, work)
gen_rand(r)
@assert r.adv_vals == adv_vals
end # otherwise, advancing was done automatically while generating the integer cache

r.idxF = idxF
nothing
end

function _advance_I!(r::MersenneTwister, adv_ints, idxI, work)
_advance_to!(r, adv_ints, work)
mt_setfull!(r, Int) # sets r.adv_ints
@assert r.adv_ints == adv_ints
r.idxI = 16*length(r.ints) - 8*idxI
end

function advance!(r::MersenneTwister, adv_jump, adv, adv_vals, idxF,
adv_ints, adv_vals_pre, adv_idxF_pre, idxI)
adv_jump = BigInt(adv_jump)
adv, adv_vals, adv_ints, adv_vals_pre = Int64.((adv, adv_vals, adv_ints, adv_vals_pre))
idxF, adv_idxF_pre, idxI = Int.((idxF, adv_idxF_pre, idxI))

ms = dsfmt_get_min_array_size() % Int
work = sizehint!(Vector{Float64}(), 2ms)
jump!(r, adv_jump)
if adv_vals_pre != -1
_advance_F!(r, adv_vals_pre, adv_idxF_pre, work)
_advance_I!(r, adv_ints, idxI, work)

@assert r.adv_vals_pre == adv_vals_pre ||
r.adv_vals_pre == -1 && adv_vals_pre == 0
@assert r.adv_idxF_pre == adv_idxF_pre ||
r.adv_idxF_pre == 1002 && adv_idxF_pre == 0

end
_advance_F!(r, adv_vals, idxF, work)
_advance_to!(r, adv, work)
r
end
25 changes: 25 additions & 0 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,28 @@ end
end
@test length(s) == n
end

@testset "show" begin
m = MersenneTwister(123)
@test string(m) == "MersenneTwister(123)"
Random.jump!(m, 2*big(10)^20)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 0))"
@test m == MersenneTwister(123, (200000000000000000000, 0))
rand(m)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 1002, 0, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 1002, 0, 1))
rand(m, Int64)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 2002, 0, 255, 1002, 0, 1, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 2002, 0, 255, 1002, 0, 1, 1))

m = MersenneTwister(0x0ecfd77f89dcd508caa37a17ebb7556b)
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b)"
rand(m, Int64)
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 2002, 1000, 254, 0, 0, 0, 1))"
@test m == MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 2002, 1000, 254, 0, 0, 0, 1))

# test when floats advancing is done by initializing ints, and (few) floats are then generated
m = MersenneTwister(0); rand(m, Int64); rand(m)
@test string(m) == "MersenneTwister(0, (0, 2002, 1000, 255, 0, 0, 0, 1))"
@test m == MersenneTwister(0, (0, 2002, 1000, 255, 0, 0, 0, 1))
end

0 comments on commit 7a19359

Please sign in to comment.