Skip to content

Commit

Permalink
Merge pull request #260 from hx2A/fix200
Browse files Browse the repository at this point in the history
Improve random_sample() for #200
  • Loading branch information
hx2A authored Mar 11, 2023
2 parents 82d30b9 + cbe912c commit 8eab021
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
4 changes: 2 additions & 2 deletions py5_docs/Reference/api_en/Sketch_random_sample.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ category = math
subcategory = random

@@ signatures
random_sample(objects: list[Any], size: int=1, replace: bool=True) -> npt.NDArray
random_sample(objects: list[Any], size: int=1, replace: bool=True) -> list[Any]

@@ variables
objects: list[Any] - list of objects to choose from
Expand All @@ -15,7 +15,7 @@ size: int=1 - number of random items to select
@@ description
Select random items from a list. The list items can be of any type. If multiple items are selected, this function will by default allow the same item to be selected multiple times. Set the `replace` parameter to `False` to prevent the same item from being selected multiple times.

The returned value will always be a numpy array, even if only one item is sampled. If the list of objects is empty, an empty numpy array will be returned.
The returned value will always be a sequence such as a list or numpy array, even if only one item is sampled. If you only want to sample one item, consider using [](sketch_random_choice) instead. If the list of objects is empty, an empty list will be returned.

This function's randomness can be influenced by [](sketch_random_seed), and makes calls to numpy to select the random items.

Expand Down
19 changes: 14 additions & 5 deletions py5_resources/py5_module/py5/mixins/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import warnings
import traceback
from pathlib import Path
import types
from typing import overload, Union, Any

import numpy as np
Expand Down Expand Up @@ -261,17 +262,25 @@ def random_int(self, *args: int) -> int:

def random_choice(self, objects: list[Any]) -> Any:
"""$class_Sketch_random_choice"""
if objects:
if len(objects):
return objects[self._rng.integers(0, len(objects))]
else:
return None

def random_sample(self, objects: list[Any], size: int=1, replace: bool=True) -> npt.NDArray:
def random_sample(self, objects: list[Any], size: int=1, replace: bool=True) -> list[Any]:
"""$class_Sketch_random_sample"""
if objects:
return self._rng.choice(objects, size=size, replace=replace)
if len(objects):
if isinstance(objects, types.GeneratorType):
objects = list(objects)
indices = self._rng.choice(range(len(objects)), size=size, replace=replace)
if not isinstance(objects, list):
try:
return objects[indices]
except:
pass
return [objects[idx] for idx in indices]
else:
return np.array([], dtype='O')
return []

@overload
def random_gaussian(self) -> float:
Expand Down

0 comments on commit 8eab021

Please sign in to comment.