Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions applications/portfoliomanager/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@ description = "Portfolio prediction and construction service"
requires-python = "==3.12.10"
dependencies = ["internal"]

[tool.uv]
package = true
src = ["src"]

[tool.uv.sources]
internal = { workspace = true }
20 changes: 20 additions & 0 deletions applications/portfoliomanager/src/portfoliomanager/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import polars as pl


def filter_equity_bars(data: pl.DataFrame) -> pl.DataFrame:
Comment thread
forstmeier marked this conversation as resolved.
Outdated
data = data.clone()

minimum_average_close_price = 10.0
minimum_average_volume = 1_000_000.0

Comment thread
forstmeier marked this conversation as resolved.
Outdated
return (
data.group_by("ticker")
.agg(
avg_close_price=pl.col("close_price").mean(),
avg_volume=pl.col("volume").mean(),
)
.filter(
(pl.col("avg_close_price") > minimum_average_close_price)
& (pl.col("avg_volume") > minimum_average_volume)
)
)
214 changes: 214 additions & 0 deletions applications/portfoliomanager/tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import polars as pl
import pytest
from portfoliomanager.preprocess import filter_equity_bars


def test_filter_equity_bars_above_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [15.0, 20.0, 25.0],
"volume": [
1_500_000.0,
2_000_000.0,
2_500_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == 20.0 # noqa: PLR2004
assert result["avg_volume"][0] == 2_000_000.0 # noqa: PLR2004


def test_filter_equity_bars_below_price_threshold() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [5.0, 8.0, 9.0],
"volume": [
1_500_000.0,
2_000_000.0,
2_500_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_below_volume_threshold() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [15.0, 20.0, 25.0],
"volume": [500_000.0, 600_000.0, 700_000.0],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_below_both_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [5.0, 6.0, 7.0],
"volume": [500_000.0, 600_000.0, 700_000.0],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_at_exact_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [
10.0,
10.0,
10.0,
],
"volume": [
1_000_000.0,
1_000_000.0,
1_000_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_just_above_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [10.01, 10.01, 10.01],
"volume": [
1_000_001.0,
1_000_001.0,
1_000_001.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == pytest.approx(10.01)
assert result["avg_volume"][0] == pytest.approx(1_000_001.0)


def test_filter_equity_bars_empty_dataframe() -> None:
data = pl.DataFrame(
{
"ticker": [],
"close_price": [],
"volume": [],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_single_row() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL"],
"close_price": [15.0],
"volume": [1_500_000.0],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == 15.0 # noqa: PLR2004
assert result["avg_volume"][0] == 1_500_000.0 # noqa: PLR2004


def test_filter_equity_bars_mixed_values() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL"],
"close_price": [5.0, 25.0],
"volume": [
500_000.0,
1_500_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_multiple_tickers() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL", "GOOGL", "GOOGL", "TSLA", "TSLA"],
"close_price": [
15.0,
20.0,
25.0,
5.0,
6.0,
12.0,
18.0,
],
"volume": [
1_500_000.0,
2_000_000.0,
2_500_000.0,
2_000_000.0,
3_000_000.0,
800_000.0,
900_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == 20.0 # noqa: PLR2004
assert result["avg_volume"][0] == 2_000_000.0 # noqa: PLR2004


def test_filter_equity_bars_data_immutability() -> None:
original_data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [15.0, 20.0, 25.0],
"volume": [1_500_000.0, 2_000_000.0, 2_500_000.0],
}
)

original_tickers = original_data["ticker"].to_list()
original_close_prices = original_data["close_price"].to_list()
original_volumes = original_data["volume"].to_list()

filter_equity_bars(original_data)

assert original_data["ticker"].to_list() == original_tickers
assert original_data["close_price"].to_list() == original_close_prices
assert original_data["volume"].to_list() == original_volumes
54 changes: 41 additions & 13 deletions libraries/python/src/internal/tft_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import TYPE_CHECKING
from datetime import date

import pandera.polars as pa
import polars as pl
from tinygrad.tensor import Tensor

if TYPE_CHECKING:
from datetime import date


class Scaler:
def __init__(self) -> None:
Expand All @@ -29,6 +27,8 @@ class TFTDataset:
"""Temporal fusion transformer dataset."""

def __init__(self, data: pl.DataFrame) -> None:
data = data.clone()

raw_columns = (
"ticker",
"timestamp",
Expand Down Expand Up @@ -137,29 +137,33 @@ def __init__(self, data: pl.DataFrame) -> None:
pl.col("timestamp").fill_null(
pl.col("date")
.cast(pl.Datetime)
.dt.replace_time_zone("America/New_York")
.dt.replace_time_zone("UTC")
.cast(pl.Int64)
.floordiv(1000)
),
]
)

data = data.with_columns( # compute new columns
pl.col("date").dt.weekday().alias("day_of_week"),
pl.col("date").dt.day().alias("day_of_month"),
pl.col("date").dt.ordinal_day().alias("day_of_year"),
pl.col("date").dt.month().alias("month"),
pl.col("date").dt.year().alias("year"),
# compute new calendar columns
data = data.with_columns(
pl.col("date").dt.weekday().alias("day_of_week").cast(pl.Int64),
pl.col("date").dt.day().alias("day_of_month").cast(pl.Int64),
pl.col("date").dt.ordinal_day().alias("day_of_year").cast(pl.Int64),
pl.col("date").dt.month().alias("month").cast(pl.Int64),
pl.col("date").dt.year().alias("year").cast(pl.Int64),
)

data = data.sort(["ticker", "timestamp"]).with_columns( # add time index column
# add time index column
data = data.sort(["ticker", "timestamp"]).with_columns(
pl.col("timestamp")
.rank("dense")
.over("ticker")
.cast(pl.Int32)
.cast(pl.Int64)
.alias("time_idx")
)

data = equity_bar_schema.validate(data)

self.scaler = Scaler()

self.scaler.fit(data[self.continuous_columns])
Expand Down Expand Up @@ -321,3 +325,27 @@ def get_batches(
batches.append(batch)

return batches


equity_bar_schema = pa.DataFrameSchema(
{
"ticker": pa.Column(str, required=True),
"timestamp": pa.Column(int, required=True),
"open_price": pa.Column(float, required=True),
"high_price": pa.Column(float, required=True),
"low_price": pa.Column(float, required=True),
"close_price": pa.Column(float, required=True),
"volume": pa.Column(float, required=True),
"volume_weighted_average_price": pa.Column(float, required=True),
"sector": pa.Column(str, required=True),
"industry": pa.Column(str, required=True),
"date": pa.Column(date, required=True),
"day_of_week": pa.Column(int, required=True),
"day_of_month": pa.Column(int, required=True),
"day_of_year": pa.Column(int, required=True),
"month": pa.Column(int, required=True),
"year": pa.Column(int, required=True),
"is_holiday": pa.Column(bool, required=True),
"time_idx": pa.Column(int, required=True),
}
)
Comment thread
forstmeier marked this conversation as resolved.
Outdated
Comment thread
forstmeier marked this conversation as resolved.
Outdated
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.