diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 14875601cd25..26f7762ca9b5 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1911,12 +1911,15 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): if continuous_dist: sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right') else: - sample_bucket_ids = samples + sample_bucket_ids = np.array(samples) if continuous_dist: sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) - for i in range(len(buckets)): - obs_freq[i] = (sample_bucket_ids == i).sum() + for i, _ in enumerate(buckets): + if continuous_dist: + obs_freq[i] = (sample_bucket_ids == i).sum() + else: + obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq