diff --git a/applications/models/src/models/train_tft_model.py b/applications/models/src/models/train_tft_model.py index 16152a8ff..cdaee01bf 100644 --- a/applications/models/src/models/train_tft_model.py +++ b/applications/models/src/models/train_tft_model.py @@ -7,8 +7,8 @@ import wandb from wandb import Run -from internal.dataset import TemporalFusionTransformerDataset -from internal.tft_model import Parameters, TemporalFusionTransformer +from internal.tft_dataset import TFTDataset +from internal.tft_model import Parameters, TFTModel configuration = { "architecture": "TFT", @@ -21,7 +21,7 @@ @task -def read_local_data(filepath: str) -> TemporalFusionTransformerDataset: +def read_local_data(filepath: str) -> TFTDataset: start_time = datetime.now(tz=timezone) logger.info(f"Reading data from {filepath}") data = pl.read_csv(filepath) @@ -30,17 +30,17 @@ def read_local_data(filepath: str) -> TemporalFusionTransformerDataset: logger.info(f"Data read successfully in {runtime_seconds} seconds") - return TemporalFusionTransformerDataset(data=data) + return TFTDataset(data=data) @task def train_model( - dataset: TemporalFusionTransformerDataset, + dataset: TFTDataset, wandb_run: Run, validation_split: float = 0.8, epoch_count: int = 10, learning_rate: float = 1e-3, -) -> TemporalFusionTransformer: +) -> TFTModel: start_time = datetime.now(tz=timezone) logger.info("Training temporal fusion transformer model") dimensions = dataset.get_dimensions() @@ -62,7 +62,7 @@ def train_model( output_length=7, ) - model = TemporalFusionTransformer(parameters=parameters) + model = TFTModel(parameters=parameters) batches = dataset.get_batches( data_type="train", @@ -93,8 +93,8 @@ def train_model( @task def validate_model( - data: TemporalFusionTransformerDataset, - model: TemporalFusionTransformer, + data: TFTDataset, + model: TFTModel, validation_split: float = 0.8, ) -> None: start_time = datetime.now(tz=timezone) @@ -119,7 +119,7 @@ def validate_model( @task -def save_model(model: TemporalFusionTransformer) -> None: +def save_model(model: TFTModel) -> None: start_time = datetime.now(tz=timezone) logger.info("Saving temporal fusion transformer model") diff --git a/libraries/python/src/internal/dataset.py b/libraries/python/src/internal/tft_dataset.py similarity index 99% rename from libraries/python/src/internal/dataset.py rename to libraries/python/src/internal/tft_dataset.py index 3016b252e..0ef2f61b4 100644 --- a/libraries/python/src/internal/dataset.py +++ b/libraries/python/src/internal/tft_dataset.py @@ -25,7 +25,9 @@ def inverse_transform(self, data: pl.DataFrame) -> pl.DataFrame: return data * self.standard_deviations + self.means -class TemporalFusionTransformerDataset: +class TFTDataset: + """Temporal fusion transformer dataset.""" + def __init__(self, data: pl.DataFrame) -> None: raw_columns = ( "ticker", diff --git a/libraries/python/src/internal/tft_model.py b/libraries/python/src/internal/tft_model.py index 47ad83657..7d30617c8 100644 --- a/libraries/python/src/internal/tft_model.py +++ b/libraries/python/src/internal/tft_model.py @@ -34,7 +34,7 @@ class Parameters(BaseModel): # https://arxiv.org/pdf/1912.09363 -class TemporalFusionTransformer: +class TFTModel: def __init__(self, parameters: Parameters) -> None: self.hidden_size = parameters.hidden_size self.batch_size = parameters.input_length diff --git a/libraries/python/tests/test_dataset.py b/libraries/python/tests/test_tft_dataset.py similarity index 92% rename from libraries/python/tests/test_dataset.py rename to libraries/python/tests/test_tft_dataset.py index 74dad37e0..4261889f8 100644 --- a/libraries/python/tests/test_dataset.py +++ b/libraries/python/tests/test_tft_dataset.py @@ -1,8 +1,8 @@ import polars as pl -from internal.dataset import TemporalFusionTransformerDataset +from internal.tft_dataset import TFTDataset -def test_dataset_load_data() -> None: +def test_tft_dataset_load_data() -> None: data = pl.DataFrame( { "timestamp": [ @@ -39,13 +39,13 @@ def test_dataset_load_data() -> None: } ) - dataset = TemporalFusionTransformerDataset(data=data) + dataset = TFTDataset(data=data) assert hasattr(dataset, "data") assert hasattr(dataset, "mappings") -def test_dataset_get_dimensions() -> None: +def test_tft_dataset_get_dimensions() -> None: data = pl.DataFrame( { "timestamp": [ @@ -64,7 +64,7 @@ def test_dataset_get_dimensions() -> None: } ) - dataset = TemporalFusionTransformerDataset(data=data) + dataset = TFTDataset(data=data) dimensions = dataset.get_dimensions() @@ -76,7 +76,7 @@ def test_dataset_get_dimensions() -> None: assert "static_continuous_features" in dimensions -def test_dataset_batches() -> None: +def test_tft_dataset_batches() -> None: data = pl.DataFrame( { "timestamp": [ @@ -100,7 +100,7 @@ def test_dataset_batches() -> None: } ) - dataset = TemporalFusionTransformerDataset(data=data) + dataset = TFTDataset(data=data) expected_input_length = 2 expected_output_length = 1