-
Notifications
You must be signed in to change notification settings - Fork 748
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
55 changed files
with
7,952 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Probabilistic Few-Shot Forecasting | ||
We propose a method for probabilistic few-shot time series forecasting. | ||
In this setting, the model is trained on a large and diverse collection of source data sets and can subsequently make probabilistic predictions for previously unseen target data sets given a small number of representative examples (the support set), without additional retraining or fine-tuning. | ||
Performance can be improved by choosing the examples in the support set based on a notion of distance to the target time series during both training and prediction. | ||
The codebase allows training and evaluating our method on a large collection of publicly available real-world data sets. | ||
|
||
### Set up | ||
|
||
```bash | ||
# install pyenv | ||
git clone https://github.com/pyenv/pyenv.git ~/.pyenv | ||
cd ~/.pyenv && src/configure && make -C src | ||
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.profile | ||
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.profile | ||
echo 'eval "$(pyenv init --path)"' >> ~/.profile | ||
source ~/.profile | ||
pyenv install 3.8.9 | ||
|
||
# install poetry | ||
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python - | ||
source $HOME/.poetry/env | ||
poetry config virtualenvs.in-project true | ||
|
||
# create virtual env for your project | ||
cd to_your_project | ||
pyenv local 3.8.9 | ||
poetry config virtualenvs.in-project true | ||
poetry install | ||
``` | ||
|
||
### Generate and upload data | ||
Run `scripts/data.py real` to download and process the GluonTS datasets supported be the few-shot predictor. | ||
|
||
### Train a model | ||
Run `scripts/train.py` with specific arguments to start a training. | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
[tool.poetry] | ||
authors = ["Your Name <[email protected]>"] | ||
description = "" | ||
name = "meta-tfs" | ||
packages = [ | ||
{include = "*", from = "src"}, | ||
] | ||
version = "0.1.0" | ||
|
||
[tool.poetry.dependencies] | ||
boto3 = "^1.18.43" | ||
botocore = "^1.21.43" | ||
click = "^8.0.1" | ||
geomloss = "^0.2.4" | ||
gluonts = {git = "https://github.com/awslabs/gluon-ts.git"} | ||
pandas = "^1.3.1" | ||
python = "^3.8,<3.10" | ||
pytorch-lightning = "^1.4.4" | ||
sagemaker = "^2.40.0,<2.41.0" | ||
scikit-learn = "^0.24.2" | ||
torch = "^1.9.0" | ||
sagemaker-training = "^3.9.2" | ||
python-dotenv = "^0.19.0" | ||
xlrd = "^2.0.1" | ||
lightkit = "^0.3.5" | ||
catch22 = "^0.2.0" | ||
seaborn = "^0.11.2" | ||
|
||
[tool.poetry.dev-dependencies] | ||
black = "^21.7b0" | ||
isort = "^5.9.3" | ||
jupyter = "^1.0.0" | ||
pylint = "^2.10.2" | ||
|
||
[tool.poetry.scripts] | ||
schedule = 'schedule:main' | ||
|
||
[build-system] | ||
build-backend = "poetry.core.masonry.api" | ||
requires = ["poetry-core>=1.0.0"] | ||
|
||
[tool.pylint.messages_control] | ||
disable = [ | ||
"arguments-differ", | ||
"duplicate-code", | ||
"missing-module-docstring", | ||
"invalid-name", | ||
"no-self-use", | ||
"too-few-public-methods", | ||
"too-many-ancestors", | ||
"too-many-arguments", | ||
"too-many-branches", | ||
"too-many-locals", | ||
"too-many-instance-attributes", | ||
] | ||
|
||
[tool.pylint.typecheck] | ||
generated-members = [ | ||
"math.*", | ||
"torch.*", | ||
] | ||
|
||
[tool.pyright] | ||
reportIncompatibleMethodOverride = false | ||
reportMissingStubTypes = false | ||
reportUnknownArgumentType = false | ||
reportUnknownMemberType = false | ||
reportUnknownVariableType = false | ||
typeCheckingMode = "strict" | ||
|
||
[tool.black] | ||
line-length = 79 | ||
|
||
[tool.isort] | ||
force_alphabetical_sort_within_sections = true | ||
include_trailing_comma = true | ||
known_first_party = ["embeddings"] | ||
line_length = 99 | ||
lines_between_sections = 0 | ||
profile = "black" | ||
skip_gitignore = true |
12 changes: 12 additions & 0 deletions
12
src/gluonts/nursery/few_shot_prediction/src/meta/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. |
34 changes: 34 additions & 0 deletions
34
src/gluonts/nursery/few_shot_prediction/src/meta/callbacks/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from .plot import ( | ||
ForecastPlotLoggerCallback, | ||
ForecastSupportSetAttentionPlotLoggerCallback, | ||
LossPlotLoggerCallback, | ||
CheatLossPlotLoggerCallback, | ||
MacroCRPSPlotCallback, | ||
) | ||
from .count import ParameterCountCallback | ||
from .save import InitialSaveCallback | ||
from .metric import QuantileMetricLoggerCallback | ||
|
||
__all__ = [ | ||
"ForecastPlotLoggerCallback", | ||
"ParameterCountCallback", | ||
"InitialSaveCallback", | ||
"ForecastSupportSetAttentionPlotLoggerCallback", | ||
"LossPlotLoggerCallback", | ||
"CheatLossPlotLoggerCallback", | ||
"MacroCRPSPlotCallback", | ||
"QuantileMetricLoggerCallback", | ||
] |
29 changes: 29 additions & 0 deletions
29
src/gluonts/nursery/few_shot_prediction/src/meta/callbacks/common.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from pytorch_lightning.loggers import CSVLogger | ||
from pytorch_lightning import Trainer | ||
import numpy as np | ||
from pathlib import Path | ||
|
||
|
||
def get_save_dir_from_csvlogger(logger: CSVLogger) -> Path: | ||
return Path(logger.save_dir) / logger.name / f"version_{logger.version}" | ||
|
||
|
||
def get_loss_steps(loss_name: str, trainer: Trainer): | ||
loss = trainer.logger.experiment.metrics | ||
ep = trainer.current_epoch | ||
out = [l[loss_name] for l in loss if loss_name in l] | ||
start = 1 if "val" in loss_name else 0 | ||
return out, np.linspace(start, ep, len(out)) |
42 changes: 42 additions & 0 deletions
42
src/gluonts/nursery/few_shot_prediction/src/meta/callbacks/count.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from pytorch_lightning import Callback, Trainer, LightningModule | ||
|
||
|
||
class ParameterCountCallback(Callback): # type: ignore | ||
""" | ||
This callback allows counting model parameters during training. | ||
The output is printed to the console and can be retrieved from the log files. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def on_pretrain_routine_start( | ||
self, trainer: Trainer, pl_module: LightningModule | ||
) -> None: | ||
# compute number of model params | ||
model_total_params = sum( | ||
p.numel() for p in pl_module.model.parameters() | ||
) | ||
model_total_trainable_params = sum( | ||
p.numel() for p in pl_module.model.parameters() if p.requires_grad | ||
) | ||
|
||
# log | ||
print("\n" + f"model_total_params: {model_total_params},") | ||
print( | ||
"\n" | ||
+ f"model_total_trainable_params: {model_total_trainable_params}," | ||
) |
78 changes: 78 additions & 0 deletions
78
src/gluonts/nursery/few_shot_prediction/src/meta/callbacks/metric.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from typing import Optional, Union, List | ||
from pytorch_lightning import Callback, Trainer | ||
import torch | ||
from gluonts.time_feature import get_seasonality | ||
|
||
from meta.models.module import MetaLightningModule | ||
from meta.common.torch import tensor_to_np | ||
from meta.metrics.numpy import compute_metrics | ||
|
||
|
||
class QuantileMetricLoggerCallback(Callback): | ||
""" | ||
A callback that computes additional metrics on a numpy representation of the dataset every n epochs. | ||
The computed values are logged to the output file of the pytorch lightning logger. | ||
Args: | ||
quantiles: The quantiles that are predicted. | ||
split: Specifies the split the batch comes from (i.e. from training or validation split) | ||
every_n_epochs: Specifies how often the plots are generated. | ||
Setting this to a large value can save time since plotting can be time consuming | ||
(especially when small datasets are used). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
quantiles: List[str], | ||
split: Optional[Union["train", "val"]] = None, | ||
every_n_epochs: int = 1, | ||
): | ||
super().__init__() | ||
self.quantiles = quantiles | ||
self.split = split | ||
self.every_n_epochs = every_n_epochs | ||
|
||
def on_validation_epoch_end( | ||
self, trainer: Trainer, pl_module: MetaLightningModule | ||
) -> None: | ||
if trainer.current_epoch % self.every_n_epochs: | ||
return | ||
dm_super = trainer.lightning_module.trainer.datamodule | ||
dm_val = dm_super.data_modules_val[0] | ||
dl = dm_val.val_dataloader() | ||
split = dm_val.splits.val() | ||
pred = [] | ||
for batch in dl: | ||
batch = batch.to(pl_module.device) | ||
pred.append( | ||
pl_module.model( | ||
supps=batch.support_set, query=batch.query_past | ||
) | ||
) | ||
|
||
# redo standardization for evaluation | ||
pred = split.data().rescale_dataset(torch.cat(pred, dim=0).cpu()) | ||
pred = tensor_to_np(pred) | ||
|
||
# use only the length that should be included for evaluation | ||
pred = pred[:, : dm_val.prediction_length, ...] | ||
m = compute_metrics( | ||
pred, | ||
split.evaluation(), | ||
quantiles=self.quantiles, | ||
seasonality=get_seasonality(dm_val.meta.freq), | ||
) | ||
self.log("metrics", m) |
Oops, something went wrong.