-
Notifications
You must be signed in to change notification settings - Fork 33.6k
[time series] Add Time series inputs tests #21846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| import unittest | ||
|
|
||
| from huggingface_hub import hf_hub_download | ||
| from parameterized import parameterized | ||
|
|
||
| from transformers import is_torch_available | ||
| from transformers.testing_utils import is_flaky, require_torch, slow, torch_device | ||
|
|
@@ -366,6 +367,90 @@ def test_attention_outputs(self): | |
| [self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length], | ||
| ) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| (1, 5, [1]), | ||
| (1, 5, [1, 10, 15]), | ||
| (1, 5, [3, 6, 9, 10]), | ||
| (2, 5, [1, 2, 7]), | ||
| (2, 5, [2, 3, 4, 6]), | ||
| (4, 5, [1, 5, 9, 11]), | ||
| (4, 5, [7, 8, 13, 14]), | ||
| ], | ||
| ) | ||
| def test_create_network_inputs(self, prediction_length, context_length, lags_sequence): | ||
| history_length = max(lags_sequence) + context_length | ||
|
|
||
| config = TimeSeriesTransformerConfig( | ||
| prediction_length=prediction_length, | ||
| context_length=context_length, | ||
| lags_sequence=lags_sequence, | ||
| scaling=False, | ||
| num_parallel_samples=10, | ||
| num_static_categorical_features=1, | ||
| cardinality=[1], | ||
| embedding_dimension=[2], | ||
| num_static_real_features=1, | ||
| ) | ||
| model = TimeSeriesTransformerModel(config) | ||
|
|
||
| batch = { | ||
| "static_categorical_features": torch.tensor([[0]], dtype=torch.int64), | ||
| "static_real_features": torch.tensor([[0.0]], dtype=torch.float32), | ||
| "past_time_features": torch.arange(history_length, dtype=torch.float32).view(1, history_length, 1), | ||
| "past_values": torch.arange(history_length, dtype=torch.float32).view(1, history_length), | ||
| "past_observed_mask": torch.arange(history_length, dtype=torch.float32).view(1, history_length), | ||
| } | ||
|
|
||
| # test with no future_target (only one step prediction) | ||
| batch["future_time_features"] = torch.arange(history_length, history_length + 1, dtype=torch.float32).view( | ||
| 1, 1, 1 | ||
| ) | ||
| transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch) | ||
|
|
||
| assert (scale == 1.0).all() | ||
| assert (loc == 0.0).all() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update all assert statements, here and below, to use self.assertTrue or another function from unittest
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah sorry! missed that doing that now
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I may have not been clear enough in my comment before. WE ABSOLUTELY DON'T CARE AND USE BOTHS IN THE TEST CODEBASE.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I didn't see the previous comment as it was remarked as resolved. Ok for me! |
||
|
|
||
| ref = torch.arange(max(lags_sequence), history_length, dtype=torch.float32) | ||
|
|
||
| for idx, lag in enumerate(lags_sequence): | ||
| assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all() | ||
|
|
||
| # test with all future data | ||
| batch["future_time_features"] = torch.arange( | ||
| history_length, history_length + prediction_length, dtype=torch.float32 | ||
| ).view(1, prediction_length, 1) | ||
| batch["future_values"] = torch.arange( | ||
| history_length, history_length + prediction_length, dtype=torch.float32 | ||
| ).view(1, prediction_length) | ||
| transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch) | ||
|
|
||
| assert (scale == 1.0).all() | ||
| assert (loc == 0.0).all() | ||
|
|
||
| ref = torch.arange(max(lags_sequence), history_length + prediction_length, dtype=torch.float32) | ||
|
|
||
| for idx, lag in enumerate(lags_sequence): | ||
| assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all() | ||
|
|
||
| # test for generation | ||
| batch.pop("future_values") | ||
| transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch) | ||
|
|
||
| lagged_sequence = model.get_lagged_subsequences( | ||
| sequence=batch["past_values"], | ||
| subsequences_length=1, | ||
| shift=1, | ||
| ) | ||
| # assert that the last element of the lagged sequence is the one after the encoders input | ||
| assert transformer_inputs[0, ..., 0][-1] + 1 == lagged_sequence[0, ..., 0][-1] | ||
|
|
||
| future_values = torch.arange(history_length, history_length + prediction_length, dtype=torch.float32).view( | ||
| 1, prediction_length | ||
| ) | ||
| # assert that the first element of the future_values is offset by lag after the decoders input | ||
| assert lagged_sequence[0, ..., 0][-1] + lags_sequence[0] == future_values[0, ..., 0] | ||
|
|
||
| @is_flaky() | ||
| def test_retain_grad_hidden_states_attentions(self): | ||
| super().test_retain_grad_hidden_states_attentions() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.