Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1589,14 +1589,11 @@ def get_lagged_subsequences(
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.config.lags_sequence]

try:
assert max(indices) + subsequences_length <= sequence_length, (
if max(indices) + subsequences_length > sequence_length:
raise ValueError(
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
except AssertionError as e:
e.args += (max(indices), sequence_length)
raise

lagged_values = []
for lag_index in indices:
Expand Down Expand Up @@ -1642,23 +1639,6 @@ def create_network_inputs(
else (past_values - loc) / scale
)

inputs_length = (
self._past_length + self.config.prediction_length if future_values is not None else self._past_length
)
try:
assert inputs.shape[1] == inputs_length, (
f"input length {inputs.shape[1]} and dynamic feature lengths {inputs_length} does not match",
)
except AssertionError as e:
e.args += (inputs.shape[1], inputs_length)
raise

subsequences_length = (
self.config.context_length + self.config.prediction_length
if future_values is not None
else self.config.context_length
)

# static features
log_abs_loc = loc.abs().log1p() if self.config.input_size == 1 else loc.squeeze(1).abs().log1p()
log_scale = scale.log() if self.config.input_size == 1 else scale.squeeze(1).log()
Expand All @@ -1675,10 +1655,21 @@ def create_network_inputs(
features = torch.cat((expanded_static_feat, time_feat), dim=-1)

# lagged features
subsequences_length = (
self.config.context_length + self.config.prediction_length
if future_values is not None
else self.config.context_length
)
lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)

if reshaped_lagged_sequence.shape[1] != time_feat.shape[1]:
raise ValueError(
f"input length {reshaped_lagged_sequence.shape[1]} and time feature lengths {time_feat.shape[1]} does not match"
)

# transformer inputs
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)

return transformer_inputs, loc, scale, static_feat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest

from huggingface_hub import hf_hub_download
from parameterized import parameterized

from transformers import is_torch_available
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
Expand Down Expand Up @@ -366,6 +367,90 @@ def test_attention_outputs(self):
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_seq_length],
)

@parameterized.expand(
[
(1, 5, [1]),
(1, 5, [1, 10, 15]),
(1, 5, [3, 6, 9, 10]),
(2, 5, [1, 2, 7]),
(2, 5, [2, 3, 4, 6]),
(4, 5, [1, 5, 9, 11]),
(4, 5, [7, 8, 13, 14]),
],
)
def test_create_network_inputs(self, prediction_length, context_length, lags_sequence):
history_length = max(lags_sequence) + context_length

config = TimeSeriesTransformerConfig(
prediction_length=prediction_length,
context_length=context_length,
lags_sequence=lags_sequence,
scaling=False,
num_parallel_samples=10,
num_static_categorical_features=1,
cardinality=[1],
embedding_dimension=[2],
num_static_real_features=1,
)
model = TimeSeriesTransformerModel(config)

batch = {
"static_categorical_features": torch.tensor([[0]], dtype=torch.int64),
"static_real_features": torch.tensor([[0.0]], dtype=torch.float32),
"past_time_features": torch.arange(history_length, dtype=torch.float32).view(1, history_length, 1),
"past_values": torch.arange(history_length, dtype=torch.float32).view(1, history_length),
"past_observed_mask": torch.arange(history_length, dtype=torch.float32).view(1, history_length),
}

# test with no future_target (only one step prediction)
batch["future_time_features"] = torch.arange(history_length, history_length + 1, dtype=torch.float32).view(
1, 1, 1
)
transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch)

self.assertTrue((scale == 1.0).all())
assert (loc == 0.0).all()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update all assert statements, here and below, to use self.assertTrue or another function from unittest

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry! missed that doing that now

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have not been clear enough in my comment before. WE ABSOLUTELY DON'T CARE AND USE BOTHS IN THE TEST CODEBASE.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't see the previous comment as it was remarked as resolved. Ok for me!


ref = torch.arange(max(lags_sequence), history_length, dtype=torch.float32)

for idx, lag in enumerate(lags_sequence):
assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all()

# test with all future data
batch["future_time_features"] = torch.arange(
history_length, history_length + prediction_length, dtype=torch.float32
).view(1, prediction_length, 1)
batch["future_values"] = torch.arange(
history_length, history_length + prediction_length, dtype=torch.float32
).view(1, prediction_length)
transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch)

assert (scale == 1.0).all()
assert (loc == 0.0).all()

ref = torch.arange(max(lags_sequence), history_length + prediction_length, dtype=torch.float32)

for idx, lag in enumerate(lags_sequence):
assert torch.isclose(ref - lag, transformer_inputs[0, :, idx]).all()

# test for generation
batch.pop("future_values")
transformer_inputs, loc, scale, _ = model.create_network_inputs(**batch)

lagged_sequence = model.get_lagged_subsequences(
sequence=batch["past_values"],
subsequences_length=1,
shift=1,
)
# assert that the last element of the lagged sequence is the one after the encoders input
assert transformer_inputs[0, ..., 0][-1] + 1 == lagged_sequence[0, ..., 0][-1]

future_values = torch.arange(history_length, history_length + prediction_length, dtype=torch.float32).view(
1, prediction_length
)
# assert that the first element of the future_values is offset by lag after the decoders input
assert lagged_sequence[0, ..., 0][-1] + lags_sequence[0] == future_values[0, ..., 0]

@is_flaky()
def test_retain_grad_hidden_states_attentions(self):
super().test_retain_grad_hidden_states_attentions()
Expand Down