Skip to content

Commit

Permalink
Merge branch 'dev' into precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored May 7, 2024
2 parents 9ee0283 + b63fc05 commit ca88108
Show file tree
Hide file tree
Showing 181 changed files with 956 additions and 711 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/lints.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Ruff & Docformat

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
check: ["ruff", "docformatter"]

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Install tools
run: pip install "ruff==0.2.2" "docformatter[tomli]==1.5.0"
- name: Ruff (Flake8)
if: matrix.check == 'ruff'
working-directory: src/
run: ruff check .
- name: Docformatter
if: matrix.check == 'docformatter'
run: docformatter --check -r src/
2 changes: 1 addition & 1 deletion examples/persist_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

"""
This example shows how to serialize and deserialize a model
This example shows how to serialize and deserialize a model.
"""
import os
import pprint
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def get_version_cmdclass(version_file) -> dict:


class TypeCheckCommand(distutils.cmd.Command):
"""A custom command to run MyPy on the project sources."""
"""
A custom command to run MyPy on the project sources.
"""

description = "run MyPy on Python source files"
user_options = []
Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/dataset/artificial/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,10 +1063,10 @@ def normalized_ar1(tau, x0=None, norm="minmax", sigma=1.0):
r"""
Returns an ar1 process with an auto correlation time of tau.
norm can be
None -> no normalization
'minmax' -> min_max_scaled
'standard' -> 0 mean, unit variance
norm can be:
- None -> no normalization
- 'minmax' -> min_max_scaled
- 'standard' -> 0 mean, unit variance
"""
assert norm in [None, "minmax", "standard"]
phi = lifted_numpy.exp(-1.0 / tau)
Expand Down
7 changes: 4 additions & 3 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def infer_file_type(path):


def _rglob(path: Path, pattern="*", levels=1):
"""Like ``path.rglob(pattern)`` except this limits the number of sub
directories that are traversed. ``levels = 0`` is thus the same as
``path.glob(pattern)``.
"""
Like ``path.rglob(pattern)`` except this limits the number of sub
directories that are traversed.
``levels = 0`` is thus the same as ``path.glob(pattern)``.
"""
if levels is not None:
levels -= 1
Expand Down
1 change: 1 addition & 0 deletions src/gluonts/dataset/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __len__(self):
def _line_starts(self):
"""
Calculate the position for each line in the file.
This information can be used with ``file.seek`` to directly jump to a
specific line in the file.
"""
Expand Down
5 changes: 3 additions & 2 deletions src/gluonts/dataset/multivariate_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def _preprocess(self, dataset: Dataset) -> None:
The preprocess function iterates over the dataset to gather data that
is necessary for alignment.
This includes 1) Storing first/last timestamp in the dataset 2)
Storing the frequency of the dataset
This includes:
1. Storing first/last timestamp in the dataset
2. Storing the frequency of the dataset
"""
for data in dataset:
timestamp = data[FieldName.START]
Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/dataset/schema/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def fields(self):

@dataclass
class Get(Op):
"""Extracts the field ``name`` from the input."""
"""
Extracts the field ``name`` from the input.
"""

name: str

Expand All @@ -69,7 +71,9 @@ def fields(self):

@dataclass
class GetAttr(Op):
"""Invokes ``obj.name``"""
"""
Invokes ``obj.name``.
"""

obj: Op
name: str
Expand Down Expand Up @@ -298,7 +302,8 @@ def parse(x: Union[str, list]) -> Op:

@dataclass
class Translator:
"""Simple translation for GluonTS Datasets.
"""
Simple translation for GluonTS Datasets.
A given translator transforms an input dictionary (data-entry) into an
output dictionary.
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/dataset/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def periods_between(
end: pd.Period,
) -> int:
"""
Count how many periods fit between ``start`` and ``end``
(inclusive). The frequency is taken from ``start``.
Count how many periods fit between ``start`` and ``end`` (inclusive). The
frequency is taken from ``start``.
For example:
Expand Down
6 changes: 4 additions & 2 deletions src/gluonts/ev/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def get(self) -> np.ndarray:

@dataclass
class Sum(Aggregation):
"""Map-reduce way of calculating the sum of a stream of values.
"""
Map-reduce way of calculating the sum of a stream of values.
`partial_result` represents one of two things, depending on the axis:
Case 1 - axis 0 is aggregated (axis is None or 0):
Expand Down Expand Up @@ -75,7 +76,8 @@ def get(self) -> np.ndarray:

@dataclass
class Mean(Aggregation):
"""Map-reduce way of calculating the mean of a stream of values.
"""
Map-reduce way of calculating the mean of a stream of values.
`partial_result` represents one of two things, depending on the axis:
Case 1 - axis 0 is aggregated (axis is None or 0):
Expand Down
67 changes: 49 additions & 18 deletions src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ class MetricCollection:
metrics: List[Metric]

def update(self, data: Mapping[str, np.ndarray]) -> Self:
"""Update metrics using a single data instance."""
"""
Update metrics using a single data instance.
"""

for metric in self.metrics:
metric.update(data)

return self

def update_all(self, stream: Iterator[Mapping[str, np.ndarray]]) -> Self:
"""Update metrics using a stream of data instances."""
"""
Update metrics using a stream of data instances.
"""

for element in stream:
self.update(element)
Expand All @@ -74,12 +78,16 @@ class Metric:
name: str

def update(self, data: Mapping[str, np.ndarray]) -> Self:
"""Update metric using a single data instance."""
"""
Update metric using a single data instance.
"""

raise NotImplementedError

def update_all(self, stream: Iterator[Mapping[str, np.ndarray]]) -> Self:
"""Update metric using a stream of data instances."""
"""
Update metric using a stream of data instances.
"""

for element in stream:
self.update(element)
Expand All @@ -92,7 +100,9 @@ def get(self) -> np.ndarray:

@dataclass
class DirectMetric(Metric):
"""A Metric which uses a single function and aggregation strategy."""
"""
A Metric which uses a single function and aggregation strategy.
"""

stat: Callable
aggregate: Aggregation
Expand All @@ -108,10 +118,11 @@ def get(self) -> np.ndarray:

@dataclass
class DerivedMetric(Metric):
"""A Metric that is computed using other metrics.
"""
A Metric that is computed using other metrics.
A derived metric updates multiple, simpler metrics independently and in
the end combines their results as defined in `post_process`.
A derived metric updates multiple, simpler metrics independently and in the
end combines their results as defined in `post_process`.
"""

metrics: Dict[str, Metric]
Expand Down Expand Up @@ -237,7 +248,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MAE(BaseMetricDefinition):
"""Mean Absolute Error"""
"""
Mean Absolute Error.
"""

forecast_type: str = "0.5"

Expand All @@ -254,7 +267,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MSE(BaseMetricDefinition):
"""Mean Squared Error"""
"""
Mean Squared Error.
"""

forecast_type: str = "mean"

Expand Down Expand Up @@ -295,7 +310,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MAPE(BaseMetricDefinition):
"""Mean Absolute Percentage Error"""
"""
Mean Absolute Percentage Error.
"""

forecast_type: str = "0.5"

Expand All @@ -314,7 +331,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class SMAPE(BaseMetricDefinition):
"""Symmetric Mean Absolute Percentage Error"""
"""
Symmetric Mean Absolute Percentage Error.
"""

forecast_type: str = "0.5"

Expand All @@ -334,7 +353,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MSIS(BaseMetricDefinition):
"""Mean Scaled Interval Score"""
"""
Mean Scaled Interval Score.
"""

alpha: float = 0.05

Expand All @@ -351,7 +372,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MASE(BaseMetricDefinition):
"""Mean Absolute Scaled Error"""
"""
Mean Absolute Scaled Error.
"""

forecast_type: str = "0.5"

Expand Down Expand Up @@ -382,7 +405,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class ND(BaseMetricDefinition):
"""Normalized Deviation"""
"""
Normalized Deviation.
"""

forecast_type: str = "0.5"

Expand Down Expand Up @@ -410,7 +435,9 @@ def __call__(self, axis: Optional[int] = None) -> DerivedMetric:

@dataclass
class RMSE(BaseMetricDefinition):
"""Root Mean Squared Error"""
"""
Root Mean Squared Error.
"""

forecast_type: str = "mean"

Expand All @@ -435,7 +462,9 @@ def __call__(self, axis: Optional[int] = None) -> DerivedMetric:

@dataclass
class NRMSE(BaseMetricDefinition):
"""RMSE, normalized by the mean absolute label"""
"""
RMSE, normalized by the mean absolute label.
"""

forecast_type: str = "mean"

Expand Down Expand Up @@ -582,7 +611,9 @@ def __call__(self, axis: Optional[int] = None) -> DerivedMetric:

@dataclass
class OWA(BaseMetricDefinition):
"""Overall Weighted Average"""
"""
Overall Weighted Average.
"""

forecast_type: str = "0.5"

Expand Down
6 changes: 4 additions & 2 deletions src/gluonts/ev/ts_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
def seasonal_error(
time_series: np.ndarray, seasonality: int, time_axis=0
) -> np.ndarray:
"""The mean abs. difference of a time series, shifted by its seasonality.
"""
The mean abs. difference of a time series, shifted by its seasonality.
Some metrics use the seasonal error for normalization."""
Some metrics use the seasonal error for normalization.
"""

time_length = time_series.shape[time_axis]

Expand Down
8 changes: 5 additions & 3 deletions src/gluonts/evaluation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def aggregate_valid(
def validate_forecast(
forecast: Forecast, quantiles: Iterable[Quantile]
) -> bool:
"""Validates a Forecast object by checking it for `NaN` values.
The supplied quantiles and mean (if available) are checked.
"""
Validates a Forecast object by checking it for `NaN` values. The supplied
quantiles and mean (if available) are checked.
Parameters
----------
Expand Down Expand Up @@ -767,7 +768,8 @@ def __call__(
fcst_iterator: Iterable[Forecast],
num_series=None,
) -> Tuple[Dict[str, float], pd.DataFrame]:
"""Compute accuracy metrics for multivariate forecasts.
"""
Compute accuracy metrics for multivariate forecasts.
Parameters
----------
Expand Down
Loading

0 comments on commit ca88108

Please sign in to comment.