Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Few shot forecasting #2517

Merged
merged 5 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/gluonts/nursery/few_shot_prediction/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Probabilistic Few-Shot Forecasting
We propose a method for probabilistic few-shot time series forecasting.
In this setting, the model is trained on a large and diverse collection of source data sets and can subsequently make probabilistic predictions for previously unseen target data sets given a small number of representative examples (the support set), without additional retraining or fine-tuning.
Performance can be improved by choosing the examples in the support set based on a notion of distance to the target time series during both training and prediction.
The codebase allows training and evaluating our method on a large collection of publicly available real-world data sets.

### Set up

```bash
# install pyenv
git clone https://github.com/pyenv/pyenv.git ~/.pyenv
cd ~/.pyenv && src/configure && make -C src
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.profile
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.profile
echo 'eval "$(pyenv init --path)"' >> ~/.profile
source ~/.profile
pyenv install 3.8.9

# install poetry
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -
source $HOME/.poetry/env
poetry config virtualenvs.in-project true

# create virtual env for your project
cd to_your_project
pyenv local 3.8.9
poetry config virtualenvs.in-project true
poetry install
```

### Generate and upload data
Run `scripts/data.py real` to download and process the GluonTS datasets supported be the few-shot predictor.

### Train a model
Run `scripts/train.py` with specific arguments to start a training.



81 changes: 81 additions & 0 deletions src/gluonts/nursery/few_shot_prediction/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
[tool.poetry]
authors = ["Your Name <[email protected]>"]
description = ""
name = "meta-tfs"
packages = [
{include = "*", from = "src"},
]
version = "0.1.0"

[tool.poetry.dependencies]
boto3 = "^1.18.43"
botocore = "^1.21.43"
click = "^8.0.1"
geomloss = "^0.2.4"
gluonts = {git = "https://github.com/awslabs/gluon-ts.git"}
pandas = "^1.3.1"
python = "^3.8,<3.10"
pytorch-lightning = "^1.4.4"
sagemaker = "^2.40.0,<2.41.0"
scikit-learn = "^0.24.2"
torch = "^1.9.0"
sagemaker-training = "^3.9.2"
python-dotenv = "^0.19.0"
xlrd = "^2.0.1"
lightkit = "^0.3.5"
catch22 = "^0.2.0"
seaborn = "^0.11.2"

[tool.poetry.dev-dependencies]
black = "^21.7b0"
isort = "^5.9.3"
jupyter = "^1.0.0"
pylint = "^2.10.2"

[tool.poetry.scripts]
schedule = 'schedule:main'

[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core>=1.0.0"]

[tool.pylint.messages_control]
disable = [
"arguments-differ",
"duplicate-code",
"missing-module-docstring",
"invalid-name",
"no-self-use",
"too-few-public-methods",
"too-many-ancestors",
"too-many-arguments",
"too-many-branches",
"too-many-locals",
"too-many-instance-attributes",
]

[tool.pylint.typecheck]
generated-members = [
"math.*",
"torch.*",
]

[tool.pyright]
reportIncompatibleMethodOverride = false
reportMissingStubTypes = false
reportUnknownArgumentType = false
reportUnknownMemberType = false
reportUnknownVariableType = false
typeCheckingMode = "strict"

[tool.black]
line-length = 79

[tool.isort]
force_alphabetical_sort_within_sections = true
include_trailing_comma = true
known_first_party = ["embeddings"]
line_length = 99
lines_between_sections = 0
profile = "black"
skip_gitignore = true
12 changes: 12 additions & 0 deletions src/gluonts/nursery/few_shot_prediction/src/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from .plot import (
ForecastPlotLoggerCallback,
ForecastSupportSetAttentionPlotLoggerCallback,
LossPlotLoggerCallback,
CheatLossPlotLoggerCallback,
MacroCRPSPlotCallback,
)
from .count import ParameterCountCallback
from .save import InitialSaveCallback
from .metric import QuantileMetricLoggerCallback

__all__ = [
"ForecastPlotLoggerCallback",
"ParameterCountCallback",
"InitialSaveCallback",
"ForecastSupportSetAttentionPlotLoggerCallback",
"LossPlotLoggerCallback",
"CheatLossPlotLoggerCallback",
"MacroCRPSPlotCallback",
"QuantileMetricLoggerCallback",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import Trainer
import numpy as np
from pathlib import Path


def get_save_dir_from_csvlogger(logger: CSVLogger) -> Path:
return Path(logger.save_dir) / logger.name / f"version_{logger.version}"


def get_loss_steps(loss_name: str, trainer: Trainer):
loss = trainer.logger.experiment.metrics
ep = trainer.current_epoch
out = [l[loss_name] for l in loss if loss_name in l]
start = 1 if "val" in loss_name else 0
return out, np.linspace(start, ep, len(out))
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from pytorch_lightning import Callback, Trainer, LightningModule


class ParameterCountCallback(Callback): # type: ignore
"""
This callback allows counting model parameters during training.
The output is printed to the console and can be retrieved from the log files.
"""

def __init__(self) -> None:
super().__init__()

def on_pretrain_routine_start(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
# compute number of model params
model_total_params = sum(
p.numel() for p in pl_module.model.parameters()
)
model_total_trainable_params = sum(
p.numel() for p in pl_module.model.parameters() if p.requires_grad
)

# log
print("\n" + f"model_total_params: {model_total_params},")
print(
"\n"
+ f"model_total_trainable_params: {model_total_trainable_params},"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Optional, Union, List
from pytorch_lightning import Callback, Trainer
import torch
from gluonts.time_feature import get_seasonality

from meta.models.module import MetaLightningModule
from meta.common.torch import tensor_to_np
from meta.metrics.numpy import compute_metrics


class QuantileMetricLoggerCallback(Callback):
"""
A callback that computes additional metrics on a numpy representation of the dataset every n epochs.
The computed values are logged to the output file of the pytorch lightning logger.

Args:
quantiles: The quantiles that are predicted.
split: Specifies the split the batch comes from (i.e. from training or validation split)
every_n_epochs: Specifies how often the plots are generated.
Setting this to a large value can save time since plotting can be time consuming
(especially when small datasets are used).
"""

def __init__(
self,
quantiles: List[str],
split: Optional[Union["train", "val"]] = None,
every_n_epochs: int = 1,
):
super().__init__()
self.quantiles = quantiles
self.split = split
self.every_n_epochs = every_n_epochs

def on_validation_epoch_end(
self, trainer: Trainer, pl_module: MetaLightningModule
) -> None:
if trainer.current_epoch % self.every_n_epochs:
return
dm_super = trainer.lightning_module.trainer.datamodule
dm_val = dm_super.data_modules_val[0]
dl = dm_val.val_dataloader()
split = dm_val.splits.val()
pred = []
for batch in dl:
batch = batch.to(pl_module.device)
pred.append(
pl_module.model(
supps=batch.support_set, query=batch.query_past
)
)

# redo standardization for evaluation
pred = split.data().rescale_dataset(torch.cat(pred, dim=0).cpu())
pred = tensor_to_np(pred)

# use only the length that should be included for evaluation
pred = pred[:, : dm_val.prediction_length, ...]
m = compute_metrics(
pred,
split.evaluation(),
quantiles=self.quantiles,
seasonality=get_seasonality(dm_val.meta.freq),
)
self.log("metrics", m)
Loading