Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions applications/portfoliomanager/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@ description = "Portfolio prediction and construction service"
requires-python = "==3.12.10"
dependencies = ["internal"]

[tool.uv]
package = true
src = ["src"]

[tool.uv.sources]
internal = { workspace = true }
21 changes: 21 additions & 0 deletions applications/portfoliomanager/src/portfoliomanager/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import polars as pl


def filter_equity_bars(
data: pl.DataFrame,
minimum_average_close_price: float = 10.0,
minimum_average_volume: float = 1_000_000.0,
) -> pl.DataFrame:
data = data.clone()

return (
data.group_by("ticker")
.agg(
avg_close_price=pl.col("close_price").mean(),
avg_volume=pl.col("volume").mean(),
)
.filter(
(pl.col("avg_close_price") > minimum_average_close_price)
& (pl.col("avg_volume") > minimum_average_volume)
)
)
214 changes: 214 additions & 0 deletions applications/portfoliomanager/tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import polars as pl
import pytest
from portfoliomanager.preprocess import filter_equity_bars


def test_filter_equity_bars_above_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [15.0, 20.0, 25.0],
"volume": [
1_500_000.0,
2_000_000.0,
2_500_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == 20.0 # noqa: PLR2004
assert result["avg_volume"][0] == 2_000_000.0 # noqa: PLR2004


def test_filter_equity_bars_below_price_threshold() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [5.0, 8.0, 9.0],
"volume": [
1_500_000.0,
2_000_000.0,
2_500_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_below_volume_threshold() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [15.0, 20.0, 25.0],
"volume": [500_000.0, 600_000.0, 700_000.0],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_below_both_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [5.0, 6.0, 7.0],
"volume": [500_000.0, 600_000.0, 700_000.0],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_at_exact_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [
10.0,
10.0,
10.0,
],
"volume": [
1_000_000.0,
1_000_000.0,
1_000_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_just_above_thresholds() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [10.01, 10.01, 10.01],
"volume": [
1_000_001.0,
1_000_001.0,
1_000_001.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == pytest.approx(10.01)
assert result["avg_volume"][0] == pytest.approx(1_000_001.0)


def test_filter_equity_bars_empty_dataframe() -> None:
data = pl.DataFrame(
{
"ticker": [],
"close_price": [],
"volume": [],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_single_row() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL"],
"close_price": [15.0],
"volume": [1_500_000.0],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == 15.0 # noqa: PLR2004
assert result["avg_volume"][0] == 1_500_000.0 # noqa: PLR2004


def test_filter_equity_bars_mixed_values() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL"],
"close_price": [5.0, 25.0],
"volume": [
500_000.0,
1_500_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_filter_equity_bars_multiple_tickers() -> None:
data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL", "GOOGL", "GOOGL", "TSLA", "TSLA"],
"close_price": [
15.0,
20.0,
25.0,
5.0,
6.0,
12.0,
18.0,
],
"volume": [
1_500_000.0,
2_000_000.0,
2_500_000.0,
2_000_000.0,
3_000_000.0,
800_000.0,
900_000.0,
],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert result["avg_close_price"][0] == 20.0 # noqa: PLR2004
assert result["avg_volume"][0] == 2_000_000.0 # noqa: PLR2004


def test_filter_equity_bars_data_immutability() -> None:
original_data = pl.DataFrame(
{
"ticker": ["AAPL", "AAPL", "AAPL"],
"close_price": [15.0, 20.0, 25.0],
"volume": [1_500_000.0, 2_000_000.0, 2_500_000.0],
}
)

original_tickers = original_data["ticker"].to_list()
original_close_prices = original_data["close_price"].to_list()
original_volumes = original_data["volume"].to_list()

filter_equity_bars(original_data)

assert original_data["ticker"].to_list() == original_tickers
assert original_data["close_price"].to_list() == original_close_prices
assert original_data["volume"].to_list() == original_volumes
85 changes: 72 additions & 13 deletions libraries/python/src/internal/tft_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import TYPE_CHECKING
from datetime import date

import pandera.polars as pa
import polars as pl
from tinygrad.tensor import Tensor

if TYPE_CHECKING:
from datetime import date


class Scaler:
def __init__(self) -> None:
Expand All @@ -29,6 +27,8 @@ class TFTDataset:
"""Temporal fusion transformer dataset."""

def __init__(self, data: pl.DataFrame) -> None:
data = data.clone()

raw_columns = (
"ticker",
"timestamp",
Expand Down Expand Up @@ -137,29 +137,33 @@ def __init__(self, data: pl.DataFrame) -> None:
pl.col("timestamp").fill_null(
pl.col("date")
.cast(pl.Datetime)
.dt.replace_time_zone("America/New_York")
.dt.replace_time_zone("UTC")
.cast(pl.Int64)
.floordiv(1000)
),
]
)

data = data.with_columns( # compute new columns
pl.col("date").dt.weekday().alias("day_of_week"),
pl.col("date").dt.day().alias("day_of_month"),
pl.col("date").dt.ordinal_day().alias("day_of_year"),
pl.col("date").dt.month().alias("month"),
pl.col("date").dt.year().alias("year"),
# compute new calendar columns
data = data.with_columns(
pl.col("date").dt.weekday().alias("day_of_week").cast(pl.Int64),
pl.col("date").dt.day().alias("day_of_month").cast(pl.Int64),
pl.col("date").dt.ordinal_day().alias("day_of_year").cast(pl.Int64),
pl.col("date").dt.month().alias("month").cast(pl.Int64),
pl.col("date").dt.year().alias("year").cast(pl.Int64),
)

data = data.sort(["ticker", "timestamp"]).with_columns( # add time index column
# add time index column
data = data.sort(["ticker", "timestamp"]).with_columns(
pl.col("timestamp")
.rank("dense")
.over("ticker")
.cast(pl.Int32)
.cast(pl.Int64)
.alias("time_idx")
)

data = dataset_schema.validate(data)

self.scaler = Scaler()

self.scaler.fit(data[self.continuous_columns])
Expand Down Expand Up @@ -321,3 +325,58 @@ def get_batches(
batches.append(batch)

return batches


dataset_schema = pa.DataFrameSchema(
{
"ticker": pa.Column(
str,
checks=pa.Check.str_matches(r"^[A-Z0-9.\-]+$"),
coerce=True,
required=True,
),
"timestamp": pa.Column(
int,
checks=pa.Check.gt(0),
coerce=True,
required=True,
),
"open_price": pa.Column(
float,
checks=pa.Check.ge(0),
coerce=True,
required=True,
),
"high_price": pa.Column(
float, checks=pa.Check.ge(0), coerce=True, required=True
),
"low_price": pa.Column(
float, checks=pa.Check.ge(0), coerce=True, required=True
),
"close_price": pa.Column(
float, checks=pa.Check.ge(0), coerce=True, required=True
),
"volume": pa.Column(
int,
checks=pa.Check.ge(0),
coerce=True,
required=True,
),
"volume_weighted_average_price": pa.Column(
float,
checks=pa.Check.ge(0),
coerce=True,
required=True,
),
"sector": pa.Column(str, coerce=True, required=True),
"industry": pa.Column(str, coerce=True, required=True),
"date": pa.Column(date, coerce=True, required=True),
"day_of_week": pa.Column(int, coerce=True, required=True),
"day_of_month": pa.Column(int, coerce=True, required=True),
"day_of_year": pa.Column(int, coerce=True, required=True),
"month": pa.Column(int, coerce=True, required=True),
"year": pa.Column(int, coerce=True, required=True),
"is_holiday": pa.Column(bool, coerce=True, required=True),
"time_idx": pa.Column(int, coerce=True, required=True),
}
)
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.