From 4328e30dc53da2983fa747876e9b546dea726ba5 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 4 Dec 2018 17:45:03 -0800 Subject: [PATCH 1/6] check for bucket instead of index --- python/mxnet/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 14875601cd25..bf75c0f4c745 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1916,7 +1916,7 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) for i in range(len(buckets)): - obs_freq[i] = (sample_bucket_ids == i).sum() + obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq From c989695c4efc6c747d54b56efcbe2bdcd5c4f09c Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 4 Dec 2018 18:03:57 -0800 Subject: [PATCH 2/6] enumerate instead of range(len()) --- python/mxnet/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index bf75c0f4c745..41a1fa7a5a1d 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1915,8 +1915,8 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): if continuous_dist: sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) - for i in range(len(buckets)): - obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() + for i, _ in enumerate(buckets): + obs_freq[i] = (sample_bucket_ids == buckets[i]).astype('uint8').sum() _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq From d16dd07cfb88e7959343f9f0adb104dea227ec30 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 4 Dec 2018 22:08:09 -0800 Subject: [PATCH 3/6] count instead of sum to solve attribute error --- python/mxnet/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 41a1fa7a5a1d..7ac6b74820da 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1916,7 +1916,7 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) for i, _ in enumerate(buckets): - obs_freq[i] = (sample_bucket_ids == buckets[i]).astype('uint8').sum() + obs_freq[i] = sample_bucket_ids.count(buckets[i]) _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq From 128e9f964c054f28c2464aa5966386f90622bc6b Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Tue, 4 Dec 2018 23:38:54 -0800 Subject: [PATCH 4/6] revert to sum --- python/mxnet/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 7ac6b74820da..a33df1b17d53 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1916,7 +1916,7 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) for i, _ in enumerate(buckets): - obs_freq[i] = sample_bucket_ids.count(buckets[i]) + obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq From 7df22dd98046c0e823518066d7fd6643a4ed7710 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 5 Dec 2018 10:24:15 -0800 Subject: [PATCH 5/6] seperate discrete and continuous --- python/mxnet/test_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index a33df1b17d53..26f7762ca9b5 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1911,12 +1911,15 @@ def chi_square_check(generator, buckets, probs, nsamples=1000000): if continuous_dist: sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right') else: - sample_bucket_ids = samples + sample_bucket_ids = np.array(samples) if continuous_dist: sample_bucket_ids = sample_bucket_ids // 2 obs_freq = np.zeros(shape=len(buckets), dtype=np.int) for i, _ in enumerate(buckets): - obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() + if continuous_dist: + obs_freq[i] = (sample_bucket_ids == i).sum() + else: + obs_freq[i] = (sample_bucket_ids == buckets[i]).sum() _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq) return p, obs_freq, expected_freq From e3ec9f28254bb10a351c00312b3479219e6f38c6 Mon Sep 17 00:00:00 2001 From: ChaiBapchya Date: Wed, 5 Dec 2018 15:06:18 -0800 Subject: [PATCH 6/6] Trigger CI