Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Random.choice implemented (#16089)
Browse files Browse the repository at this point in the history
* imperative choice done

* unit test done

* expose take to np internal

* style fixed

* style fixed

* style problems fixed

* remove out parameter and fix style

* fix syntax error
  • Loading branch information
xidulu authored and haojin2 committed Sep 9, 2019
1 parent 9a9c5f8 commit e260f13
Show file tree
Hide file tree
Showing 8 changed files with 679 additions and 2 deletions.
78 changes: 77 additions & 1 deletion python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ...base import numeric_types


__all__ = ['randint', 'uniform', 'normal']
__all__ = ['randint', 'uniform', 'normal', "choice"]


def randint(low, high=None, size=None, dtype=None, **kwargs):
Expand Down Expand Up @@ -243,3 +243,79 @@ def multinomial(n, pvals, size=None):
if any(isinstance(i, list) for i in pvals):
raise ValueError('object too deep for desired array')
return _npi.multinomial(n=n, pvals=pvals, size=size)


def choice(a, size=None, replace=True, p=None, **kwargs):
"""Generates a random sample from a given 1-D array
Parameters
-----------
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements.
If an int, the random sample is generated as if a were np.arange(a)
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement
p : 1-D array-like, optional
The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
ctx : Context, optional
Device context of output. Default is current context.
Returns
--------
samples : ndarray
The generated random samples
Examples
---------
Generate a uniform random sample from np.arange(5) of size 3:
>>> np.random.choice(5, 3)
array([0, 3, 4])
>>> #This is equivalent to np.random.randint(0,5,3)
Generate a non-uniform random sample from np.arange(5) of size 3:
>>> np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
array([3, 3, 0])
Generate a uniform random sample from np.arange(5) of size 3 without
replacement:
>>> np.random.choice(5, 3, replace=False)
array([3,1,0])
>>> #This is equivalent to np.random.permutation(np.arange(5))[:3]
Generate a non-uniform random sample from np.arange(5) of size
3 without replacement:
>>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
array([2, 3, 0])
"""
from ...numpy import ndarray as np_ndarray
ctx = kwargs.pop('ctx', None)
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if size == ():
size = None
if isinstance(a, np_ndarray):
ctx = None
if p is None:
indices = _npi.choice(a, a=None, size=size,
replace=replace, ctx=ctx, weighted=False)
return _npi.take(a, indices)
else:
indices = _npi.choice(a, p, a=None, size=size,
replace=replace, ctx=ctx, weighted=True)
return _npi.take(a, indices)
else:
if p is None:
return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out)
else:
return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)
57 changes: 56 additions & 1 deletion python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..ndarray import numpy as _mx_nd_np


__all__ = ["randint", "uniform", "normal"]
__all__ = ["randint", "uniform", "normal", "choice"]


def randint(low, high=None, size=None, dtype=None, **kwargs):
Expand Down Expand Up @@ -180,3 +180,58 @@ def multinomial(n, pvals, size=None, **kwargs):
array([32, 68])
"""
return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs)


def choice(a, size=None, replace=True, p=None, **kwargs):
"""Generates a random sample from a given 1-D array
Parameters
-----------
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements.
If an int, the random sample is generated as if a were np.arange(a)
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement
p : 1-D array-like, optional
The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
ctx : Context, optional
Device context of output. Default is current context.
Returns
--------
samples : ndarray
The generated random samples
Examples
---------
Generate a uniform random sample from np.arange(5) of size 3:
>>> np.random.choice(5, 3)
array([0, 3, 4])
>>> #This is equivalent to np.random.randint(0,5,3)
Generate a non-uniform random sample from np.arange(5) of size 3:
>>> np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
array([3, 3, 0])
Generate a uniform random sample from np.arange(5) of size 3 without
replacement:
>>> np.random.choice(5, 3, replace=False)
array([3,1,0])
>>> #This is equivalent to np.random.permutation(np.arange(5))[:3]
Generate a non-uniform random sample from np.arange(5) of size
3 without replacement:
>>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
array([2, 3, 0])
"""
return _mx_nd_np.random.choice(a, size, replace, p, **kwargs)
77 changes: 77 additions & 0 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,80 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs):
raise NotImplementedError('np.random.normal only supports loc and scale of '
'numeric types for now')
return _npi.random_normal(loc, scale, shape=size, dtype=dtype, ctx=ctx, out=out, **kwargs)


def choice(a, size=None, replace=True, p=None, **kwargs):
"""Generates a random sample from a given 1-D array
Parameters
-----------
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements.
If an int, the random sample is generated as if a were np.arange(a)
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., ``(m, n, k)``, then
``m * n * k`` samples are drawn. Default is None, in which case a
single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement
p : 1-D array-like, optional
The probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all
entries in a.
ctx : Context, optional
Device context of output. Default is current context.
Returns
--------
samples : _Symbol
The generated random samples
Examples
---------
Generate a uniform random sample from np.arange(5) of size 3:
>>> np.random.choice(5, 3)
array([0, 3, 4])
>>> #This is equivalent to np.random.randint(0,5,3)
Generate a non-uniform random sample from np.arange(5) of size 3:
>>> np.random.choice(5, 3, p=[0.1, 0, 0.3, 0.6, 0])
array([3, 3, 0])
Generate a uniform random sample from np.arange(5) of size 3 without
replacement:
>>> np.random.choice(5, 3, replace=False)
array([3,1,0])
>>> #This is equivalent to np.random.permutation(np.arange(5))[:3]
Generate a non-uniform random sample from np.arange(5) of size
3 without replacement:
>>> np.random.choice(5, 3, replace=False, p=[0.1, 0, 0.3, 0.6, 0])
array([2, 3, 0])
"""
from ._symbol import _Symbol as np_symbol
ctx = kwargs.pop('ctx', None)
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if size == ():
size = None

if isinstance(a, np_symbol):
ctx = None
if p is None:
indices = _npi.choice(a, a=None, size=size,
replace=replace, ctx=ctx, weighted=False)
return _npi.take(a, indices)
else:
indices = _npi.choice(a, p, a=None, size=size,
replace=replace, ctx=ctx, weighted=True)
return _npi.take(a, indices)
else:
if p is None:
return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out)
else:
return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)
81 changes: 81 additions & 0 deletions src/operator/numpy/random/np_choice_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_choice_op.cc
* \brief Operator for random subset sampling
*/

#include "./np_choice_op.h"
#include <algorithm>

namespace mxnet {
namespace op {

template <>
void _sort<cpu>(float* key, int64_t* data, index_t length) {
std::sort(data, data + length,
[key](int64_t const& i, int64_t const& j) -> bool {
return key[i] > key[j];
});
}

DMLC_REGISTER_PARAMETER(NumpyChoiceParam);

NNVM_REGISTER_OP(_npi_choice)
.describe("random choice")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
int num_input = 0;
const NumpyChoiceParam& param = nnvm::get<NumpyChoiceParam>(attrs.parsed);
if (param.weighted) num_input += 1;
if (!param.a.has_value()) num_input += 1;
return num_input;
})
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>(
"FListInputNames",
[](const NodeAttrs& attrs) {
int num_input = 0;
const NumpyChoiceParam& param =
nnvm::get<NumpyChoiceParam>(attrs.parsed);
if (param.weighted) num_input += 1;
if (!param.a.has_value()) num_input += 1;
if (num_input == 0) return std::vector<std::string>();
if (num_input == 1) return std::vector<std::string>{"input1"};
return std::vector<std::string>{"input1", "input2"};
})
.set_attr_parser(ParamParser<NumpyChoiceParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyChoiceOpShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyChoiceOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
ResourceRequest::kRandom,
ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyChoiceForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_argument("input2", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyChoiceParam::__FIELDS__());

} // namespace op
} // namespace mxnet
46 changes: 46 additions & 0 deletions src/operator/numpy/random/np_choice_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_choice_op.cu
* \brief Operator for random subset sampling
*/

#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/swap.h>
#include "./np_choice_op.h"

namespace mxnet {
namespace op {

template <>
void _sort<gpu>(float* key, int64_t* data, index_t length) {
thrust::device_ptr<float> dev_key(key);
thrust::device_ptr<int64_t> dev_data(data);
thrust::sort_by_key(dev_key, dev_key + length, dev_data,
thrust::greater<float>());
}

NNVM_REGISTER_OP(_npi_choice)
.set_attr<FCompute>("FCompute<gpu>", NumpyChoiceForward<gpu>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit e260f13

Please sign in to comment.