Skip to content

Commit

Permalink
Update samplers (#2913)
Browse files Browse the repository at this point in the history
* Add samplers

* Fix range

* Updated name

* Update init

* Update ExpectedNumInstanceSampler, Remove ExpectedNumInstanceWithMinSampler

* Use randint instead of choice

* Fix

---------

Co-authored-by: Abdul Fatir Ansari <[email protected]>
  • Loading branch information
abdulfatir and Abdul Fatir Ansari authored Jun 12, 2023
1 parent e430cd2 commit 1cc17c1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/gluonts/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ContinuousTimePredictionSampler",
"ExpandDimArray",
"ExpectedNumInstanceSampler",
"NumInstanceSampler",
"FilterTransformation",
"FlatMapTransformation",
"Identity",
Expand Down Expand Up @@ -117,6 +118,7 @@
ContinuousTimeUniformSampler,
ContinuousTimePredictionSampler,
ExpectedNumInstanceSampler,
NumInstanceSampler,
InstanceSampler,
TestSplitSampler,
ValidationSplitSampler,
Expand Down
34 changes: 32 additions & 2 deletions src/gluonts/transform/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ def __call__(self, ts: np.ndarray) -> np.ndarray:
raise NotImplementedError()


class NumInstanceSampler(InstanceSampler):
"""
Samples N time points from each series.
Parameters
----------
N
number of time points to sample from each time series.
"""

N: int

def __call__(self, ts: np.ndarray) -> np.ndarray:
a, b = self._get_bounds(ts)
if a > b:
return np.array([], dtype=int)

return np.random.randint(a, b + 1, size=self.N)


class UniformSplitSampler(InstanceSampler):
"""
Samples each point with the same fixed probability.
Expand Down Expand Up @@ -116,10 +136,13 @@ class ExpectedNumInstanceSampler(InstanceSampler):
----------
num_instances
number of training examples generated per time series on average
number of time points to sample per time series on average
min_instances
minimum number of time points to sample per time series
"""

num_instances: float
min_instances: int = 0
total_length: int = 0
n: int = 0

Expand All @@ -139,7 +162,14 @@ def __call__(self, ts: np.ndarray) -> np.ndarray:

p = self.num_instances / avg_length
(indices,) = np.where(np.random.random_sample(window_size) < p)
return indices + a
indices += a
if len(indices) < self.min_instances:
prefix = np.random.randint(
a, b + 1, size=self.min_instances - len(indices)
)
return np.concatenate([prefix, indices])

return indices


class BucketInstanceSampler(InstanceSampler):
Expand Down

0 comments on commit 1cc17c1

Please sign in to comment.