From 9876e0e25eefddd1ea825b8fedc9a789bfb61a89 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Sat, 19 Oct 2019 06:01:58 +0000 Subject: [PATCH] add type switch to weight tensor --- src/operator/numpy/random/np_choice_op.h | 20 +++++++++++++------- tests/python/unittest/test_numpy_op.py | 21 +++++++++++---------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h index 335cc2741759..a6a7cecfefd5 100644 --- a/src/operator/numpy/random/np_choice_op.h +++ b/src/operator/numpy/random/np_choice_op.h @@ -118,15 +118,17 @@ struct random_indices { // Weighted sample without replacement. // Use perturbed Gumbel variates as keys. +template struct generate_keys { - MSHADOW_XINLINE static void Map(index_t i, float *uniforms, float *weights) { + MSHADOW_XINLINE static void Map(index_t i, float *uniforms, IType *weights) { uniforms[i] = -logf(-logf(uniforms[i])) + logf(weights[i]); } }; // Weighted sample with replacement. +template struct categorical_sampling { - MSHADOW_XINLINE static void Map(index_t i, float *weights, size_t length, + MSHADOW_XINLINE static void Map(index_t i, IType *weights, size_t length, float *uniforms, int64_t *outs) { outs[i] = 0; float acc = 0.0; @@ -179,15 +181,19 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, prnd->SampleUniform(&random_numbers, 0, 1); workspace_ptr += ((random_tensor_size * sizeof(float) / 7 + 1) * 8); if (replace) { - Kernel::Launch( - s, output_size, inputs[weight_index].dptr(), input_size, - random_numbers.dptr_, outputs[0].dptr()); + MSHADOW_REAL_TYPE_SWITCH(inputs[weight_index].type_flag_, IType, { + Kernel, xpu>::Launch( + s, output_size, inputs[weight_index].dptr(), input_size, + random_numbers.dptr_, outputs[0].dptr()); + }); } else { Tensor indices = Tensor( reinterpret_cast(workspace_ptr), Shape1(indices_size), s); indices = expr::range((int64_t)0, input_size); - Kernel::Launch(s, input_size, random_numbers.dptr_, - inputs[weight_index].dptr()); + MSHADOW_REAL_TYPE_SWITCH(inputs[weight_index].type_flag_, IType, { + Kernel, xpu>::Launch(s, input_size, random_numbers.dptr_, + inputs[weight_index].dptr()); + }); _sort(random_numbers.dptr_, indices.dptr_, input_size); Copy(outputs[0].FlatTo1D(s), indices.Slice(0, output_size), s); } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2303c9cee29c..b0328935b68e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2368,16 +2368,17 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5, weight) # Test hypridize mode: - for hybridize in [True, False]: - for replace in [True, False]: - test_choice = TestUniformChoice(num_classes // 2, replace) - test_choice_weighted = TestWeightedChoice(num_classes // 2, replace) - if hybridize: - test_choice.hybridize() - test_choice_weighted.hybridize() - weight = np.array(_np.random.dirichlet([1.0] * num_classes)) - test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None) - test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) + for wtype in ['float16', 'float32', 'float64']: + for hybridize in [True, False]: + for replace in [True, False]: + test_choice = TestUniformChoice(num_classes // 2, replace) + test_choice_weighted = TestWeightedChoice(num_classes // 2, replace) + if hybridize: + test_choice.hybridize() + test_choice_weighted.hybridize() + weight = np.array(_np.random.dirichlet([1.0] * num_classes)).astype(wtype) + test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None) + test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) @with_seed()