Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions applications/models/src/models/train_tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import wandb
from wandb import Run

from internal.dataset import TemporalFusionTransformerDataset
from internal.tft_model import Parameters, TemporalFusionTransformer
from internal.tft_dataset import TFTDataset
from internal.tft_model import Parameters, TFTModel

configuration = {
"architecture": "TFT",
Expand All @@ -21,7 +21,7 @@


@task
def read_local_data(filepath: str) -> TemporalFusionTransformerDataset:
def read_local_data(filepath: str) -> TFTDataset:
start_time = datetime.now(tz=timezone)
logger.info(f"Reading data from {filepath}")
data = pl.read_csv(filepath)
Expand All @@ -30,17 +30,17 @@ def read_local_data(filepath: str) -> TemporalFusionTransformerDataset:

logger.info(f"Data read successfully in {runtime_seconds} seconds")

return TemporalFusionTransformerDataset(data=data)
return TFTDataset(data=data)


@task
def train_model(
dataset: TemporalFusionTransformerDataset,
dataset: TFTDataset,
wandb_run: Run,
validation_split: float = 0.8,
epoch_count: int = 10,
learning_rate: float = 1e-3,
) -> TemporalFusionTransformer:
) -> TFTModel:
start_time = datetime.now(tz=timezone)
logger.info("Training temporal fusion transformer model")
dimensions = dataset.get_dimensions()
Expand All @@ -62,7 +62,7 @@ def train_model(
output_length=7,
)

model = TemporalFusionTransformer(parameters=parameters)
model = TFTModel(parameters=parameters)

batches = dataset.get_batches(
data_type="train",
Expand Down Expand Up @@ -93,8 +93,8 @@ def train_model(

@task
def validate_model(
data: TemporalFusionTransformerDataset,
model: TemporalFusionTransformer,
data: TFTDataset,
model: TFTModel,
validation_split: float = 0.8,
) -> None:
start_time = datetime.now(tz=timezone)
Expand All @@ -119,7 +119,7 @@ def validate_model(


@task
def save_model(model: TemporalFusionTransformer) -> None:
def save_model(model: TFTModel) -> None:
start_time = datetime.now(tz=timezone)
logger.info("Saving temporal fusion transformer model")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def inverse_transform(self, data: pl.DataFrame) -> pl.DataFrame:
return data * self.standard_deviations + self.means


class TemporalFusionTransformerDataset:
class TFTDataset:
"""Temporal fusion transformer dataset."""

Comment thread
forstmeier marked this conversation as resolved.
def __init__(self, data: pl.DataFrame) -> None:
raw_columns = (
"ticker",
Expand Down
2 changes: 1 addition & 1 deletion libraries/python/src/internal/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Parameters(BaseModel):


# https://arxiv.org/pdf/1912.09363
class TemporalFusionTransformer:
class TFTModel:
def __init__(self, parameters: Parameters) -> None:
self.hidden_size = parameters.hidden_size
self.batch_size = parameters.input_length
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import polars as pl
from internal.dataset import TemporalFusionTransformerDataset
from internal.tft_dataset import TFTDataset


def test_dataset_load_data() -> None:
def test_tft_dataset_load_data() -> None:
data = pl.DataFrame(
{
"timestamp": [
Expand Down Expand Up @@ -39,13 +39,13 @@ def test_dataset_load_data() -> None:
}
)

dataset = TemporalFusionTransformerDataset(data=data)
dataset = TFTDataset(data=data)

assert hasattr(dataset, "data")
assert hasattr(dataset, "mappings")


def test_dataset_get_dimensions() -> None:
def test_tft_dataset_get_dimensions() -> None:
data = pl.DataFrame(
{
"timestamp": [
Expand All @@ -64,7 +64,7 @@ def test_dataset_get_dimensions() -> None:
}
)

dataset = TemporalFusionTransformerDataset(data=data)
dataset = TFTDataset(data=data)

dimensions = dataset.get_dimensions()

Expand All @@ -76,7 +76,7 @@ def test_dataset_get_dimensions() -> None:
assert "static_continuous_features" in dimensions


def test_dataset_batches() -> None:
def test_tft_dataset_batches() -> None:
data = pl.DataFrame(
{
"timestamp": [
Expand All @@ -100,7 +100,7 @@ def test_dataset_batches() -> None:
}
)

dataset = TemporalFusionTransformerDataset(data=data)
dataset = TFTDataset(data=data)

expected_input_length = 2
expected_output_length = 1
Expand Down