Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix flaky test test_random:test_randint_generator #13498

Merged
merged 7 commits into from
Dec 7, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,7 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000):

Usually the user is required to specify the probs parameter.

After obtatining the p value, we could further use the standard p > 0.05 threshold to get \
After obtaining the p value, we could further use the standard p > 0.05 (alpha) threshold to get \
the final result.

Examples::
Expand Down Expand Up @@ -1920,7 +1920,7 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000):
_, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq)
return p, obs_freq, expected_freq

def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, success_rate=0.15):
def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, success_rate=0.25, alpha=0.05):
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
"""Verify whether the generator is correct using chi-square testing.

The test is repeated for "nrepeat" times and we check if the success rate is
Expand All @@ -1943,6 +1943,8 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
The times to repeat the test
success_rate: float
The desired success rate
alpha: float
The desired threshold for type-I error i.e. when a true null hypothesis is rejected

Returns
-------
Expand All @@ -1958,7 +1960,7 @@ def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, suc
cs_ret_l.append(cs_ret)
obs_freq_l.append(obs_freq)
expected_freq_l.append(expected_freq)
success_num = (np.array(cs_ret_l) > 0.05).sum()
success_num = (np.array(cs_ret_l) > alpha).sum()
if success_num < nrepeat * success_rate:
raise AssertionError("Generator test fails, Chi-square p=%s, obs_freq=%s, expected_freq=%s."
"\nbuckets=%s, probs=%s"
Expand Down
11 changes: 6 additions & 5 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,8 +859,8 @@ def test_randint_extremes():
a = mx.nd.random.randint(dtype='int64', low=50000000, high=50000010, ctx=mx.context.current_context())
assert a>=50000000 and a<=50000010

@with_seed()
@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/13446")
# Seed set because the test is not robust enough to operate on random data
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
@with_seed(1234)
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
def test_randint_generator():
ctx = mx.context.current_context()
for dtype in ['int32', 'int64']:
Expand All @@ -871,14 +871,15 @@ def test_randint_generator():
buckets = np.array(buckets, dtype=dtype).tolist()
probs = [(buckets[i][1] - buckets[i][0]) / float(scale) for i in range(5)]
generator_mx = lambda x: mx.nd.random.randint(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy()
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
verify_generator(generator=generator_mx, buckets=buckets, probs=probs, nrepeat=100, alpha=0.01)
# Scipy uses alpha = 0.01 for testing discrete distribution generator
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
generator_mx_same_seed = \
lambda x: np.concatenate(
[mx.nd.random.randint(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
for _ in range(10)])
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs, nrepeat=100, alpha=0.01)

with_seed()
@with_seed()
def test_randint_without_dtype():
a = mx.nd.random.randint(low=50000000, high=50000010, ctx=mx.context.current_context())
assert(a.dtype, 'int32')
Expand Down