-
-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement weighted sampling API #518
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm less convinced about adding the IntoIteratorRandom
trait because (1) the only advantage over IteratorRandom
is implicit conversions (e.g. direct usage on a Vec
) which aren't always a good idea (as you note) and (2) it's an extra trait. Also (3) this would let us add SliceRandom
methods by the same name later without name conflicts.
src/seq.rs
Outdated
/// use rand::prelude::*; | ||
/// | ||
/// let choices = [('a', 2), ('b', 1), ('c', 1)]; | ||
/// // In rustc version XXX and newer, you can use a closure instead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea what XXX is? 1.26 is the only version since 1.22 mentioning closures in the release announcement, and doesn't mention cloning. We don't support older than 1.22 anyway.
src/seq.rs
Outdated
@@ -91,6 +94,52 @@ pub trait SliceRandom { | |||
fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Item> | |||
where R: Rng + ?Sized; | |||
|
|||
/// Similar to [`choose`], but each item in the slice don't have the same | |||
/// likelyhood of getting returned. The likelyhood of a given item getting | |||
/// returned is proportional to the value returned by the mapping function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's spelled likelihood. Also, "don't have" is bad grammar and the sentence is indirectly describing the purpose. Try:
A variant of
choose
where the likelihood of each outcome may be specified. The specified functionfunc
maps itemsx
to a relative likelihoodfunc(x)
. The probability of each item being selected is thereforefunc(x) / S
, whereS
is the sum of allfunc(x)
.
Perhaps also rename func
→ weight
or w
or f
?
src/seq.rs
Outdated
/// Extension trait on IntoIterator, providing random sampling methods. | ||
pub trait IntoIteratorRandom: IntoIterator + Sized { | ||
|
||
/// Return a the index of a random element from this iterator where the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'a the'
src/seq.rs
Outdated
Self::Item: SampleBorrow<X> { | ||
let mut iter = self.into_iter(); | ||
let mut total_weight: X = iter.next() | ||
.expect("Can't create Distribution for empty set of weights") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should panic in this case. Maybe return some type of error instead.
I feel like this should be a |
Why not put the weight function One thing on my mind here is avoiding adding too much code/API for relatively obscure functionality. We may be able to reduce the API to just Can you add some benchmarks? |
I pushed an updated API. It contains
In this version
I don't think we can make this a lot smaller. The bulk of the code is in |
Oh, but on the flip side, |
261d7b0
to
8538619
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some typos I noticed.
I'm still not sure about the approach though.
src/distributions/weighted.rs
Outdated
@@ -0,0 +1,182 @@ | |||
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018
src/distributions/weighted.rs
Outdated
// Note that this whole module is only imported if feature="alloc" is enabled. | ||
#[cfg(not(feature="std"))] use alloc::Vec; | ||
|
||
/// A distribution using weighted sampling to pick an discretely selected item. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a discretely
src/distributions/weighted.rs
Outdated
/// of a random element from the iterator used when the `WeightedIndex` was | ||
/// created. The chance of a given element being picked is proportional to the | ||
/// value of the element. The weights can use any type `X` for which an | ||
/// implementaiton of [`Uniform<X>`] exists. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementation
src/distributions/weighted.rs
Outdated
|
||
/// A distribution using weighted sampling to pick an discretely selected item. | ||
/// | ||
/// When a `WeightedIndex` is sampled from, it returns the index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sampling a WeightedIndex
distribution returns ...
src/distributions/weighted.rs
Outdated
impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { | ||
/// Creates a new a `WeightedIndex` [`Distribution`] using the values | ||
/// in `weights`. The weights can use any type `X` for which an | ||
/// implementaiton of [`Uniform<X>`] exists. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same typo
I pushed those comment fixes as well as a couple of more tests. One thing that I wasn't sure which way to go on is what to do with negative weights. Definitely feels like if that happens that it's a pretty severe bug in the calling code, so panicking seems warranted. But also seems somewhat strange to have some errors return a |
I'm curious: why are we re-implementing binary search? |
Good point! Using |
let chosen_weight = self.weight_distribution.sample(rng); | ||
// Find the first item which has a weight *higher* than the chosen weight. | ||
self.cumulative_weights.binary_search_by( | ||
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this panic if (Won't happen)Ok
? Will that just not happen?
Since the closure never returns In general |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of little things (sorry I didn't catch this before)!
I also opened #535 but that can be a separate PR.
@@ -192,6 +197,8 @@ use Rng; | |||
#[doc(inline)] pub use self::dirichlet::Dirichlet; | |||
|
|||
pub mod uniform; | |||
#[cfg(feature="alloc")] | |||
#[doc(hidden)] pub mod weighted; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have no reason to make this module public so better make it private I think. The doc(hidden)
stuff is supposedly there for backwards compatibility, though probably it's been used erroneously in a couple of cases.
src/distributions/weighted.rs
Outdated
/// | ||
/// # Panics | ||
/// | ||
/// If a value in the iterator is `< 0`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it makes sense to panic on some errors while handling others via Result
.
src/distributions/weighted.rs
Outdated
|
||
let zero = <X as Default>::default(); | ||
let weights = iter.map(|w| { | ||
assert!(*w.borrow() >= zero, "Negative weight in WeightedIndex::new"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. this can just return an error
…::choose_weighted_mut
Pushed an updated PR that fixes those comments. I'll do a separate PR to address #535. |
It will possible use that as: most important things is It most useful in practical life. P.S. AFAIK numpy.random.choice used CDF algorithm. |
@mrLSD we have very different requirements from Python/numpy, so I don't think there is much to take from their approach. Besides which this has already been merged. This samples with replacement, with given weights. We also have code for sampling without replacement, but without user-defined weights: see I don't know why you would want to a |
It's very different functions About size. Very simple example: I have 1kk users with their weights. And I should select only 100 unique users. Do you really think that it's rare case? |
Fair point. Sounds like you have a simple feature request then, rather than wanting a different API or multi-dimensional matrices? Numpy embeds many things in one place, not always for the better. I opened #596 (which you could have done to start with). I have no plans to work on this now, but at least now there is a visible ticket open. |
@dhardy Thanks! |
Having experimented with various ways of doing weighted sampling, I think this API is the best I've thought of so far.
This PR adds the following APIs:
IntoIteratorRandom::choose_index_weighted
. This is the most low-level way of going a single sampling. It consumes an iterable of weights and returns an index. This index can then be used to index into a slice, another iterator, etc. This is to weighted sampling whatgen_range
is to uniform sampling.IntoIteratorRandom::into_weighted_index_distribution
. If you're sampling multiple times using the same set of weights, it is more optimal to build up an array of cumulative weights and then do a binary search to find the index corresponding to the randomly generated value. So this is the most low-level API for repeated sampling. This is to weighted sampling whatUniform::new
is to uniform sampling.SliceRandom::choose_weighted
. This is a convenience function on top ofIntoIteratorRandom::choose_index_weighted
which allows using a slice and a mapping function to get a weighted sample from the slice. This function is to weighted sampling whatSliceRandom::choose
is to uniform sampling.SliceRandom::choose_weighted_mut
. Same asSliceRandom::choose_weighted
, but returns a mutable reference. This function is to weighted sampling whatSliceRandom::choose_mut
is to uniform sampling.There's still lots of other ways we could approach this. But this set of functions felt pretty good.
The last two felt like very nice pleasant to use APIs which should cover most common use cases. Especially on modern versions of rustc where closures automatically implement
Clone
when they can. And they nicely match withSliceRandom::choose
andSliceRandom::choose_mut
.The first two functions seem like pretty good low-level APIs for doing anything that
choose_weighted
andchoose_weighted_mut
. However I'm less sure that these functions live where they should and have the correct names. In particular, maybe IntoIteratorRandom::into_weighted_index_distribution should be a separate struct with a constructor which takes an iterator, more similar toUniform::new
.And for
IntoIteratorRandom::choose_index_weighted
,SliceRandom::choose_weighted
andSliceRandom::choose_weighted_mut
there's the questions of if they should live onSliceRandom
orIntoIteratorRandom
.The advantage of things living on
IntoIteratorRandom
is that they appear on both slices and on iterators, which is a good thing. The downside is that it means thatmy_vec.choose_index_weighted(...)
will actually silently clone theVec
and its contents, which is unlikely expected. However it will also consume theVec
, which will likely lead to compilation errors, allowing the developer to catch this. The fix is to write(&my_vec).choose_index_weighted(...)
which will not clone theVec
or its contents, and will not consume theVec
. This problem is unique toVec
objects and do not happen when calling on a slice or an array.