Skip to content

Commit

Permalink
add log error metrics to ev pkg (awslabs#2621)
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-mcdo committed Feb 9, 2023
1 parent 8477bb3 commit 97a5125
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 2 deletions.
14 changes: 14 additions & 0 deletions src/gluonts/ev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
error,
absolute_error,
squared_error,
log_error,
absolute_log_error,
squared_log_error,
quantile_loss,
coverage,
absolute_percentage_error,
Expand All @@ -30,6 +33,8 @@
sum_absolute_label,
SumAbsoluteError,
MSE,
MALE,
MSLE,
SumQuantileLoss,
Coverage,
MAPE,
Expand All @@ -38,6 +43,8 @@
MASE,
ND,
RMSE,
EMALE,
ERMSLE,
NRMSE,
WeightedSumQuantileLoss,
MAECoverage,
Expand All @@ -54,6 +61,9 @@
"error",
"absolute_error",
"squared_error",
"log_error",
"absolute_log_error",
"squared_log_error",
"quantile_loss",
"coverage",
"absolute_percentage_error",
Expand All @@ -64,6 +74,8 @@
"sum_absolute_label",
"SumAbsoluteError",
"MSE",
"MALE",
"MSLE",
"SumQuantileLoss",
"Coverage",
"MAPE",
Expand All @@ -72,6 +84,8 @@
"MASE",
"ND",
"RMSE",
"EMALE",
"ERMSLE",
"NRMSE",
"WeightedSumQuantileLoss",
"MAECoverage",
Expand Down
80 changes: 80 additions & 0 deletions src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
error,
absolute_error,
absolute_label,
absolute_log_error,
absolute_percentage_error,
absolute_scaled_error,
coverage,
quantile_loss,
scaled_interval_score,
squared_error,
squared_log_error,
symmetric_absolute_percentage_error,
)

Expand Down Expand Up @@ -94,6 +96,36 @@ def __call__(self, axis: Optional[int] = None) -> DirectEvaluator:
)


@dataclass
class MALE:
"""Mean Absolute Log Error"""

forecast_type: str = "0.5"

def __call__(self, axis: Optional[int] = None) -> DirectEvaluator:
return DirectEvaluator(
name="MALE",
stat=partial(absolute_log_error, forecast_type=self.forecast_type),
aggregate=Mean(axis=axis),
)


@dataclass
class MSLE:
"""Mean Squared Log Error"""

forecast_type: str = "0.5"
# technically, the forecast type should be "geometric mean",
# but forecast objects do not provide this (for now)

def __call__(self, axis: Optional[int] = None) -> DirectEvaluator:
return DirectEvaluator(
name="MSLE",
stat=partial(squared_log_error, forecast_type=self.forecast_type),
aggregate=Mean(axis=axis),
)


@dataclass
class SumQuantileLoss:
q: float
Expand Down Expand Up @@ -228,6 +260,54 @@ def __call__(self, axis: Optional[int] = None) -> DerivedEvaluator:
)


@dataclass
class EMALE:
"""Exponential Mean Absolute Log Error"""

forecast_type: str = "0.5"

@staticmethod
def exponential_mean_absolute_log_error(
mean_absolute_log_error: np.ndarray,
) -> np.ndarray:
return np.exp(mean_absolute_log_error)

def __call__(self, axis: Optional[int] = None) -> DerivedEvaluator:
return DerivedEvaluator(
name="EMALE",
evaluators={
"mean_absolute_log_error": MALE(
forecast_type=self.forecast_type
)(axis=axis)
},
post_process=self.exponential_mean_absolute_log_error,
)


@dataclass
class ERMSLE:
"""Exponential Root Mean Squared Log Error"""

forecast_type: str = "0.5"

@staticmethod
def exponential_root_mean_squared_log_error(
mean_squared_log_error: np.ndarray,
) -> np.ndarray:
return np.exp(np.sqrt(mean_squared_log_error))

def __call__(self, axis: Optional[int] = None) -> DerivedEvaluator:
return DerivedEvaluator(
name="ERMSLE",
evaluators={
"mean_squared_log_error": MSLE(
forecast_type=self.forecast_type
)(axis=axis)
},
post_process=self.exponential_root_mean_squared_log_error,
)


@dataclass
class NRMSE:
"""RMSE, normalized by the mean absolute label"""
Expand Down
16 changes: 16 additions & 0 deletions src/gluonts/ev/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def squared_error(
return np.square(error(data, forecast_type))


def log_error(data: Dict[str, np.ndarray], forecast_type: str) -> np.ndarray:
return np.log(data["label"] / data[forecast_type])


def absolute_log_error(
data: Dict[str, np.ndarray], forecast_type: str
) -> np.ndarray:
return np.abs(log_error(data, forecast_type))


def squared_log_error(
data: Dict[str, np.ndarray], forecast_type: str
) -> np.ndarray:
return np.square(log_error(data, forecast_type))


def quantile_loss(data: Dict[str, np.ndarray], q: float) -> np.ndarray:
forecast_type = str(q)
prediction = data[forecast_type]
Expand Down
20 changes: 18 additions & 2 deletions test/ev/test_metrics_compared_to_previous_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
MAPE,
MASE,
MSE,
MALE,
MSLE,
MSIS,
SMAPE,
Coverage,
Expand All @@ -45,6 +47,8 @@
ND,
NRMSE,
RMSE,
EMALE,
ERMSLE,
MAECoverage,
MeanSumQuantileLoss,
MeanWeightedSumQuantileLoss,
Expand Down Expand Up @@ -158,6 +162,8 @@ def get_new_metrics(test_data, predictor, quantile_levels):
*(SumQuantileLoss(q=quantile.value) for quantile in quantiles),
mean_absolute_label,
MSE(),
MALE(),
MSLE(),
MASE(),
MAPE(),
SMAPE(),
Expand All @@ -168,6 +174,8 @@ def get_new_metrics(test_data, predictor, quantile_levels):
RMSE(),
NRMSE(),
ND(),
EMALE(),
ERMSLE(),
*(WeightedSumQuantileLoss(q=quantile.value) for quantile in quantiles),
MeanSumQuantileLoss([quantile.value for quantile in quantiles]),
MeanWeightedSumQuantileLoss(
Expand Down Expand Up @@ -237,10 +245,18 @@ def get_new_metrics(test_data, predictor, quantile_levels):
"MAE_Coverage": aggregated_metrics["MAE_coverage"],
}

for metric_name in ["MSE", "MASE", "MAPE", "sMAPE", "MSIS"]:
for metric_name in [
"MSE",
"MALE",
"MSLE",
"MASE",
"MAPE",
"sMAPE",
"MSIS",
]:
all_metrics[metric_name] = np.ma.mean(item_metrics[metric_name])

for metric_name in ["RMSE", "NRMSE", "ND", "OWA"]:
for metric_name in ["RMSE", "NRMSE", "ND", "EMALE", "ERMSLE", "OWA"]:
all_metrics[metric_name] = aggregated_metrics[metric_name]

return all_metrics
Expand Down
33 changes: 33 additions & 0 deletions test/ev/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

from gluonts.ev import (
absolute_error,
absolute_log_error,
absolute_label,
absolute_percentage_error,
coverage,
error,
log_error,
quantile_loss,
squared_error,
squared_log_error,
symmetric_absolute_percentage_error,
scaled_interval_score,
absolute_scaled_error,
Expand Down Expand Up @@ -88,6 +91,36 @@ def test_squared_error():
np.testing.assert_almost_equal(actual, expected)


def test_log_error():
for label in TIME_SERIES:
for forecast in TIME_SERIES:
data = {"label": label, "0.5": forecast}
actual = log_error(data, forecast_type="0.5")
expected = np.log(label / forecast)

np.testing.assert_almost_equal(actual, expected)


def test_abs_log_error():
for label in TIME_SERIES:
for forecast in TIME_SERIES:
data = {"label": label, "0.5": forecast}
actual = absolute_log_error(data, forecast_type="0.5")
expected = np.abs(np.log(label / forecast))

np.testing.assert_almost_equal(actual, expected)


def test_squared_log_error():
for label in TIME_SERIES:
for forecast in TIME_SERIES:
data = {"label": label, "0.5": forecast}
actual = squared_log_error(data, forecast_type="0.5")
expected = np.square(np.log(label / forecast))

np.testing.assert_almost_equal(actual, expected)


def test_quantile_loss():
for label in TIME_SERIES:
for forecast in TIME_SERIES:
Expand Down

0 comments on commit 97a5125

Please sign in to comment.