Skip to content

Commit

Permalink
Few shot forecasting (#2517)
Browse files Browse the repository at this point in the history
  • Loading branch information
RingoIngo committed Dec 26, 2022
1 parent 9059f26 commit 8c29bca
Show file tree
Hide file tree
Showing 55 changed files with 7,952 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/gluonts/nursery/few_shot_prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Probabilistic Few-Shot Forecasting
We propose a method for probabilistic few-shot time series forecasting.
In this setting, the model is trained on a large and diverse collection of source data sets and can subsequently make probabilistic predictions for previously unseen target data sets given a small number of representative examples (the support set), without additional retraining or fine-tuning.
Performance can be improved by choosing the examples in the support set based on a notion of distance to the target time series during both training and prediction.
The codebase allows training and evaluating our method on a large collection of publicly available real-world data sets.

### Set up

```bash
# install pyenv
git clone https://github.com/pyenv/pyenv.git ~/.pyenv
cd ~/.pyenv && src/configure && make -C src
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.profile
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.profile
echo 'eval "$(pyenv init --path)"' >> ~/.profile
source ~/.profile
pyenv install 3.8.9

# install poetry
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -
source $HOME/.poetry/env
poetry config virtualenvs.in-project true

# create virtual env for your project
cd to_your_project
pyenv local 3.8.9
poetry config virtualenvs.in-project true
poetry install
```

### Generate and upload data
Run `scripts/data.py real` to download and process the GluonTS datasets supported be the few-shot predictor.

### Train a model
Run `scripts/train.py` with specific arguments to start a training.



81 changes: 81 additions & 0 deletions src/gluonts/nursery/few_shot_prediction/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
[tool.poetry]
authors = ["Your Name <[email protected]>"]
description = ""
name = "meta-tfs"
packages = [
{include = "*", from = "src"},
]
version = "0.1.0"

[tool.poetry.dependencies]
boto3 = "^1.18.43"
botocore = "^1.21.43"
click = "^8.0.1"
geomloss = "^0.2.4"
gluonts = {git = "https://github.com/awslabs/gluon-ts.git"}
pandas = "^1.3.1"
python = "^3.8,<3.10"
pytorch-lightning = "^1.4.4"
sagemaker = "^2.40.0,<2.41.0"
scikit-learn = "^0.24.2"
torch = "^1.9.0"
sagemaker-training = "^3.9.2"
python-dotenv = "^0.19.0"
xlrd = "^2.0.1"
lightkit = "^0.3.5"
catch22 = "^0.2.0"
seaborn = "^0.11.2"

[tool.poetry.dev-dependencies]
black = "^21.7b0"
isort = "^5.9.3"
jupyter = "^1.0.0"
pylint = "^2.10.2"

[tool.poetry.scripts]
schedule = 'schedule:main'

[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core>=1.0.0"]

[tool.pylint.messages_control]
disable = [
"arguments-differ",
"duplicate-code",
"missing-module-docstring",
"invalid-name",
"no-self-use",
"too-few-public-methods",
"too-many-ancestors",
"too-many-arguments",
"too-many-branches",
"too-many-locals",
"too-many-instance-attributes",
]

[tool.pylint.typecheck]
generated-members = [
"math.*",
"torch.*",
]

[tool.pyright]
reportIncompatibleMethodOverride = false
reportMissingStubTypes = false
reportUnknownArgumentType = false
reportUnknownMemberType = false
reportUnknownVariableType = false
typeCheckingMode = "strict"

[tool.black]
line-length = 79

[tool.isort]
force_alphabetical_sort_within_sections = true
include_trailing_comma = true
known_first_party = ["embeddings"]
line_length = 99
lines_between_sections = 0
profile = "black"
skip_gitignore = true
12 changes: 12 additions & 0 deletions src/gluonts/nursery/few_shot_prediction/src/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from .plot import (
ForecastPlotLoggerCallback,
ForecastSupportSetAttentionPlotLoggerCallback,
LossPlotLoggerCallback,
CheatLossPlotLoggerCallback,
MacroCRPSPlotCallback,
)
from .count import ParameterCountCallback
from .save import InitialSaveCallback
from .metric import QuantileMetricLoggerCallback

__all__ = [
"ForecastPlotLoggerCallback",
"ParameterCountCallback",
"InitialSaveCallback",
"ForecastSupportSetAttentionPlotLoggerCallback",
"LossPlotLoggerCallback",
"CheatLossPlotLoggerCallback",
"MacroCRPSPlotCallback",
"QuantileMetricLoggerCallback",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Trainer
import numpy as np
from pathlib import Path


def get_save_dir_from_csvlogger(logger: CSVLogger) -> Path:
return Path(logger.save_dir) / logger.name / f"version_{logger.version}"


def get_loss_steps(loss_name: str, trainer: Trainer):
loss = trainer.logger.experiment.metrics
ep = trainer.current_epoch
out = [l[loss_name] for l in loss if loss_name in l]
start = 1 if "val" in loss_name else 0
return out, np.linspace(start, ep, len(out))
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pytorch_lightning import Callback, Trainer, LightningModule


class ParameterCountCallback(Callback): # type: ignore
"""
This callback allows counting model parameters during training.
The output is printed to the console and can be retrieved from the log files.
"""

def __init__(self) -> None:
super().__init__()

def on_pretrain_routine_start(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
# compute number of model params
model_total_params = sum(
p.numel() for p in pl_module.model.parameters()
)
model_total_trainable_params = sum(
p.numel() for p in pl_module.model.parameters() if p.requires_grad
)

# log
print("\n" + f"model_total_params: {model_total_params},")
print(
"\n"
+ f"model_total_trainable_params: {model_total_trainable_params},"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Optional, Union, List
from pytorch_lightning import Callback, Trainer
import torch
from gluonts.time_feature import get_seasonality

from meta.models.module import MetaLightningModule
from meta.common.torch import tensor_to_np
from meta.metrics.numpy import compute_metrics


class QuantileMetricLoggerCallback(Callback):
"""
A callback that computes additional metrics on a numpy representation of the dataset every n epochs.
The computed values are logged to the output file of the pytorch lightning logger.
Args:
quantiles: The quantiles that are predicted.
split: Specifies the split the batch comes from (i.e. from training or validation split)
every_n_epochs: Specifies how often the plots are generated.
Setting this to a large value can save time since plotting can be time consuming
(especially when small datasets are used).
"""

def __init__(
self,
quantiles: List[str],
split: Optional[Union["train", "val"]] = None,
every_n_epochs: int = 1,
):
super().__init__()
self.quantiles = quantiles
self.split = split
self.every_n_epochs = every_n_epochs

def on_validation_epoch_end(
self, trainer: Trainer, pl_module: MetaLightningModule
) -> None:
if trainer.current_epoch % self.every_n_epochs:
return
dm_super = trainer.lightning_module.trainer.datamodule
dm_val = dm_super.data_modules_val[0]
dl = dm_val.val_dataloader()
split = dm_val.splits.val()
pred = []
for batch in dl:
batch = batch.to(pl_module.device)
pred.append(
pl_module.model(
supps=batch.support_set, query=batch.query_past
)
)

# redo standardization for evaluation
pred = split.data().rescale_dataset(torch.cat(pred, dim=0).cpu())
pred = tensor_to_np(pred)

# use only the length that should be included for evaluation
pred = pred[:, : dm_val.prediction_length, ...]
m = compute_metrics(
pred,
split.evaluation(),
quantiles=self.quantiles,
seasonality=get_seasonality(dm_val.meta.freq),
)
self.log("metrics", m)
Loading

0 comments on commit 8c29bca

Please sign in to comment.