From cf6e8cbd035bf315b3e8280416468a629c780d03 Mon Sep 17 00:00:00 2001 From: Chaitanya Prakash Bapat Date: Wed, 5 Dec 2018 23:20:08 -0800 Subject: [PATCH] Chi_square_check for discrete distribution fix (#13543) * check for bucket instead of index * enumerate instead of range(len()) * count instead of sum to solve attribute error * revert to sum * seperate discrete and continuous * Trigger CI --- python/mxnet/test_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 14875601cd25..26f7762ca9b5 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1911,12 +1911,15 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): if continuous_dist: sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right') else: - sample_bucket_ids = samples + sample_bucket_ids = np.array(samples) if continuous_dist: sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) - for i in range(len(buckets)): - obs_freq[i] = (sample_bucket_ids == i).sum() + for i, _ in enumerate(buckets): + if continuous_dist: + obs_freq[i] = (sample_bucket_ids == i).sum() + else: + obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq