From c50226f23385b090b29e25b884419b071f2e8ff4 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Tue, 3 Sep 2019 06:18:28 +0000 Subject: [PATCH 1/8] imperative choice done --- python/mxnet/ndarray/numpy/random.py | 25 +- python/mxnet/numpy/random.py | 6 +- src/operator/numpy/random/np_choice_op.cc | 57 +++++ src/operator/numpy/random/np_choice_op.cu | 25 ++ src/operator/numpy/random/np_choice_op.h | 274 ++++++++++++++++++++++ tests/python/unittest/test_numpy_op.py | 36 +++ 6 files changed, 421 insertions(+), 2 deletions(-) create mode 100644 src/operator/numpy/random/np_choice_op.cc create mode 100644 src/operator/numpy/random/np_choice_op.cu create mode 100644 src/operator/numpy/random/np_choice_op.h diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index be918615bfd9..99bc1e8c92e7 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -20,7 +20,7 @@ from ...context import current_context from . import _internal as _npi -__all__ = ['uniform'] +__all__ = ['uniform', 'choice'] def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): @@ -79,3 +79,26 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): raise ValueError( "Distribution parameters must be either mxnet.numpy.ndarray or numbers") + + +def choice(a, size=None, replace=True, p=None, ctx=None, out=None): + from ...numpy import ndarray as np_ndarray + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + + if isinstance(a, np_ndarray): + if p is None: + indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) + return a[indices] + else: + indices = _npi.choice(a, p, a=None, size=size, replace=replace, ctx=ctx, weighted=True) + return a[indices] + else: + if p is None: + return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) + else: + return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index f85936345b7f..6bd9f176e45c 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['uniform'] +__all__ = ['uniform', 'choice'] def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): @@ -55,3 +55,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): Drawn samples from the parameterized uniform distribution. """ return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out) + + +def choice(a, size=None, replace=True, p=None, ctx=None, out=None): + return _mx_nd_np.random.choice(a, size, replace, p, ctx, out) diff --git a/src/operator/numpy/random/np_choice_op.cc b/src/operator/numpy/random/np_choice_op.cc new file mode 100644 index 000000000000..4a3c5c190629 --- /dev/null +++ b/src/operator/numpy/random/np_choice_op.cc @@ -0,0 +1,57 @@ +#include "./np_choice_op.h" +#include + +namespace mxnet { +namespace op { + +template<> +void _swap(int64_t& a, int64_t& b) { + std::swap(a, b); +} + +template<> +void _sort(float* key, int64_t* data, index_t length) { + std::sort(data, data + length, + [key](int64_t const& i, int64_t const& j) -> bool { + return key[i] > key[j]; + }); +} + + +DMLC_REGISTER_PARAMETER(NumpyChoiceParam); + + +NNVM_REGISTER_OP(_npi_choice) +.describe("random choice") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + int num_input = 0; + const NumpyChoiceParam& param = nnvm::get(attrs.parsed); + if (param.weighted) num_input += 1; + if (!param.a.has_value()) num_input += 1; + return num_input; + } +) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"input1", "input2"}; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyChoiceOpShape) +.set_attr("FInferType", NumpyChoiceOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyChoiceForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyChoiceParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + + diff --git a/src/operator/numpy/random/np_choice_op.cu b/src/operator/numpy/random/np_choice_op.cu new file mode 100644 index 000000000000..f7967346a2a8 --- /dev/null +++ b/src/operator/numpy/random/np_choice_op.cu @@ -0,0 +1,25 @@ +#include "./np_choice_op.h" +#include +#include +#include + +namespace mxnet { +namespace op { + +template<> +void _swap(int64_t& a, int64_t& b) { + thrust::swap(a, b); +} + +template<> +void _sort(float* key, int64_t* data, index_t length) { + thrust::device_ptr dev_key(key); + thrust::device_ptr dev_data(data); + thrust::sort_by_key(dev_key, dev_key + length, dev_data, thrust::greater()); +} + +NNVM_REGISTER_OP(_npi_choice) +.set_attr("FCompute", NumpyChoiceForward); + +} // namespace op +} // namespace mxnet \ No newline at end of file diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h new file mode 100644 index 000000000000..9c8d83ac4e47 --- /dev/null +++ b/src/operator/numpy/random/np_choice_op.h @@ -0,0 +1,274 @@ +#ifndef NP_CHOICE_OP_H +#define NP_CHOICE_OP_H + +#include +#include +#include +#include +#include +#include "../../elemwise_op_common.h" +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../tensor/elemwise_binary_broadcast_op.h" + +namespace mxnet { +namespace op { + +struct NumpyChoiceParam: public dmlc::Parameter { + dmlc::optional a; + // int64_t a; + std::string ctx; + dmlc::optional> size; + // int64_t size; + bool replace; + bool weighted; + DMLC_DECLARE_PARAMETER(NumpyChoiceParam) { + DMLC_DECLARE_FIELD(a); + DMLC_DECLARE_FIELD(size); + DMLC_DECLARE_FIELD(ctx) + .set_default("cpu"); + DMLC_DECLARE_FIELD(replace).set_default(true); + DMLC_DECLARE_FIELD(weighted).set_default(false); + } +}; + +inline bool NumpyChoiceOpType(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + (*out_attrs)[0] = mshadow::kInt64; + return true; +} + +inline bool NumpyChoiceOpShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + // begin + const NumpyChoiceParam ¶m = nnvm::get(attrs.parsed); + int64_t a; + if (param.size.has_value()) { + // Size declared. + std::vector oshape_vec; + const mxnet::Tuple &size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(0, -1)) + } + return true; +} + +template +void _swap(int64_t& a, int64_t& b); + +template +void _sort(float* key, int64_t* data, index_t length); + +namespace mxnet_op { + +// Uniform sample without replacement. +struct generate_samples { + MSHADOW_XINLINE static void Map(index_t i, int64_t k, + unsigned *rands) { + // printf("%d:%u,%d\n", i, rands[i], k); + rands[i] = rands[i] % (i + k + 1); + // printf("sample[%d]:%u\n", i, rands[i]); + } +}; + +template +struct generate_reservoir { + MSHADOW_XINLINE static void Map(index_t dummy_index, + int64_t *indices, unsigned *samples, + int64_t nb_iterations, int64_t k) { + for (int64_t i = 0; i < nb_iterations; i++) { + int64_t z = samples[i]; + // printf("z:%d\n", z); + // printf("k:%d\n", k); + if (z < k) { + // _swap(indices[z], indices[i + k]); + int64_t t = indices[z]; + indices[z] = indices[i + k]; + indices[i + k] = t; + } + } + } +}; + +// Uniform sample with replacement. +struct random_indices { + MSHADOW_XINLINE static void Map(index_t i, unsigned *samples, int64_t *outs, int64_t k) { + outs[i] = samples[i] % k; + } +}; + +// Weighted sample without replacement. +// Use perturbed Gumbel variates as keys. +struct generate_keys { + MSHADOW_XINLINE static void Map(index_t i, float* uniforms, float* weights) { + uniforms[i] = -logf(-logf(uniforms[i])) + logf(weights[i]); + } +}; + +// Weighted sample with replacement. +struct categorical_sampling { + MSHADOW_XINLINE static void Map(index_t i, float* weights, size_t length, float* uniforms, int64_t *outs) { + outs[i] = 0; + float acc = 0.0; + float threshold = uniforms[i]; + for (size_t k = 0; k < length; k++) { + acc += weights[k]; + if (acc < threshold) { + outs[i] += 1; + } + } + } +}; + +} // namespace mxnet_op + + +template +void weighted_reservoir_sampling() { + +} + +template +void reservoir_sampling() { + // allocate +} + +template +void sampling_with_replacement() { + +} + +template +void weighted_sampling_with_replacement() { + +} + + +template +void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + // forward + using namespace mshadow; + using namespace mxnet_op; + const NumpyChoiceParam ¶m = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + bool replace = param.replace; + bool weighted = param.weighted; + int64_t input_size = 0; + int weight_index = 0; + if (param.a.has_value()) { + input_size = param.a.value(); + } else { + input_size = inputs[0].Size(); + weight_index += 1; + } + int64_t output_size = outputs[0].Size(); + // printf("%p\n", workspace_ptr); + if (weighted) { + Random *prnd = ctx.requested[0].get_random(s); + int64_t random_tensor_size = replace ? output_size : input_size; + int64_t indices_size = replace ? 0 : input_size; + Tensor workspace = + (ctx.requested[1].get_space_typed( + Shape1(indices_size * sizeof(int64_t) + + (random_tensor_size * sizeof(float) / 7 + 1) * 8), + s)); + // slice workspace + char *workspace_ptr = workspace.dptr_; + Tensor random_numbers = + Tensor(reinterpret_cast(workspace_ptr), + Shape1(random_tensor_size), s); + prnd->SampleUniform(&random_numbers, 0, 1); + workspace_ptr += ((random_tensor_size * sizeof(float) / 7 + 1)* 8); + if (replace) { + Kernel::Launch(s, output_size, + inputs[weight_index].dptr(), + input_size, + random_numbers.dptr_, + outputs[0].dptr()); + } else { + Tensor indices = Tensor( + reinterpret_cast(workspace_ptr), + Shape1(indices_size), + s); + indices = expr::range((int64_t)0, input_size); + Kernel::Launch(s, input_size, + random_numbers.dptr_, + inputs[weight_index].dptr()); + _sort(random_numbers.dptr_, indices.dptr_, input_size); + Copy( + outputs[0].FlatTo1D(s), + indices.Slice(0, output_size), + s + ); + + } + } else { + Random *prnd = ctx.requested[0].get_random(s); + int64_t random_tensor_size = + (replace ? output_size + : std::min(output_size, input_size - output_size)); + int64_t indices_size = replace ? 0 : input_size; + Tensor workspace = + (ctx.requested[1].get_space_typed( + Shape1(indices_size * sizeof(int64_t) + + ((random_tensor_size * sizeof(unsigned) / 7 + 1) * 8)), + s)); + // slice workspace + char *workspace_ptr = workspace.dptr_; + Tensor random_numbers = + Tensor(reinterpret_cast(workspace_ptr), + Shape1(random_tensor_size), s); + prnd->GetRandInt(random_numbers); + workspace_ptr += ((random_tensor_size * sizeof(unsigned) / 7 + 1) * 8); + if (replace) { + Kernel::Launch(s, output_size, random_numbers.dptr_, + outputs[0].dptr(), + input_size); + } else { + Tensor indices = Tensor( + reinterpret_cast(workspace_ptr), + Shape1(indices_size), + s); + indices = expr::range((int64_t)0, input_size); + int64_t nb_iterations = random_tensor_size; + int64_t split = input_size - nb_iterations; + Kernel::Launch(s, random_tensor_size, split, + random_numbers.dptr_); + // reservoir sampling + Kernel, xpu>::Launch( + s, 1, indices.dptr_, random_numbers.dptr_, nb_iterations, split); + index_t begin; + index_t end; + if (2 * output_size < input_size) { + begin = input_size - output_size; + end = input_size; + } else { + begin = 0; + end = output_size; + } + Copy( + outputs[0].FlatTo1D(s), + indices.Slice(begin, end), + s + ); + } + } +} + + + +} // namespace op +} // namespace mxnet + +#endif /* NP_CHOICE_OP_H */ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 399cdead6177..be197116bc3b 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1080,6 +1080,42 @@ def hybrid_forward(self, F, a, *args): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@use_np +def test_np_choice(): + + def test_sample_with_replacement(num_classes, shape, weight=None): + samples = np.random.choice(num_classes, shape, p=weight) + generated_density = np.histogram(samples, np.arange(num_classes + 1), density=True) + expected_density = (weight if weight is not None else + np.array([1 / num_classes] * num_classes)) + # test almost equal + print(generated_density[0] - expected_density) + # test shape + assert (samples.shape == shape) + + def test_sample_without_replacement(num_classes, shape, num_trials, weight=None): + samples = np.random.choice(num_classes, shape, replace=False, p=weight) + # Check shape and uniqueness + assert samples.shape == shape + assert len(np.unique(samples)) == samples.size + # Check distribution + bins = np.zeros((num_classes)) + expected_freq = (weight if weight is not None else + np.array([1 / num_classes] * num_classes)) + for i in range(num_trials): + out = np.random.choice(num_classes, 1, replace=False, p=weight) + bins[out] += 1 + bins /= num_trials + print(bins - expected_freq) + # assert_almost_equal(bins, expected_freq) + ctx = n + num_classes = 20 + num_samples = 10 ** 5 + # Sample with replacement: + + # Sample without replacment: + if __name__ == '__main__': import nose nose.runmodule() From 33a40fa4fab6b16a2b60c95063bfd2eb5d83e384 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 4 Sep 2019 04:38:33 +0000 Subject: [PATCH 2/8] unit test done --- python/mxnet/ndarray/numpy/random.py | 4 +- python/mxnet/symbol/numpy/random.py | 23 +++++ src/operator/numpy/random/np_choice_op.h | 33 +------- tests/python/unittest/test_numpy_op.py | 103 ++++++++++++++++++----- 4 files changed, 108 insertions(+), 55 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 5a77dcf838a7..1bc7e8675a2b 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -203,10 +203,10 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): if isinstance(a, np_ndarray): if p is None: indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) - return a[indices] + return _npi.take(a, indices) else: indices = _npi.choice(a, p, a=None, size=size, replace=replace, ctx=ctx, weighted=True) - return a[indices] + return _npi.take(a, indices) else: if p is None: return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index c5b8e1dc4906..3da3b8afbddd 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -190,3 +190,26 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): raise NotImplementedError('np.random.normal only supports loc and scale of ' 'numeric types for now') return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) + + +def choice(a, size=None, replace=True, p=None, ctx=None, out=None): + from ._symbol import _Symbol as np_symbol + if ctx is None: + ctx = current_context() + if out is not None: + size = out.shape + if size == (): + size = None + + if isinstance(a, np_symbol): + if p is None: + indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) + return _npi.take(a, indices) + else: + indices = _npi.choice(a, p, a=None, size=size, replace=replace, ctx=ctx, weighted=True) + return _npi.take(a, indices) + else: + if p is None: + return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) + else: + return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) \ No newline at end of file diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h index 9c8d83ac4e47..c14970bf5c0e 100644 --- a/src/operator/numpy/random/np_choice_op.h +++ b/src/operator/numpy/random/np_choice_op.h @@ -17,10 +17,8 @@ namespace op { struct NumpyChoiceParam: public dmlc::Parameter { dmlc::optional a; - // int64_t a; std::string ctx; dmlc::optional> size; - // int64_t size; bool replace; bool weighted; DMLC_DECLARE_PARAMETER(NumpyChoiceParam) { @@ -43,7 +41,6 @@ inline bool NumpyChoiceOpType(const nnvm::NodeAttrs &attrs, inline bool NumpyChoiceOpShape(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { - // begin const NumpyChoiceParam ¶m = nnvm::get(attrs.parsed); int64_t a; if (param.size.has_value()) { @@ -72,9 +69,7 @@ namespace mxnet_op { struct generate_samples { MSHADOW_XINLINE static void Map(index_t i, int64_t k, unsigned *rands) { - // printf("%d:%u,%d\n", i, rands[i], k); rands[i] = rands[i] % (i + k + 1); - // printf("sample[%d]:%u\n", i, rands[i]); } }; @@ -85,10 +80,7 @@ struct generate_reservoir { int64_t nb_iterations, int64_t k) { for (int64_t i = 0; i < nb_iterations; i++) { int64_t z = samples[i]; - // printf("z:%d\n", z); - // printf("k:%d\n", k); if (z < k) { - // _swap(indices[z], indices[i + k]); int64_t t = indices[z]; indices[z] = indices[i + k]; indices[i + k] = t; @@ -129,28 +121,6 @@ struct categorical_sampling { } // namespace mxnet_op - -template -void weighted_reservoir_sampling() { - -} - -template -void reservoir_sampling() { - // allocate -} - -template -void sampling_with_replacement() { - -} - -template -void weighted_sampling_with_replacement() { - -} - - template void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -173,7 +143,6 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, weight_index += 1; } int64_t output_size = outputs[0].Size(); - // printf("%p\n", workspace_ptr); if (weighted) { Random *prnd = ctx.requested[0].get_random(s); int64_t random_tensor_size = replace ? output_size : input_size; @@ -245,7 +214,7 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, int64_t split = input_size - nb_iterations; Kernel::Launch(s, random_tensor_size, split, random_numbers.dptr_); - // reservoir sampling + // Reservoir sampling. Kernel, xpu>::Launch( s, 1, indices.dptr_, random_numbers.dptr_, nb_iterations, split); index_t begin; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 3caed8b5adf1..e719bea54950 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1650,38 +1650,99 @@ def hybrid_forward(self, F, a): @with_seed() @use_np def test_np_choice(): - - def test_sample_with_replacement(num_classes, shape, weight=None): - samples = np.random.choice(num_classes, shape, p=weight) - generated_density = np.histogram(samples, np.arange(num_classes + 1), density=True) - expected_density = (weight if weight is not None else - np.array([1 / num_classes] * num_classes)) + class TestUniformChoice(HybridBlock): + def __init__(self, sample_size, replace): + super(TestUniformChoice, self).__init__() + self.sample_size = sample_size + self.replace = replace + + def hybrid_forward(self, F, a): + # op = getattr(F.np.random, "choice", None) + # return a + op(size=self.sample_size, replace=self.replace) + return F.np.random.choice(a=a, size=self.sample_size, replace=self.replace, p=None) + + class TestWeightedChoice(HybridBlock): + def __init__(self, sample_size, replace): + super(TestWeightedChoice, self).__init__() + self.sample_size = sample_size + self.replace = replace + + def hybrid_forward(self, F, a, p): + op = getattr(F.np.random, "choice", None) + return F.np.random.choice(a, self.sample_size, self.replace, p) + + def test_sample_with_replacement(sampler, num_classes, shape, weight=None): + samples = sampler(num_classes, shape, replace=True, p=weight).asnumpy() + generated_density = _np.histogram(samples, _np.arange(num_classes + 1), density=True)[0] + expected_density = (weight.asnumpy() if weight is not None else + _np.array([1 / num_classes] * num_classes)) # test almost equal - print(generated_density[0] - expected_density) + assert_almost_equal(generated_density, expected_density, rtol=1e-1, atol=1e-2) # test shape assert (samples.shape == shape) - def test_sample_without_replacement(num_classes, shape, num_trials, weight=None): - samples = np.random.choice(num_classes, shape, replace=False, p=weight) + def test_sample_without_replacement(sampler, num_classes, shape, num_trials, weight=None): + samples = sampler(num_classes, shape, replace=False, p=weight).asnumpy() # Check shape and uniqueness assert samples.shape == shape - assert len(np.unique(samples)) == samples.size + assert len(_np.unique(samples)) == samples.size # Check distribution - bins = np.zeros((num_classes)) - expected_freq = (weight if weight is not None else - np.array([1 / num_classes] * num_classes)) + bins = _np.zeros((num_classes)) + expected_freq = (weight.asnumpy() if weight is not None else + _np.array([1 / num_classes] * num_classes)) for i in range(num_trials): - out = np.random.choice(num_classes, 1, replace=False, p=weight) + out = sampler(num_classes, 1, replace=False, p=weight).item() bins[out] += 1 bins /= num_trials - print(bins - expected_freq) - # assert_almost_equal(bins, expected_freq) - ctx = n - num_classes = 20 - num_samples = 10 ** 5 - # Sample with replacement: + assert_almost_equal(bins, expected_freq, rtol=1e-1, atol=1e-2) + + def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): + a = np.arange(set_size) + if weight is not None: + samples = sampler(a, p) + else: + samples = sampler(a) + assert len(samples) == samples_size + if not replace: + assert len(_np.unique(samples)) == samples_size + + num_classes = 10 + num_samples = 10 ** 8 + # for hybridize in [True, False]: + # test sample with replacement + shape_list1 = [ + (10 ** 8, 1), + (10 ** 5, 10 ** 3), + (10 ** 2, 10 ** 3, 10 ** 3) + ] + for shape in shape_list1: + test_sample_with_replacement(np.random.choice, num_classes, shape) + weight = np.array(_np.random.dirichlet([1.0] * num_classes)) + test_sample_with_replacement(np.random.choice, num_classes, shape, weight) - # Sample without replacment: + shape_list2 = [ + (6, 1), + (2, 3), + (1, 2, 3), + (2, 2), + ] + # for shape in shape_list2: + # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5) + # weight = np.array(_np.random.dirichlet([1.0] * num_classes)) + # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5, weight) + + # Test hypridize mode: + for hybridize in [True, False]: + for replace in [True, False]: + test_choice = TestUniformChoice(num_classes // 2, replace) + if hybridize: + test_choice.hybridize() + weight = np.array(_np.random.dirichlet([1.0] * num_classes)) + test_indexing_mode(test_choice, num_classes, num_classes // 2, None) + + + + if __name__ == '__main__': import nose From 0460c92bccd7ebb27cc28c2ad58ef84542d918a1 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 4 Sep 2019 04:40:34 +0000 Subject: [PATCH 3/8] expose take to np internal --- src/operator/tensor/indexing_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 147205505e24..38ec77bdd758 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -653,6 +653,7 @@ NNVM_REGISTER_OP(_backward_SparseEmbedding) .set_attr("FComputeEx", SparseEmbeddingOpBackwardEx); NNVM_REGISTER_OP(take) +.add_alias("_npi_take") .describe(R"code(Takes elements from an input array along the given axis. This function slices the input array along a particular axis with the provided indices. From 51012d79606e49cef7862ef8250b6f7aba94ca7c Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 4 Sep 2019 06:41:02 +0000 Subject: [PATCH 4/8] style fixed --- python/mxnet/ndarray/numpy/random.py | 61 ++++++++- python/mxnet/numpy/random.py | 55 ++++++++- python/mxnet/symbol/numpy/random.py | 61 ++++++++- src/operator/numpy/random/np_choice_op.cc | 99 +++++++++------ src/operator/numpy/random/np_choice_op.cu | 51 +++++--- src/operator/numpy/random/np_choice_op.h | 144 +++++++++++----------- tests/python/unittest/test_numpy_op.py | 42 +++---- 7 files changed, 361 insertions(+), 152 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 1bc7e8675a2b..6ccf28acba5e 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -192,6 +192,59 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): def choice(a, size=None, replace=True, p=None, ctx=None, out=None): + """Generates a random sample from a given 1-D array + + Parameters + ----------- + a : 1-D array-like or int + If an ndarray, a random sample is generated from its elements. + If an int, the random sample is generated as if a were np.arange(a) + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + replace : boolean, optional + Whether the sample is with or without replacement + p : 1-D array-like, optional + The probabilities associated with each entry in a. + If not given the sample assumes a uniform distribution over all + entries in a. + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + -------- + samples : ndarray + The generated random samples + + Examples + --------- + Generate a uniform random sample from np.arange(5) of size 3: + + >>> np.random.choice(5, 3) + array([0, 3, 4]) + >>> #This is equivalent to np.random.randint(0,5,3) + + Generate a non-uniform random sample from np.arange(5) of size 3: + + >>> np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0]) + array([3, 3, 0]) + + Generate a uniform random sample from np.arange(5) of size 3 without + replacement: + + >>> np.random.choice(5, 3, replace=False) + array([3,1,0]) + >>> #This is equivalent to np.random.permutation(np.arange(5))[:3] + + Generate a non-uniform random sample from np.arange(5) of size + 3 without replacement: + + >>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) + array([2, 3, 0]) + """ from ...numpy import ndarray as np_ndarray if ctx is None: ctx = current_context() @@ -202,13 +255,15 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): if isinstance(a, np_ndarray): if p is None: - indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) + indices = _npi.choice(a, a=None, size=size, + replace=replace, ctx=ctx, weighted=False) return _npi.take(a, indices) else: - indices = _npi.choice(a, p, a=None, size=size, replace=replace, ctx=ctx, weighted=True) + indices = _npi.choice(a, p, a=None, size=size, + replace=replace, ctx=ctx, weighted=True) return _npi.take(a, indices) else: if p is None: return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) else: - return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) \ No newline at end of file + return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 29cabfb559a4..5cebf9b1aa3d 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -147,4 +147,57 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): def choice(a, size=None, replace=True, p=None, ctx=None, out=None): - return _mx_nd_np.random.choice(a, size, replace, p, ctx, out) \ No newline at end of file + """Generates a random sample from a given 1-D array + + Parameters + ----------- + a : 1-D array-like or int + If an ndarray, a random sample is generated from its elements. + If an int, the random sample is generated as if a were np.arange(a) + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + replace : boolean, optional + Whether the sample is with or without replacement + p : 1-D array-like, optional + The probabilities associated with each entry in a. + If not given the sample assumes a uniform distribution over all + entries in a. + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + -------- + samples : ndarray + The generated random samples + + Examples + --------- + Generate a uniform random sample from np.arange(5) of size 3: + + >>> np.random.choice(5, 3) + array([0, 3, 4]) + >>> #This is equivalent to np.random.randint(0,5,3) + + Generate a non-uniform random sample from np.arange(5) of size 3: + + >>> np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0]) + array([3, 3, 0]) + + Generate a uniform random sample from np.arange(5) of size 3 without + replacement: + + >>> np.random.choice(5, 3, replace=False) + array([3,1,0]) + >>> #This is equivalent to np.random.permutation(np.arange(5))[:3] + + Generate a non-uniform random sample from np.arange(5) of size + 3 without replacement: + + >>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) + array([2, 3, 0]) + """ + return _mx_nd_np.random.choice(a, size, replace, p, ctx, out) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 3da3b8afbddd..2b1a08306e1f 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -193,6 +193,59 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): def choice(a, size=None, replace=True, p=None, ctx=None, out=None): + """Generates a random sample from a given 1-D array + + Parameters + ----------- + a : 1-D array-like or int + If an ndarray, a random sample is generated from its elements. + If an int, the random sample is generated as if a were np.arange(a) + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k`` samples are drawn. Default is None, in which case a + single value is returned. + replace : boolean, optional + Whether the sample is with or without replacement + p : 1-D array-like, optional + The probabilities associated with each entry in a. + If not given the sample assumes a uniform distribution over all + entries in a. + ctx : Context, optional + Device context of output. Default is current context. + out : ``ndarray``, optional + Store output to an existing ``ndarray``. + + Returns + -------- + samples : _Symbol + The generated random samples + + Examples + --------- + Generate a uniform random sample from np.arange(5) of size 3: + + >>> np.random.choice(5, 3) + array([0, 3, 4]) + >>> #This is equivalent to np.random.randint(0,5,3) + + Generate a non-uniform random sample from np.arange(5) of size 3: + + >>> np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0]) + array([3, 3, 0]) + + Generate a uniform random sample from np.arange(5) of size 3 without + replacement: + + >>> np.random.choice(5, 3, replace=False) + array([3,1,0]) + >>> #This is equivalent to np.random.permutation(np.arange(5))[:3] + + Generate a non-uniform random sample from np.arange(5) of size + 3 without replacement: + + >>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) + array([2, 3, 0]) + """ from ._symbol import _Symbol as np_symbol if ctx is None: ctx = current_context() @@ -203,13 +256,15 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): if isinstance(a, np_symbol): if p is None: - indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) + indices = _npi.choice(a, a=None, size=size, + replace=replace, ctx=ctx, weighted=False) return _npi.take(a, indices) else: - indices = _npi.choice(a, p, a=None, size=size, replace=replace, ctx=ctx, weighted=True) + indices = _npi.choice(a, p, a=None, size=size, + replace=replace, ctx=ctx, weighted=True) return _npi.take(a, indices) else: if p is None: return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) else: - return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) \ No newline at end of file + return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) diff --git a/src/operator/numpy/random/np_choice_op.cc b/src/operator/numpy/random/np_choice_op.cc index 4a3c5c190629..f68573e4b400 100644 --- a/src/operator/numpy/random/np_choice_op.cc +++ b/src/operator/numpy/random/np_choice_op.cc @@ -1,15 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_choice_op.cc + * \brief Operator for random subset sampling + */ + #include "./np_choice_op.h" #include namespace mxnet { namespace op { -template<> -void _swap(int64_t& a, int64_t& b) { - std::swap(a, b); -} - -template<> +template <> void _sort(float* key, int64_t* data, index_t length) { std::sort(data, data + length, [key](int64_t const& i, int64_t const& j) -> bool { @@ -17,41 +37,44 @@ void _sort(float* key, int64_t* data, index_t length) { }); } - DMLC_REGISTER_PARAMETER(NumpyChoiceParam); - NNVM_REGISTER_OP(_npi_choice) -.describe("random choice") -.set_num_inputs( - [](const nnvm::NodeAttrs& attrs) { - int num_input = 0; - const NumpyChoiceParam& param = nnvm::get(attrs.parsed); - if (param.weighted) num_input += 1; - if (!param.a.has_value()) num_input += 1; - return num_input; - } -) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"input1", "input2"}; - }) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", NumpyChoiceOpShape) -.set_attr("FInferType", NumpyChoiceOpType) -.set_attr("FResourceRequest", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{ - ResourceRequest::kRandom, ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", NumpyChoiceForward) -.set_attr("FGradient", MakeZeroGradNodes) -.add_argument("input1", "NDArray-or-Symbol", "Source input") -.add_argument("input2", "NDArray-or-Symbol", "Source input") -.add_arguments(NumpyChoiceParam::__FIELDS__()); + .describe("random choice") + .set_num_inputs([](const nnvm::NodeAttrs& attrs) { + int num_input = 0; + const NumpyChoiceParam& param = nnvm::get(attrs.parsed); + if (param.weighted) num_input += 1; + if (!param.a.has_value()) num_input += 1; + return num_input; + }) + .set_num_outputs(1) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + int num_input = 0; + const NumpyChoiceParam& param = + nnvm::get(attrs.parsed); + if (param.weighted) num_input += 1; + if (!param.a.has_value()) num_input += 1; + if (num_input == 0) return std::vector(); + if (num_input == 1) return std::vector{"input1"}; + return std::vector{"input1", "input2"}; + }) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", NumpyChoiceOpShape) + .set_attr("FInferType", NumpyChoiceOpType) + .set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, + ResourceRequest::kTempSpace}; + }) + .set_attr("FCompute", NumpyChoiceForward) + .set_attr("FGradient", MakeZeroGradNodes) + .add_argument("input1", "NDArray-or-Symbol", "Source input") + .add_argument("input2", "NDArray-or-Symbol", "Source input") + .add_arguments(NumpyChoiceParam::__FIELDS__()); } // namespace op } // namespace mxnet - - diff --git a/src/operator/numpy/random/np_choice_op.cu b/src/operator/numpy/random/np_choice_op.cu index f7967346a2a8..cf04b91c8d15 100644 --- a/src/operator/numpy/random/np_choice_op.cu +++ b/src/operator/numpy/random/np_choice_op.cu @@ -1,25 +1,46 @@ -#include "./np_choice_op.h" -#include -#include +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_choice_op.cu + * \brief Operator for random subset sampling + */ + #include +#include +#include +#include "./np_choice_op.h" namespace mxnet { namespace op { -template<> -void _swap(int64_t& a, int64_t& b) { - thrust::swap(a, b); -} - -template<> +template <> void _sort(float* key, int64_t* data, index_t length) { - thrust::device_ptr dev_key(key); - thrust::device_ptr dev_data(data); - thrust::sort_by_key(dev_key, dev_key + length, dev_data, thrust::greater()); + thrust::device_ptr dev_key(key); + thrust::device_ptr dev_data(data); + thrust::sort_by_key(dev_key, dev_key + length, dev_data, + thrust::greater()); } NNVM_REGISTER_OP(_npi_choice) -.set_attr("FCompute", NumpyChoiceForward); + .set_attr("FCompute", NumpyChoiceForward); -} // namespace op -} // namespace mxnet \ No newline at end of file +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h index c14970bf5c0e..a22663b417ee 100644 --- a/src/operator/numpy/random/np_choice_op.h +++ b/src/operator/numpy/random/np_choice_op.h @@ -1,5 +1,30 @@ -#ifndef NP_CHOICE_OP_H -#define NP_CHOICE_OP_H +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_choice_op.h + * \brief Operator for random subset sampling + */ + +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_CHOICE_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_CHOICE_OP_H_ #include #include @@ -15,25 +40,24 @@ namespace mxnet { namespace op { -struct NumpyChoiceParam: public dmlc::Parameter { - dmlc::optional a; - std::string ctx; - dmlc::optional> size; - bool replace; - bool weighted; - DMLC_DECLARE_PARAMETER(NumpyChoiceParam) { - DMLC_DECLARE_FIELD(a); - DMLC_DECLARE_FIELD(size); - DMLC_DECLARE_FIELD(ctx) - .set_default("cpu"); - DMLC_DECLARE_FIELD(replace).set_default(true); - DMLC_DECLARE_FIELD(weighted).set_default(false); - } +struct NumpyChoiceParam : public dmlc::Parameter { + dmlc::optional a; + std::string ctx; + dmlc::optional> size; + bool replace; + bool weighted; + DMLC_DECLARE_PARAMETER(NumpyChoiceParam) { + DMLC_DECLARE_FIELD(a); + DMLC_DECLARE_FIELD(size); + DMLC_DECLARE_FIELD(ctx).set_default("cpu"); + DMLC_DECLARE_FIELD(replace).set_default(true); + DMLC_DECLARE_FIELD(weighted).set_default(false); + } }; inline bool NumpyChoiceOpType(const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { (*out_attrs)[0] = mshadow::kInt64; return true; } @@ -42,7 +66,6 @@ inline bool NumpyChoiceOpShape(const nnvm::NodeAttrs &attrs, std::vector *in_attrs, std::vector *out_attrs) { const NumpyChoiceParam ¶m = nnvm::get(attrs.parsed); - int64_t a; if (param.size.has_value()) { // Size declared. std::vector oshape_vec; @@ -58,26 +81,22 @@ inline bool NumpyChoiceOpShape(const nnvm::NodeAttrs &attrs, } template -void _swap(int64_t& a, int64_t& b); - -template -void _sort(float* key, int64_t* data, index_t length); +void _sort(float *key, int64_t *data, index_t length); namespace mxnet_op { // Uniform sample without replacement. struct generate_samples { - MSHADOW_XINLINE static void Map(index_t i, int64_t k, - unsigned *rands) { + MSHADOW_XINLINE static void Map(index_t i, int64_t k, unsigned *rands) { rands[i] = rands[i] % (i + k + 1); } }; template struct generate_reservoir { - MSHADOW_XINLINE static void Map(index_t dummy_index, - int64_t *indices, unsigned *samples, - int64_t nb_iterations, int64_t k) { + MSHADOW_XINLINE static void Map(index_t dummy_index, int64_t *indices, + unsigned *samples, int64_t nb_iterations, + int64_t k) { for (int64_t i = 0; i < nb_iterations; i++) { int64_t z = samples[i]; if (z < k) { @@ -91,7 +110,8 @@ struct generate_reservoir { // Uniform sample with replacement. struct random_indices { - MSHADOW_XINLINE static void Map(index_t i, unsigned *samples, int64_t *outs, int64_t k) { + MSHADOW_XINLINE static void Map(index_t i, unsigned *samples, int64_t *outs, + int64_t k) { outs[i] = samples[i] % k; } }; @@ -99,17 +119,18 @@ struct random_indices { // Weighted sample without replacement. // Use perturbed Gumbel variates as keys. struct generate_keys { - MSHADOW_XINLINE static void Map(index_t i, float* uniforms, float* weights) { + MSHADOW_XINLINE static void Map(index_t i, float *uniforms, float *weights) { uniforms[i] = -logf(-logf(uniforms[i])) + logf(weights[i]); } }; // Weighted sample with replacement. struct categorical_sampling { - MSHADOW_XINLINE static void Map(index_t i, float* weights, size_t length, float* uniforms, int64_t *outs) { + MSHADOW_XINLINE static void Map(index_t i, float *weights, size_t length, + float *uniforms, int64_t *outs) { outs[i] = 0; float acc = 0.0; - float threshold = uniforms[i]; + float threshold = uniforms[i]; for (size_t k = 0; k < length; k++) { acc += weights[k]; if (acc < threshold) { @@ -122,12 +143,10 @@ struct categorical_sampling { } // namespace mxnet_op template -void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - // forward +void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mxnet_op; const NumpyChoiceParam ¶m = nnvm::get(attrs.parsed); @@ -154,48 +173,39 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, s)); // slice workspace char *workspace_ptr = workspace.dptr_; - Tensor random_numbers = + Tensor random_numbers = Tensor(reinterpret_cast(workspace_ptr), - Shape1(random_tensor_size), s); + Shape1(random_tensor_size), s); prnd->SampleUniform(&random_numbers, 0, 1); - workspace_ptr += ((random_tensor_size * sizeof(float) / 7 + 1)* 8); + workspace_ptr += ((random_tensor_size * sizeof(float) / 7 + 1) * 8); if (replace) { - Kernel::Launch(s, output_size, - inputs[weight_index].dptr(), - input_size, - random_numbers.dptr_, - outputs[0].dptr()); + Kernel::Launch( + s, output_size, inputs[weight_index].dptr(), input_size, + random_numbers.dptr_, outputs[0].dptr()); } else { Tensor indices = Tensor( - reinterpret_cast(workspace_ptr), - Shape1(indices_size), - s); + reinterpret_cast(workspace_ptr), Shape1(indices_size), s); indices = expr::range((int64_t)0, input_size); - Kernel::Launch(s, input_size, - random_numbers.dptr_, + Kernel::Launch(s, input_size, random_numbers.dptr_, inputs[weight_index].dptr()); _sort(random_numbers.dptr_, indices.dptr_, input_size); - Copy( - outputs[0].FlatTo1D(s), - indices.Slice(0, output_size), - s - ); - + Copy(outputs[0].FlatTo1D(s), indices.Slice(0, output_size), + s); } } else { Random *prnd = ctx.requested[0].get_random(s); - int64_t random_tensor_size = + int64_t random_tensor_size = (replace ? output_size : std::min(output_size, input_size - output_size)); int64_t indices_size = replace ? 0 : input_size; Tensor workspace = (ctx.requested[1].get_space_typed( Shape1(indices_size * sizeof(int64_t) + - ((random_tensor_size * sizeof(unsigned) / 7 + 1) * 8)), + (random_tensor_size * sizeof(unsigned) / 7 + 1) * 8), s)); // slice workspace char *workspace_ptr = workspace.dptr_; - Tensor random_numbers = + Tensor random_numbers = Tensor(reinterpret_cast(workspace_ptr), Shape1(random_tensor_size), s); prnd->GetRandInt(random_numbers); @@ -206,9 +216,7 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, input_size); } else { Tensor indices = Tensor( - reinterpret_cast(workspace_ptr), - Shape1(indices_size), - s); + reinterpret_cast(workspace_ptr), Shape1(indices_size), s); indices = expr::range((int64_t)0, input_size); int64_t nb_iterations = random_tensor_size; int64_t split = input_size - nb_iterations; @@ -226,18 +234,12 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, begin = 0; end = output_size; } - Copy( - outputs[0].FlatTo1D(s), - indices.Slice(begin, end), - s - ); + Copy(outputs[0].FlatTo1D(s), indices.Slice(begin, end), s); } } } - - } // namespace op } // namespace mxnet -#endif /* NP_CHOICE_OP_H */ +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_CHOICE_OP_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e719bea54950..5fa330aaffcd 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1657,8 +1657,6 @@ def __init__(self, sample_size, replace): self.replace = replace def hybrid_forward(self, F, a): - # op = getattr(F.np.random, "choice", None) - # return a + op(size=self.sample_size, replace=self.replace) return F.np.random.choice(a=a, size=self.sample_size, replace=self.replace, p=None) class TestWeightedChoice(HybridBlock): @@ -1699,7 +1697,7 @@ def test_sample_without_replacement(sampler, num_classes, shape, num_trials, wei def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): a = np.arange(set_size) if weight is not None: - samples = sampler(a, p) + samples = sampler(a, weight) else: samples = sampler(a) assert len(samples) == samples_size @@ -1708,24 +1706,23 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): num_classes = 10 num_samples = 10 ** 8 - # for hybridize in [True, False]: - # test sample with replacement - shape_list1 = [ - (10 ** 8, 1), - (10 ** 5, 10 ** 3), - (10 ** 2, 10 ** 3, 10 ** 3) - ] - for shape in shape_list1: - test_sample_with_replacement(np.random.choice, num_classes, shape) - weight = np.array(_np.random.dirichlet([1.0] * num_classes)) - test_sample_with_replacement(np.random.choice, num_classes, shape, weight) + # Density tests are commented out due to their huge time comsumption. + # shape_list1 = [ + # (10 ** 8, 1), + # (10 ** 5, 10 ** 3), + # (10 ** 2, 10 ** 3, 10 ** 3) + # ] + # for shape in shape_list1: + # test_sample_with_replacement(np.random.choice, num_classes, shape) + # weight = np.array(_np.random.dirichlet([1.0] * num_classes)) + # test_sample_with_replacement(np.random.choice, num_classes, shape, weight) - shape_list2 = [ - (6, 1), - (2, 3), - (1, 2, 3), - (2, 2), - ] + # shape_list2 = [ + # (6, 1), + # (2, 3), + # (1, 2, 3), + # (2, 2), + # ] # for shape in shape_list2: # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5) # weight = np.array(_np.random.dirichlet([1.0] * num_classes)) @@ -1735,10 +1732,13 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): for hybridize in [True, False]: for replace in [True, False]: test_choice = TestUniformChoice(num_classes // 2, replace) + test_choice_weighted = TestWeightedChoice(num_classes // 2, replace) if hybridize: test_choice.hybridize() + test_choice_weighted.hybridize() weight = np.array(_np.random.dirichlet([1.0] * num_classes)) - test_indexing_mode(test_choice, num_classes, num_classes // 2, None) + test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None) + test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) From 5d96fafcbdcee717067290fa7eb805703e0ae1d7 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 4 Sep 2019 06:45:26 +0000 Subject: [PATCH 5/8] style fixed --- python/mxnet/numpy/random.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 00c5b1fa3248..9749eecfcd2c 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -237,4 +237,3 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): array([2, 3, 0]) """ return _mx_nd_np.random.choice(a, size, replace, p, ctx, out) - From 5325be10a74ed6619ed9e9e3b3b8a9965f3a4284 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Wed, 4 Sep 2019 07:10:31 +0000 Subject: [PATCH 6/8] style problems fixed --- src/operator/numpy/random/np_choice_op.cc | 65 ++++++++++++----------- src/operator/numpy/random/np_choice_op.cu | 2 +- tests/python/unittest/test_numpy_op.py | 10 ++-- 3 files changed, 39 insertions(+), 38 deletions(-) diff --git a/src/operator/numpy/random/np_choice_op.cc b/src/operator/numpy/random/np_choice_op.cc index f68573e4b400..328d7e264861 100644 --- a/src/operator/numpy/random/np_choice_op.cc +++ b/src/operator/numpy/random/np_choice_op.cc @@ -40,41 +40,42 @@ void _sort(float* key, int64_t* data, index_t length) { DMLC_REGISTER_PARAMETER(NumpyChoiceParam); NNVM_REGISTER_OP(_npi_choice) - .describe("random choice") - .set_num_inputs([](const nnvm::NodeAttrs& attrs) { +.describe("random choice") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + int num_input = 0; + const NumpyChoiceParam& param = nnvm::get(attrs.parsed); + if (param.weighted) num_input += 1; + if (!param.a.has_value()) num_input += 1; + return num_input; +}) +.set_num_outputs(1) +.set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { int num_input = 0; - const NumpyChoiceParam& param = nnvm::get(attrs.parsed); + const NumpyChoiceParam& param = + nnvm::get(attrs.parsed); if (param.weighted) num_input += 1; if (!param.a.has_value()) num_input += 1; - return num_input; - }) - .set_num_outputs(1) - .set_attr( - "FListInputNames", - [](const NodeAttrs& attrs) { - int num_input = 0; - const NumpyChoiceParam& param = - nnvm::get(attrs.parsed); - if (param.weighted) num_input += 1; - if (!param.a.has_value()) num_input += 1; - if (num_input == 0) return std::vector(); - if (num_input == 1) return std::vector{"input1"}; - return std::vector{"input1", "input2"}; - }) - .set_attr_parser(ParamParser) - .set_attr("FInferShape", NumpyChoiceOpShape) - .set_attr("FInferType", NumpyChoiceOpType) - .set_attr("FResourceRequest", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{ - ResourceRequest::kRandom, - ResourceRequest::kTempSpace}; - }) - .set_attr("FCompute", NumpyChoiceForward) - .set_attr("FGradient", MakeZeroGradNodes) - .add_argument("input1", "NDArray-or-Symbol", "Source input") - .add_argument("input2", "NDArray-or-Symbol", "Source input") - .add_arguments(NumpyChoiceParam::__FIELDS__()); + if (num_input == 0) return std::vector(); + if (num_input == 1) return std::vector{"input1"}; + return std::vector{"input1", "input2"}; +}) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyChoiceOpShape) +.set_attr("FInferType", NumpyChoiceOpType) +.set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{ + ResourceRequest::kRandom, + ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyChoiceForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyChoiceParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_choice_op.cu b/src/operator/numpy/random/np_choice_op.cu index cf04b91c8d15..0f42a2e76df0 100644 --- a/src/operator/numpy/random/np_choice_op.cu +++ b/src/operator/numpy/random/np_choice_op.cu @@ -40,7 +40,7 @@ void _sort(float* key, int64_t* data, index_t length) { } NNVM_REGISTER_OP(_npi_choice) - .set_attr("FCompute", NumpyChoiceForward); +.set_attr("FCompute", NumpyChoiceForward); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5fa330aaffcd..1c9ebbb32866 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1675,7 +1675,7 @@ def test_sample_with_replacement(sampler, num_classes, shape, weight=None): expected_density = (weight.asnumpy() if weight is not None else _np.array([1 / num_classes] * num_classes)) # test almost equal - assert_almost_equal(generated_density, expected_density, rtol=1e-1, atol=1e-2) + assert_almost_equal(generated_density, expected_density, rtol=1e-1, atol=1e-1) # test shape assert (samples.shape == shape) @@ -1692,7 +1692,7 @@ def test_sample_without_replacement(sampler, num_classes, shape, num_trials, wei out = sampler(num_classes, 1, replace=False, p=weight).item() bins[out] += 1 bins /= num_trials - assert_almost_equal(bins, expected_freq, rtol=1e-1, atol=1e-2) + assert_almost_equal(bins, expected_freq, rtol=1e-1, atol=1e-1) def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): a = np.arange(set_size) @@ -1707,6 +1707,7 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): num_classes = 10 num_samples = 10 ** 8 # Density tests are commented out due to their huge time comsumption. + # Tests passed locally. # shape_list1 = [ # (10 ** 8, 1), # (10 ** 5, 10 ** 3), @@ -1717,6 +1718,8 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): # weight = np.array(_np.random.dirichlet([1.0] * num_classes)) # test_sample_with_replacement(np.random.choice, num_classes, shape, weight) + # Tests passed locally, + # commented out for the same reason as above. # shape_list2 = [ # (6, 1), # (2, 3), @@ -1741,9 +1744,6 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) - - - if __name__ == '__main__': import nose nose.runmodule() From 35ca74130072b45bbec569b5e30bff5609314176 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 6 Sep 2019 04:01:18 +0000 Subject: [PATCH 7/8] remove out parameter and fix style --- python/mxnet/ndarray/numpy/random.py | 14 ++++++-------- python/mxnet/numpy/random.py | 6 ++---- python/mxnet/symbol/numpy/random.py | 13 ++++++------- src/operator/numpy/random/np_choice_op.h | 9 ++++----- 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index ddc0b16169c0..9372beaf1e92 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -245,7 +245,7 @@ def multinomial(n, pvals, size=None): return _npi.multinomial(n=n, pvals=pvals, size=size) -def choice(a, size=None, replace=True, p=None, ctx=None, out=None): +def choice(a, size=None, replace=True, p=None, **kwargs): """Generates a random sample from a given 1-D array Parameters @@ -265,8 +265,6 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): entries in a. ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns -------- @@ -300,14 +298,14 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): array([2, 3, 0]) """ from ...numpy import ndarray as np_ndarray + ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - if out is not None: - size = out.shape + out = kwargs.pop('out', None) if size == (): size = None - if isinstance(a, np_ndarray): + ctx = None if p is None: indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) @@ -318,6 +316,6 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.take(a, indices) else: if p is None: - return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) + return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out) else: - return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) + return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 9749eecfcd2c..aace767c8d55 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -182,7 +182,7 @@ def multinomial(n, pvals, size=None, **kwargs): return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs) -def choice(a, size=None, replace=True, p=None, ctx=None, out=None): +def choice(a, size=None, replace=True, p=None, **kwargs): """Generates a random sample from a given 1-D array Parameters @@ -202,8 +202,6 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): entries in a. ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns -------- @@ -236,4 +234,4 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): >>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0]) array([2, 3, 0]) """ - return _mx_nd_np.random.choice(a, size, replace, p, ctx, out) + return _mx_nd_np.random.choice(a, size, replace, p, **kwargs) diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 2b1a08306e1f..523983bac20a 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -192,7 +192,7 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs) -def choice(a, size=None, replace=True, p=None, ctx=None, out=None): +def choice(a, size=None, replace=True, p=None, **kwargs): """Generates a random sample from a given 1-D array Parameters @@ -212,8 +212,6 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): entries in a. ctx : Context, optional Device context of output. Default is current context. - out : ``ndarray``, optional - Store output to an existing ``ndarray``. Returns -------- @@ -247,14 +245,15 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): array([2, 3, 0]) """ from ._symbol import _Symbol as np_symbol + ctx = kwargs.pop('ctx', None) if ctx is None: ctx = current_context() - if out is not None: - size = out.shape + out = kwargs.pop('out', None) if size == (): size = None if isinstance(a, np_symbol): + ctx = None if p is None: indices = _npi.choice(a, a=None, size=size, replace=replace, ctx=ctx, weighted=False) @@ -265,6 +264,6 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None): return _npi.take(a, indices) else: if p is None: - return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False) + return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out) else: - return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True) + return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out) diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h index a22663b417ee..f6882c5d3283 100644 --- a/src/operator/numpy/random/np_choice_op.h +++ b/src/operator/numpy/random/np_choice_op.h @@ -167,7 +167,7 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, int64_t random_tensor_size = replace ? output_size : input_size; int64_t indices_size = replace ? 0 : input_size; Tensor workspace = - (ctx.requested[1].get_space_typed( + ctx.requested[1].get_space_typed( Shape1(indices_size * sizeof(int64_t) + (random_tensor_size * sizeof(float) / 7 + 1) * 8), s)); @@ -189,8 +189,7 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, Kernel::Launch(s, input_size, random_numbers.dptr_, inputs[weight_index].dptr()); _sort(random_numbers.dptr_, indices.dptr_, input_size); - Copy(outputs[0].FlatTo1D(s), indices.Slice(0, output_size), - s); + Copy(outputs[0].FlatTo1D(s), indices.Slice(0, output_size), s); } } else { Random *prnd = ctx.requested[0].get_random(s); @@ -199,10 +198,10 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, : std::min(output_size, input_size - output_size)); int64_t indices_size = replace ? 0 : input_size; Tensor workspace = - (ctx.requested[1].get_space_typed( + ctx.requested[1].get_space_typed( Shape1(indices_size * sizeof(int64_t) + (random_tensor_size * sizeof(unsigned) / 7 + 1) * 8), - s)); + s); // slice workspace char *workspace_ptr = workspace.dptr_; Tensor random_numbers = From 0cd681e37ad44746b56a5e1880f19c02dbea63ba Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 6 Sep 2019 04:07:33 +0000 Subject: [PATCH 8/8] fix syntax error --- src/operator/numpy/random/np_choice_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h index f6882c5d3283..335cc2741759 100644 --- a/src/operator/numpy/random/np_choice_op.h +++ b/src/operator/numpy/random/np_choice_op.h @@ -170,7 +170,7 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, ctx.requested[1].get_space_typed( Shape1(indices_size * sizeof(int64_t) + (random_tensor_size * sizeof(float) / 7 + 1) * 8), - s)); + s); // slice workspace char *workspace_ptr = workspace.dptr_; Tensor random_numbers =