diff --git a/.flox/env/manifest.lock b/.flox/env/manifest.lock index 7106a875a..6bc8e78dc 100644 --- a/.flox/env/manifest.lock +++ b/.flox/env/manifest.lock @@ -27,6 +27,9 @@ "gum": { "pkg-path": "gum" }, + "mask": { + "pkg-path": "mask" + }, "mise": { "pkg-path": "mise" }, @@ -997,6 +1000,122 @@ "group": "toplevel", "priority": 5 }, + { + "attr_path": "mask", + "broken": false, + "derivation": "/nix/store/szjnp3963v7dswjzffggb0yqnrmqg3pa-mask-0.11.6.drv", + "description": "CLI task runner defined by a simple markdown file", + "install_id": "mask", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=c23193b943c6c689d70ee98ce3128239ed9e32d1", + "name": "mask-0.11.6", + "pname": "mask", + "rev": "c23193b943c6c689d70ee98ce3128239ed9e32d1", + "rev_count": 861038, + "rev_date": "2025-09-13T06:43:22Z", + "scrape_date": "2025-09-15T02:01:02.794961Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "0.11.6", + "outputs_to_install": [ + "out" + ], + "outputs": { + "out": "/nix/store/2g2qvzgjq5zw0l146va1p03mc39xy4va-mask-0.11.6" + }, + "system": "aarch64-darwin", + "group": "toplevel", + "priority": 5 + }, + { + "attr_path": "mask", + "broken": false, + "derivation": "/nix/store/a9zcvh6x9342bi0l5jyz1l8a74y0zdr7-mask-0.11.6.drv", + "description": "CLI task runner defined by a simple markdown file", + "install_id": "mask", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=c23193b943c6c689d70ee98ce3128239ed9e32d1", + "name": "mask-0.11.6", + "pname": "mask", + "rev": "c23193b943c6c689d70ee98ce3128239ed9e32d1", + "rev_count": 861038, + "rev_date": "2025-09-13T06:43:22Z", + "scrape_date": "2025-09-15T02:30:13.874798Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "0.11.6", + "outputs_to_install": [ + "out" + ], + "outputs": { + "out": "/nix/store/yr6j5mkqg8y7w1c9v8kv1crjy83vnpby-mask-0.11.6" + }, + "system": "aarch64-linux", + "group": "toplevel", + "priority": 5 + }, + { + "attr_path": "mask", + "broken": false, + "derivation": "/nix/store/8lhg3symrg5ggm6d8jpk0hj1c55cnsyy-mask-0.11.6.drv", + "description": "CLI task runner defined by a simple markdown file", + "install_id": "mask", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=c23193b943c6c689d70ee98ce3128239ed9e32d1", + "name": "mask-0.11.6", + "pname": "mask", + "rev": "c23193b943c6c689d70ee98ce3128239ed9e32d1", + "rev_count": 861038, + "rev_date": "2025-09-13T06:43:22Z", + "scrape_date": "2025-09-15T02:56:55.479919Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "0.11.6", + "outputs_to_install": [ + "out" + ], + "outputs": { + "out": "/nix/store/xdassl0ji7vs1zljimqa1n24y8myvw7p-mask-0.11.6" + }, + "system": "x86_64-darwin", + "group": "toplevel", + "priority": 5 + }, + { + "attr_path": "mask", + "broken": false, + "derivation": "/nix/store/72yafxkmwfh8gaan7nkzd3m9pqlv4rry-mask-0.11.6.drv", + "description": "CLI task runner defined by a simple markdown file", + "install_id": "mask", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=c23193b943c6c689d70ee98ce3128239ed9e32d1", + "name": "mask-0.11.6", + "pname": "mask", + "rev": "c23193b943c6c689d70ee98ce3128239ed9e32d1", + "rev_count": 861038, + "rev_date": "2025-09-13T06:43:22Z", + "scrape_date": "2025-09-15T03:27:32.404732Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "0.11.6", + "outputs_to_install": [ + "out" + ], + "outputs": { + "out": "/nix/store/8h9hdwpgzknxadrz6fwinix214dbvm4c-mask-0.11.6" + }, + "system": "x86_64-linux", + "group": "toplevel", + "priority": 5 + }, { "attr_path": "mise", "broken": false, diff --git a/.flox/env/manifest.toml b/.flox/env/manifest.toml index 1f34f03f5..299c21235 100644 --- a/.flox/env/manifest.toml +++ b/.flox/env/manifest.toml @@ -13,12 +13,12 @@ fselect.pkg-path = "fselect" google-cloud-sdk.pkg-path = "google-cloud-sdk" awscli2.pkg-path = "awscli2" gum.pkg-path = "gum" +mask.pkg-path = "mask" bacon.pkg-path = "bacon" cargo-watch.pkg-path = "cargo-watch" cargo-nextest.pkg-path = "cargo-nextest" cargo.pkg-path = "cargo" - [hook] on-activate = ''' ''' @@ -26,9 +26,4 @@ on-activate = ''' [build] [options] -systems = [ - "aarch64-darwin", - "aarch64-linux", - "x86_64-darwin", - "x86_64-linux", -] +systems = ["aarch64-darwin", "aarch64-linux", "x86_64-darwin", "x86_64-linux"] diff --git a/applications/datamanager/src/datamanager/alpaca_client.py b/applications/datamanager/src/datamanager/alpaca_client.py deleted file mode 100644 index a0d9b465f..000000000 --- a/applications/datamanager/src/datamanager/alpaca_client.py +++ /dev/null @@ -1,170 +0,0 @@ -import math -import time -from datetime import UTC, date -from typing import TYPE_CHECKING, cast -from zoneinfo import ZoneInfo - -import polars as pl -from alpaca.data.enums import DataFeed -from alpaca.data.historical import StockHistoricalDataClient -from alpaca.data.requests import StockSnapshotRequest -from alpaca.trading import TradingClient -from alpaca.trading.enums import AssetClass, AssetStatus -from alpaca.trading.requests import GetAssetsRequest -from internal.equity_bar import equity_bar_schema -from loguru import logger - -if TYPE_CHECKING: - from alpaca.data.models import Snapshot - from alpaca.trading.models import Asset - - -class AlpacaClient: - def __init__( - self, - api_key: str, - api_secret: str, - is_paper: bool, # noqa: FBT001 - ) -> None: - self.rate_limit_sleep = 0.5 # seconds - - self.historical_client = StockHistoricalDataClient( - api_key=api_key, - secret_key=api_secret, - sandbox=is_paper, - ) - - self.trading_client = TradingClient( - api_key=api_key, - secret_key=api_secret, - paper=is_paper, - ) - - def _get_tickers(self) -> list[str]: - try: - assets: list[Asset] = cast( - "list[Asset]", - self.trading_client.get_all_assets( - GetAssetsRequest( - status=AssetStatus.ACTIVE, - asset_class=AssetClass.US_EQUITY, - attributes="has_options", - ) - ), - ) - - except Exception as e: - logger.error(f"Error fetching Alpaca assets: {e}") - raise - - time.sleep(self.rate_limit_sleep) - - return [asset.symbol for asset in assets] - - def fetch_latest_data(self, current_date: date) -> pl.DataFrame: - tickers = self._get_tickers() - - equity_bars: list[dict[str, str | int | float | None]] = [] - chunk_size = 100 - failed_chunks = 0 - maximum_failed_chunks = len(tickers) // chunk_size // 4 - # determine total chunks (ceiling) and allow up to 25% failures, minimum 1 - n_chunks = max(1, (len(tickers) + chunk_size - 1) // chunk_size) - maximum_failed_chunks = max(1, math.ceil(n_chunks * 0.25)) # up to 25% - for i in range(0, len(tickers), chunk_size): - tickers_subset = tickers[i : i + chunk_size] - - try: - snapshots: dict[str, Snapshot] = cast( - "dict[str, Snapshot]", - self.historical_client.get_stock_snapshot( - StockSnapshotRequest( - symbol_or_symbols=tickers_subset, - feed=DataFeed("iex"), - ) - ), - ) - - processed_tickers = set() - missing_tickers = [] - - for snapshot in snapshots.values(): - if snapshot.daily_bar is None: - missing_tickers.append(snapshot.symbol) - continue - - processed_tickers.add(snapshot.symbol) - - daily_equity_bar = snapshot.daily_bar - - daily_equity_bar_timestamp = daily_equity_bar.timestamp - - # normalize naive timestamps to UTC to preserve the moment in time - if daily_equity_bar_timestamp.tzinfo is None: - utc_normalized_timestamp = daily_equity_bar_timestamp.replace( - tzinfo=UTC - ) - else: - utc_normalized_timestamp = daily_equity_bar_timestamp - - # create NY-localized datetime solely for date comparison - ny_localized_timestamp = utc_normalized_timestamp.astimezone( - ZoneInfo("America/New_York") - ) - daily_equity_bar_date = ny_localized_timestamp.date() - - if daily_equity_bar_date != current_date: - logger.info( - f"Skipping equity bar for {snapshot.symbol} on {daily_equity_bar_date}" # noqa: E501 - ) - continue - - equity_bars.append( - { - "ticker": snapshot.symbol, - "timestamp": int( - utc_normalized_timestamp.timestamp() * 1000 - ), # convert to milliseconds using UTC-normalized datetime - "open_price": float(daily_equity_bar.open), - "high_price": float(daily_equity_bar.high), - "low_price": float(daily_equity_bar.low), - "close_price": float(daily_equity_bar.close), - "volume": int(daily_equity_bar.volume), - "volume_weighted_average_price": ( - float(daily_equity_bar.vwap) - if daily_equity_bar.vwap is not None - else math.nan - ), - } - ) - - tickers_without_snapshots = ( - set(tickers_subset) - processed_tickers - set(missing_tickers) - ) - if tickers_without_snapshots: - logger.warning( - f"No snapshots available for tickers: {list(tickers_without_snapshots)}" # noqa: E501 - ) - - if missing_tickers: - logger.warning( - f"No daily_bar available for tickers: {missing_tickers}" - ) - - time.sleep(self.rate_limit_sleep) - - except Exception as e: - logger.error( - f"Error fetching Alpaca snapshots for chunk {i // chunk_size + 1}: {e}" # noqa: E501 - ) - failed_chunks += 1 - if failed_chunks > maximum_failed_chunks: - message = f"Too many chunk failures: {failed_chunks} chunks failed" - raise RuntimeError(message) from e - continue # continue with next chunk instead of raising - - logger.info( - f"Collected {len(equity_bars)} equity bar records from {len(tickers)} tickers" # noqa: E501 - ) - - return equity_bar_schema.validate(pl.DataFrame(equity_bars)) diff --git a/applications/datamanager/src/datamanager/s3_client.py b/applications/datamanager/src/datamanager/s3_client.py deleted file mode 100644 index 2ab8043ce..000000000 --- a/applications/datamanager/src/datamanager/s3_client.py +++ /dev/null @@ -1,107 +0,0 @@ -import re -from datetime import date - -import boto3 -import duckdb -import polars as pl -from internal.equity_bar import equity_bar_schema -from loguru import logger - - -class S3Client: - def __init__(self, data_bucket_name: str) -> None: - self.data_bucket_name = data_bucket_name - self.daily_equity_bars_path = f"s3://{self.data_bucket_name}/equity/bars/" - self.duckdb_connection = duckdb.connect() - self._setup_s3_access() - - def _setup_s3_access(self) -> None: - region = boto3.Session().region_name or "us-east-1" - - if not re.match(r"^[a-z0-9-]+$", region): - message = f"Invalid S3 region format: {region}" - raise ValueError(message) - - self.duckdb_connection.execute(f""" - INSTALL httpfs; - LOAD httpfs; - SET s3_region='{region}'; - """) - - def close(self) -> None: - if self.duckdb_connection: - self.duckdb_connection.close() - - def __enter__(self): # noqa: ANN204 - return self - - def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001, ANN204 - _ = exc_type, exc_val, exc_tb - - self.close() - - def write_equity_bars_data(self, data: pl.DataFrame) -> None: - data = equity_bar_schema.validate(data) - - count = len(data) - if count > 0: - try: - data.with_columns( - pl.from_epoch(pl.col("timestamp"), time_unit="ms").alias("dt") - ).with_columns( - pl.col("dt").dt.year().alias("year"), - pl.col("dt").dt.month().alias("month"), - pl.col("dt").dt.day().alias("day"), - ).drop("dt").write_parquet( - file=self.daily_equity_bars_path, - partition_by=["year", "month", "day"], - ) - - except Exception as e: - logger.error( - f"Error writing equity bars data to bucket '{self.data_bucket_name}' " # noqa: E501 - f"at path '{self.daily_equity_bars_path}' with {count} rows: {e}" - ) - raise - - def read_equity_bars_data(self, start_date: date, end_date: date) -> pl.DataFrame: - path_pattern = f"s3://{self.data_bucket_name}/equity/bars/**/*.parquet" - - query = """ - SELECT * - FROM read_parquet( - ?, - HIVE_PARTITIONING=1 - ) - WHERE - (year > ? OR - (year = ? AND month > ?) OR - (year = ? AND month = ? - AND day >= ?)) - AND - (year < ? OR - (year = ? AND month < ?) OR - (year = ? - AND month = ? - AND day <= ?)) - """ - - params = ( - path_pattern, - start_date.year, - start_date.year, - start_date.month, - start_date.year, - start_date.month, - start_date.day, - end_date.year, - end_date.year, - end_date.month, - end_date.year, - end_date.month, - end_date.day, - ) - - equity_bars = self.duckdb_connection.execute(query, params).pl() - - return equity_bar_schema.validate(equity_bars) diff --git a/applications/datamanager/tests/test_alpaca_client.py b/applications/datamanager/tests/test_alpaca_client.py deleted file mode 100644 index 3b971091e..000000000 --- a/applications/datamanager/tests/test_alpaca_client.py +++ /dev/null @@ -1,159 +0,0 @@ -from datetime import date, datetime -from unittest.mock import MagicMock, patch -from zoneinfo import ZoneInfo - -import polars as pl -import pytest -from datamanager.alpaca_client import AlpacaClient - - -def test_alpaca_client_fetch_latest_data() -> None: - with ( - patch( - "datamanager.alpaca_client.StockHistoricalDataClient" - ) as mock_historical_client, - patch("datamanager.alpaca_client.TradingClient") as mock_trading_client, - patch("time.sleep") as mock_sleep, # noqa: F841 - ): - mock_trading_instance = MagicMock() - mock_trading_client.return_value = mock_trading_instance - - mock_asset1 = MagicMock() - mock_asset1.symbol = "AAPL" - mock_asset2 = MagicMock() - mock_asset2.symbol = "GOOGL" - - mock_trading_instance.get_all_assets.return_value = [mock_asset1, mock_asset2] - - mock_historical_instance = MagicMock() - mock_historical_client.return_value = mock_historical_instance - - open_price = 150.0 - high_price = 155.0 - low_price = 149.0 - close_price = 153.0 - volume = 1000000 - vwap = 152.5 - - timezone = ZoneInfo("America/New_York") - - current_date = date(2024, 1, 15) - mock_daily_bar = MagicMock() - mock_daily_bar.timestamp = datetime(2024, 1, 15, 16, 0, 0, tzinfo=timezone) - mock_daily_bar.open = open_price - mock_daily_bar.high = high_price - mock_daily_bar.low = low_price - mock_daily_bar.close = close_price - mock_daily_bar.volume = volume - mock_daily_bar.vwap = vwap - - mock_snapshot = MagicMock() - mock_snapshot.symbol = "AAPL" - mock_snapshot.daily_bar = mock_daily_bar - - mock_snapshots = {"AAPL": mock_snapshot} - mock_historical_instance.get_stock_snapshot.return_value = mock_snapshots - - alpaca_client = AlpacaClient("test-key", "test-secret", True) # noqa: FBT003 - - result = alpaca_client.fetch_latest_data(current_date) - - assert isinstance(result, pl.DataFrame) - assert len(result) == 1 - - row = result.row(0, named=True) - assert row["ticker"] == "AAPL" - assert row["open_price"] == open_price - assert row["high_price"] == high_price - assert row["low_price"] == low_price - assert row["close_price"] == close_price - assert row["volume"] == volume - assert row["volume_weighted_average_price"] == vwap - - expected_timestamp = int( - datetime(2024, 1, 15, 16, 0, 0, tzinfo=timezone).timestamp() * 1000 - ) - assert row["timestamp"] == expected_timestamp - - -def test_alpaca_client_fetch_latest_data_no_daily_bar() -> None: - with ( - patch( - "datamanager.alpaca_client.StockHistoricalDataClient" - ) as mock_historical_client, - patch("datamanager.alpaca_client.TradingClient") as mock_trading_client, - patch("time.sleep") as mock_sleep, # noqa: F841 - ): - mock_trading_instance = MagicMock() - mock_trading_client.return_value = mock_trading_instance - - mock_asset = MagicMock() - mock_asset.symbol = "TEST" - mock_trading_instance.get_all_assets.return_value = [mock_asset] - - mock_historical_instance = MagicMock() - mock_historical_client.return_value = mock_historical_instance - - mock_snapshot = MagicMock() - mock_snapshot.symbol = "TEST" - mock_snapshot.daily_bar = None - - mock_snapshots = {"TEST": mock_snapshot} - mock_historical_instance.get_stock_snapshot.return_value = mock_snapshots - - alpaca_client = AlpacaClient("test-key", "test-secret", True) # noqa: FBT003 - - current_date = date(2024, 1, 15) - - with pytest.raises( - pl.exceptions.ColumnNotFoundError, match='unable to find column "ticker"' - ): - alpaca_client.fetch_latest_data(current_date) - - -def test_alpaca_client_fetch_latest_data_wrong_date() -> None: - with ( - patch( - "datamanager.alpaca_client.StockHistoricalDataClient" - ) as mock_historical_client, - patch("datamanager.alpaca_client.TradingClient") as mock_trading_client, - patch("time.sleep") as mock_sleep, # noqa: F841 - ): - mock_trading_instance = MagicMock() - mock_trading_client.return_value = mock_trading_instance - - mock_asset = MagicMock() - mock_asset.symbol = "AAPL" - mock_trading_instance.get_all_assets.return_value = [mock_asset] - - mock_historical_instance = MagicMock() - mock_historical_client.return_value = mock_historical_instance - - timezone = ZoneInfo("America/New_York") - - mock_daily_bar = MagicMock() - mock_daily_bar.timestamp = datetime( - 2024, 1, 14, 16, 0, 0, tzinfo=timezone - ) # Wrong date - mock_daily_bar.open = 150.0 - mock_daily_bar.high = 155.0 - mock_daily_bar.low = 149.0 - mock_daily_bar.close = 153.0 - mock_daily_bar.volume = 1000000 - mock_daily_bar.vwap = 152.5 - - mock_snapshot = MagicMock() - mock_snapshot.symbol = "AAPL" - mock_snapshot.daily_bar = mock_daily_bar - - mock_snapshots = {"AAPL": mock_snapshot} - mock_historical_instance.get_stock_snapshot.return_value = mock_snapshots - - alpaca_client = AlpacaClient("test-key", "test-secret", True) # noqa: FBT003 - - current_date = date(2024, 1, 15) - - with pytest.raises( - pl.exceptions.ColumnNotFoundError, match='unable to find column "ticker"' - ): - alpaca_client.fetch_latest_data(current_date) diff --git a/applications/datamanager/tests/test_s3_client.py b/applications/datamanager/tests/test_s3_client.py deleted file mode 100644 index 926076572..000000000 --- a/applications/datamanager/tests/test_s3_client.py +++ /dev/null @@ -1,100 +0,0 @@ -from datetime import date -from unittest.mock import MagicMock, patch - -import polars as pl -import pytest -from datamanager.s3_client import S3Client - - -def test_s3_client_write_equity_bars_data() -> None: - with ( - patch("boto3.Session") as mock_session, - patch("duckdb.connect") as mock_duckdb_connect, - patch("polars.DataFrame.write_parquet") as mock_write_parquet, - ): - mock_session.return_value.region_name = "us-east-1" - mock_duckdb_conn = MagicMock() - mock_duckdb_connect.return_value = mock_duckdb_conn - - s3_client = S3Client("test-bucket") - - test_data = pl.DataFrame( - { - "ticker": ["AAPL", "GOOGL"], - "timestamp": [ - 1640995200000, - 1641081600000, - ], # 2022-01-01, 2022-01-02 in ms - "open_price": [150.0, 2800.0], - "high_price": [155.0, 2850.0], - "low_price": [149.0, 2790.0], - "close_price": [153.0, 2820.0], - "volume": [1000000, 500000], - "volume_weighted_average_price": [152.5, 2815.0], - } - ) - - s3_client.write_equity_bars_data(test_data) - - mock_write_parquet.assert_called_once() - call_args = mock_write_parquet.call_args - assert call_args[1]["file"] == "s3://test-bucket/equity/bars/" - assert call_args[1]["partition_by"] == ["year", "month", "day"] - - -def test_s3_client_read_equity_bars_data() -> None: - with ( - patch("boto3.Session") as mock_session, - patch("duckdb.connect") as mock_duckdb_connect, - ): - mock_session.return_value.region_name = "us-east-1" - mock_duckdb_conn = MagicMock() - mock_duckdb_connect.return_value = mock_duckdb_conn - - mock_result = MagicMock() - mock_result.pl.return_value = pl.DataFrame( - { - "ticker": ["AAPL", "GOOGL"], - "timestamp": [1640995200000, 1641081600000], - "open_price": [150.0, 2800.0], - "high_price": [155.0, 2850.0], - "low_price": [149.0, 2790.0], - "close_price": [153.0, 2820.0], - "volume": [1000000, 500000], - "volume_weighted_average_price": [152.5, 2815.0], - "year": [2022, 2022], - "month": [1, 1], - "day": [1, 2], - } - ) - - mock_duckdb_conn.execute.return_value = mock_result - - s3_client = S3Client("test-bucket") - - start_date = date(2022, 1, 1) - end_date = date(2022, 1, 2) - result = s3_client.read_equity_bars_data(start_date, end_date) - - assert mock_duckdb_conn.execute.call_count == 2 # noqa: PLR2004 - - assert isinstance(result, pl.DataFrame) - - -def test_s3_client_write_equity_bars_data_empty_dataframe() -> None: - with ( - patch("boto3.Session") as mock_session, - patch("duckdb.connect") as mock_duckdb_connect, - ): - mock_session.return_value.region_name = "us-east-1" - mock_duckdb_conn = MagicMock() - mock_duckdb_connect.return_value = mock_duckdb_conn - - s3_client = S3Client("test-bucket") - - empty_data = pl.DataFrame() - - with pytest.raises( - pl.exceptions.ColumnNotFoundError, match='unable to find column "ticker"' - ): - s3_client.write_equity_bars_data(empty_data) diff --git a/applications/portfoliomanager/pyproject.toml b/applications/portfoliomanager/pyproject.toml index 88430ae10..3ccf69dde 100644 --- a/applications/portfoliomanager/pyproject.toml +++ b/applications/portfoliomanager/pyproject.toml @@ -4,12 +4,17 @@ version = "0.1.0" description = "Portfolio prediction and construction service" requires-python = "==3.12.10" dependencies = [ + "internal", "fastapi>=0.116.1", "uvicorn>=0.35.0", "httpx>=0.27.0", + "pandera[polars,pandas]>=0.26.0", + "alpaca-py>=0.42.1", ] [tool.uv] package = true src = ["src"] +[tool.uv.sources] +internal = { workspace = true } diff --git a/applications/portfoliomanager/src/portfoliomanager/alpaca_client.py b/applications/portfoliomanager/src/portfoliomanager/alpaca_client.py new file mode 100644 index 000000000..4ab6b134a --- /dev/null +++ b/applications/portfoliomanager/src/portfoliomanager/alpaca_client.py @@ -0,0 +1,111 @@ +import time +from typing import cast + +import pandera.polars as pa +import polars as pl +from alpaca.trading import Position, TradeAccount, TradingClient + + +class AlpacaAccount: + def __init__( + self, + cash_amount: float, + positions: pl.DataFrame, + ) -> None: + self.cash_amount = cash_amount + + position_schema.validate(positions) + + self.positions = positions + + +class AlpacaClient: + def __init__( + self, + api_key: str, + api_secret: str, + is_paper: bool, # noqa: FBT001 + ) -> None: + self.rate_limit_sleep = 0.5 # seconds + + self.trading_client = TradingClient( + api_key=api_key, + secret_key=api_secret, + paper=is_paper, + ) + + def get_account(self) -> AlpacaAccount: + account: TradeAccount = cast("TradeAccount", self.trading_client.get_account()) + + time.sleep(self.rate_limit_sleep) + + account_positions: list[Position] = cast( + "list[Position]", self.trading_client.get_all_positions() + ) + if not account_positions: + time.sleep(self.rate_limit_sleep) + empty_positions = pl.DataFrame( + { + "ticker": pl.Series([], dtype=pl.String), + "side": pl.Series([], dtype=pl.String), + "dollar_amount": pl.Series([], dtype=pl.Float64), + "share_amount": pl.Series([], dtype=pl.Float64), + } + ) + + return AlpacaAccount( + cash_amount=float(cast("str", account.cash)), + positions=empty_positions, + ) + + position_data = [ + { + "ticker": account_position.symbol, + "side": str(account_position.side).replace("PositionSide.", "").upper(), + "dollar_amount": float(cast("str", account_position.market_value)), + "share_amount": float(cast("str", account_position.qty)), + } + for account_position in account_positions + ] + + time.sleep(self.rate_limit_sleep) + + positions = pl.DataFrame(position_data) + + position_schema.validate(positions) + + return AlpacaAccount( + cash_amount=float(cast("str", account.cash)), + positions=positions, + ) + + +def is_uppercase(data: pa.PolarsData) -> pl.LazyFrame: + return data.lazyframe.select( + pl.col(data.key).str.to_uppercase() == pl.col(data.key) + ) + + +position_schema = pa.DataFrameSchema( + { + "ticker": pa.Column( + dtype=str, + checks=[pa.Check(is_uppercase)], + ), + "side": pa.Column( + dtype=str, + checks=[ + pa.Check.isin(["LONG", "SHORT"]), + pa.Check(is_uppercase), + ], + ), + "dollar_amount": pa.Column( + dtype=float, + checks=[pa.Check.greater_than(0)], + ), + "share_amount": pa.Column( + dtype=float, + checks=[pa.Check.greater_than(0)], + ), + }, +) diff --git a/applications/portfoliomanager/src/portfoliomanager/risk_management.py b/applications/portfoliomanager/src/portfoliomanager/risk_management.py new file mode 100644 index 000000000..438e62694 --- /dev/null +++ b/applications/portfoliomanager/src/portfoliomanager/risk_management.py @@ -0,0 +1,384 @@ +import math +from datetime import UTC, datetime + +import polars as pl + + +def add_positions_action_column( + positions: pl.DataFrame, + current_datetime: datetime, +) -> pl.DataFrame: + positions = positions.clone() + + return positions.with_columns( + pl.when( + pl.col("timestamp") + .cast(pl.Float64) + .map_elements( + lambda ts: datetime.fromtimestamp(ts, tz=UTC).date(), + return_dtype=pl.Date, + ) + == current_datetime.date() + ) + .then(pl.lit("PDT_LOCKED")) + .otherwise(pl.lit("UNSPECIFIED")) + .alias("action") + ) + + +def add_equity_bars_returns_and_realized_volatility_columns( + equity_bars: pl.DataFrame, +) -> pl.DataFrame: + equity_bars = equity_bars.clone() + + minimum_bars_per_ticker_required = 30 + + ticker_counts = equity_bars.group_by("ticker").agg(pl.len().alias("count")) + insufficient_tickers = ticker_counts.filter( + pl.col("count") < minimum_bars_per_ticker_required + ) + + if insufficient_tickers.height > 0: + insufficient_list = insufficient_tickers.select("ticker").to_series().to_list() + message = f"Tickers with insufficient data (< {minimum_bars_per_ticker_required} rows): {insufficient_list}" # noqa: E501 + raise ValueError(message) + + equity_bars = equity_bars.sort(["ticker", "timestamp"]) + daily_returns = pl.col("close_price").pct_change().over("ticker") + return equity_bars.with_columns( + pl.when(pl.col("close_price").is_not_null()) + .then(daily_returns) + .otherwise(None) + .alias("daily_returns"), + pl.when(pl.col("close_price").is_not_null()) + .then( + pl.when((daily_returns + 1) > 0) + .then((daily_returns + 1).log()) + .otherwise(None) + ) + .otherwise(None) + .alias("log_daily_returns"), + daily_returns.rolling_std(window_size=minimum_bars_per_ticker_required).alias( + "realized_volatility" + ), + ) + + +def add_positions_performance_columns( + positions: pl.DataFrame, + original_predictions: pl.DataFrame, # per original position ticker and timestamp + original_equity_bars: pl.DataFrame, # per original position ticker and timestamp + current_timestamp: datetime, +) -> pl.DataFrame: + positions = positions.clone() + original_predictions = original_predictions.clone() + original_equity_bars = original_equity_bars.clone() + + position_predictions = positions.join( + other=original_predictions, + on=["ticker", "timestamp"], + how="left", + ).select( + pl.col("ticker"), + pl.col("timestamp"), + pl.col("side"), + pl.col("dollar_amount"), + pl.col("action"), + pl.col("quantile_10").alias("original_lower_threshold"), + pl.col("quantile_90").alias("original_upper_threshold"), + ) + + original_equity_bars_with_returns = original_equity_bars.sort( + ["ticker", "timestamp"] + ) + + position_returns = [] + + for row in position_predictions.iter_rows(named=True): + ticker = row["ticker"] + position_timestamp = row["timestamp"] + + ticker_bars = original_equity_bars_with_returns.filter( + (pl.col("ticker") == ticker) + & (pl.col("timestamp") >= position_timestamp) + & (pl.col("timestamp") <= current_timestamp.timestamp()) + ) + + cumulative_log_return = ( + ticker_bars.select(pl.col("log_daily_returns").sum()).item() or 0 + ) + + cumulative_simple_return = math.exp(cumulative_log_return) - 1 + + position_returns.append( + { + "ticker": ticker, + "timestamp": position_timestamp, + "cumulative_simple_return": cumulative_simple_return, + } + ) + + returns = pl.DataFrame(position_returns) + + positions_with_data = position_predictions.join( + other=returns, + on=["ticker", "timestamp"], + how="left", + ) + + return positions_with_data.with_columns( + pl.when(pl.col("action") == "PDT_LOCKED") + .then(pl.lit("PDT_LOCKED")) + .when( + (pl.col("action") != "PDT_LOCKED") + & ( + ( + (pl.col("side") == "LONG") + & ( + pl.col("cumulative_simple_return") + <= pl.col("original_lower_threshold") + ) + ) + | ( + (pl.col("side") == "SHORT") + & ( + pl.col("cumulative_simple_return") + >= pl.col("original_upper_threshold") + ) + ) + ) + ) + .then(pl.lit("CLOSE_POSITION")) + .when( + ( + (pl.col("side") == "LONG") + & ( + pl.col("cumulative_simple_return") + >= pl.col("original_upper_threshold") + ) + ) + | ( + (pl.col("side") == "SHORT") + & ( + pl.col("cumulative_simple_return") + <= pl.col("original_lower_threshold") + ) + ) + ) + .then(pl.lit("MAINTAIN_POSITION")) + .otherwise(pl.lit("UNSPECIFIED")) + .alias("action") + ).drop( + [ + "original_lower_threshold", + "original_upper_threshold", + "cumulative_simple_return", + ] + ) + + +def add_predictions_zscore_ranked_columns(predictions: pl.DataFrame) -> pl.DataFrame: + predictions = predictions.clone() + + quantile_50_mean = predictions.select(pl.col("quantile_50").mean()).item() + quantile_50_standard_deviation = ( + predictions.select(pl.col("quantile_50").std()).item() or 1e-8 + ) + + z_score_return = ( + pl.col("quantile_50") - quantile_50_mean + ) / quantile_50_standard_deviation + + inter_quartile_range = pl.col("quantile_90") - pl.col("quantile_10") + + composite_score = z_score_return / (1 + inter_quartile_range) + + return predictions.with_columns( + z_score_return.alias("z_score_return"), + inter_quartile_range.alias("inter_quartile_range"), + composite_score.alias("composite_score"), + pl.lit("UNSPECIFIED").alias("action"), + ).sort(["composite_score", "inter_quartile_range"], descending=[True, False]) + + +def create_optimal_portfolio( + predictions: pl.DataFrame, + positions: pl.DataFrame, + maximum_capital: float, + current_timestamp: datetime, +) -> pl.DataFrame: + predictions = predictions.clone() + positions = positions.clone() + + minimum_inter_quartile_range = 0.75 + high_uncertainty_tickers = ( + predictions.filter( + pl.col("inter_quartile_range") > minimum_inter_quartile_range + ) + .select("ticker") + .to_series() + .to_list() + ) + + closed_positions, maintained_positions = _filter_positions(positions) + + closed_position_tickers = closed_positions.select("ticker").to_series().to_list() + maintained_position_tickers = ( + maintained_positions.select("ticker").to_series().to_list() + ) + + excluded_tickers = ( + high_uncertainty_tickers + closed_position_tickers + maintained_position_tickers + ) + + available_predictions = predictions.filter( + ~pl.col("ticker").is_in(excluded_tickers) + ) + + maintained_long_capital = _filter_side_capital_amount(maintained_positions, "LONG") + maintained_short_capital = _filter_side_capital_amount( + maintained_positions, "SHORT" + ) + closed_long_capital = _filter_side_capital_amount(closed_positions, "LONG") + closed_short_capital = _filter_side_capital_amount(closed_positions, "SHORT") + + target_side_capital = maximum_capital / 2 + available_long_capital = max( + 0.0, + target_side_capital - maintained_long_capital + closed_long_capital, + ) + available_short_capital = max( + 0.0, + target_side_capital - maintained_short_capital + closed_short_capital, + ) + + maintained_long_count = maintained_positions.filter(pl.col("side") == "LONG").height + maintained_short_count = maintained_positions.filter( + pl.col("side") == "SHORT" + ).height + + new_long_positions_needed = max(0, 10 - maintained_long_count) + new_short_positions_needed = max(0, 10 - maintained_short_count) + + total_available = available_predictions.height + maximum_long_candidates = min(new_long_positions_needed, total_available // 2) + maximum_short_candidates = min( + new_short_positions_needed, total_available - maximum_long_candidates + ) + + long_candidates = available_predictions.head(maximum_long_candidates) + short_candidates = available_predictions.tail(maximum_short_candidates) + + dollar_amount_per_long = ( + available_long_capital / maximum_long_candidates + if maximum_long_candidates > 0 + else 0 + ) + dollar_amount_per_short = ( + available_short_capital / maximum_short_candidates + if maximum_short_candidates > 0 + else 0 + ) + + long_positions = long_candidates.select( + pl.col("ticker"), + pl.lit(current_timestamp.timestamp()).cast(pl.Float64).alias("timestamp"), + pl.lit("LONG").alias("side"), + pl.lit(dollar_amount_per_long).alias("dollar_amount"), + pl.lit("UNSPECIFIED").alias("action"), + ) + + short_positions = short_candidates.select( + pl.col("ticker"), + pl.lit(current_timestamp.timestamp()).cast(pl.Float64).alias("timestamp"), + pl.lit("SHORT").alias("side"), + pl.lit(dollar_amount_per_short).alias("dollar_amount"), + pl.lit("UNSPECIFIED").alias("action"), + ) + + return _collect_portfolio_positions( + long_positions, + short_positions, + maintained_positions, + ) + + +def _filter_positions(positions: pl.DataFrame) -> tuple[pl.DataFrame, pl.DataFrame]: + positions = positions.clone() + + if positions.height == 0: + return ( + pl.DataFrame( + { + "ticker": [], + "timestamp": [], + "side": [], + "dollar_amount": [], + "action": [], + } + ), + pl.DataFrame( + { + "ticker": [], + "timestamp": [], + "side": [], + "dollar_amount": [], + "action": [], + } + ), + ) + + closed_positions = positions.filter(pl.col("action") == "CLOSE_POSITION") + maintained_positions = positions.filter(pl.col("action") == "MAINTAIN_POSITION") + + return closed_positions, maintained_positions + + +def _filter_side_capital_amount(positions: pl.DataFrame, side: str) -> float: + positions = positions.clone() + + filtered_positions = positions.filter(pl.col("side") == side.upper()) + + if filtered_positions.height == 0: + return 0.0 + + try: + side_capital_amount = filtered_positions.select(pl.sum("dollar_amount")).item() + return float(side_capital_amount or 0) + + except Exception: # noqa: BLE001 + return 0.0 + + +def _collect_portfolio_positions( + long_positions: pl.DataFrame, + short_positions: pl.DataFrame, + maintained_positions: pl.DataFrame, +) -> pl.DataFrame: + long_positions = long_positions.clone() + short_positions = short_positions.clone() + maintained_positions = maintained_positions.clone() + + portfolio_components = [] + + if long_positions.height > 0: + portfolio_components.append(long_positions) + if short_positions.height > 0: + portfolio_components.append(short_positions) + if maintained_positions.height > 0: + portfolio_components.append( + maintained_positions.with_columns(pl.col("timestamp").cast(pl.Float64)) + ) + + if len(portfolio_components) == 0: + message = "No portfolio components to create an optimal portfolio." + raise ValueError(message) + + optimal_portfolio = pl.concat(portfolio_components) + + return optimal_portfolio.select( + "ticker", + pl.col("timestamp").cast(pl.Float64), + "side", + "dollar_amount", + ).sort(["ticker", "side"]) diff --git a/applications/portfoliomanager/tests/test_risk_management.py b/applications/portfoliomanager/tests/test_risk_management.py new file mode 100644 index 000000000..26a102162 --- /dev/null +++ b/applications/portfoliomanager/tests/test_risk_management.py @@ -0,0 +1,731 @@ +from datetime import UTC, datetime + +import polars as pl +import pytest +from portfoliomanager.risk_management import ( + add_equity_bars_returns_and_realized_volatility_columns, + add_positions_action_column, + add_positions_performance_columns, + add_predictions_zscore_ranked_columns, + create_optimal_portfolio, +) + + +def test_add_positions_action_column_same_day_positions_locked() -> None: + current_datetime = datetime(2024, 1, 15, 0, 0, 0, 0, tzinfo=UTC) + positions = pl.DataFrame( + { + "ticker": ["AAPL", "GOOGL"], + "timestamp": [ + datetime(2024, 1, 15, 9, 30, tzinfo=UTC).timestamp(), + datetime(2024, 1, 15, 14, 0, tzinfo=UTC).timestamp(), + ], + "side": ["LONG", "SHORT"], + "dollar_amount": [1000.0, 1000.0], + } + ) + + result = add_positions_action_column(positions, current_datetime) + + assert all(action == "PDT_LOCKED" for action in result["action"].to_list()) + assert len(result) == 2 # noqa: PLR2004 + + +def test_add_positions_action_column_previous_day_positions_unlocked() -> None: + current_datetime = datetime(2024, 1, 15, 0, 0, 0, 0, tzinfo=UTC) + positions = pl.DataFrame( + { + "ticker": ["AAPL", "GOOGL"], + "timestamp": [ + datetime(2024, 1, 14, 9, 30, tzinfo=UTC).timestamp(), + datetime(2024, 1, 13, 14, 0, tzinfo=UTC).timestamp(), + ], + "side": ["LONG", "SHORT"], + "dollar_amount": [1000.0, 1000.0], + } + ) + + result = add_positions_action_column(positions, current_datetime) + + assert all(action == "UNSPECIFIED" for action in result["action"].to_list()) + assert len(result) == 2 # noqa: PLR2004 + + +def test_add_positions_action_column_mixed_dates() -> None: + current_datetime = datetime(2024, 1, 15, 0, 0, 0, 0, tzinfo=UTC) + positions = pl.DataFrame( + { + "ticker": ["AAPL", "GOOGL", "TSLA"], + "timestamp": [ + datetime(2024, 1, 15, 9, 30, tzinfo=UTC).timestamp(), # same day + datetime(2024, 1, 14, 14, 0, tzinfo=UTC).timestamp(), # previous day + datetime(2024, 1, 15, 16, 0, tzinfo=UTC).timestamp(), # same day + ], + "side": ["LONG", "SHORT", "LONG"], + "dollar_amount": [1000.0, 1000.0, 1000.0], + } + ) + + result = add_positions_action_column(positions, current_datetime) + + expected_actions = ["PDT_LOCKED", "UNSPECIFIED", "PDT_LOCKED"] + assert result["action"].to_list() == expected_actions + + +def test_add_positions_action_column_empty_dataframe() -> None: + current_datetime = datetime(2024, 1, 15, 0, 0, 0, 0, tzinfo=UTC) + positions = pl.DataFrame( + {"ticker": [], "timestamp": [], "side": [], "dollar_amount": []} + ) + + result = add_positions_action_column(positions, current_datetime) + + assert len(result) == 0 + assert "action" in result.columns + + +def test_add_equity_bars_returns_and_realized_volatility_columns_sufficient_data_success() -> ( # noqa: E501 + None +): + equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 35, + "timestamp": [ + datetime(2024, 1, i + 1, tzinfo=UTC).timestamp() for i in range(31) + ] + + [datetime(2024, 2, i + 1, tzinfo=UTC).timestamp() for i in range(4)], + "close_price": list(range(100, 135)), # increasing prices + } + ) + + result = add_equity_bars_returns_and_realized_volatility_columns(equity_bars) + + assert "daily_returns" in result.columns + assert "log_daily_returns" in result.columns + assert "realized_volatility" in result.columns + assert len(result) == 35 # noqa: PLR2004 + + +def test_add_equity_bars_returns_and_realized_volatility_columns_insufficient_data_raises_error() -> ( # noqa: E501 + None +): + equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 25, # only 25 bars, need 30 + "timestamp": [ + datetime(2024, 1, i + 1, tzinfo=UTC).timestamp() for i in range(25) + ], + "close_price": list(range(100, 125)), + } + ) + + with pytest.raises(ValueError, match="Tickers with insufficient data"): + add_equity_bars_returns_and_realized_volatility_columns(equity_bars) + + +def test_add_equity_bars_returns_and_realized_volatility_columns_multiple_tickers_mixed_data() -> ( # noqa: E501 + None +): + equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 35 + ["GOOGL"] * 25, # AAPL has enough, GOOGL does not + "timestamp": [ + datetime(2024, 1, i + 1, tzinfo=UTC).timestamp() for i in range(31) + ] + + [datetime(2024, 2, i + 1, tzinfo=UTC).timestamp() for i in range(4)] + + [datetime(2024, 2, i + 1, tzinfo=UTC).timestamp() for i in range(25)], + "close_price": list(range(100, 135)) + list(range(200, 225)), + } + ) + + with pytest.raises(ValueError, match="GOOGL"): + add_equity_bars_returns_and_realized_volatility_columns(equity_bars) + + +def test_add_equity_bars_returns_grouped_per_ticker() -> None: + base_timestamp = datetime(2024, 1, 1, tzinfo=UTC).timestamp() + + aapl_data = [] + googl_data = [] + + for i in range(30): + timestamp = base_timestamp + (i * 86400) + aapl_price = 100.0 + i # AAPL prices increase + googl_price = 200.0 - i * 0.5 # GOOGL prices decrease slightly + + aapl_data.append( + {"ticker": "AAPL", "timestamp": timestamp, "close_price": aapl_price} + ) + googl_data.append( + {"ticker": "GOOGL", "timestamp": timestamp, "close_price": googl_price} + ) + + all_data = [] + for i in range(30): + all_data.append(aapl_data[i]) + all_data.append(googl_data[i]) + + equity_bars = pl.DataFrame(all_data) + + out = add_equity_bars_returns_and_realized_volatility_columns(equity_bars) + aapl = out.filter(pl.col("ticker") == "AAPL").sort("timestamp") + googl = out.filter(pl.col("ticker") == "GOOGL").sort("timestamp") + + aapl_returns = aapl["daily_returns"].to_list() + googl_returns = googl["daily_returns"].to_list() + + assert aapl_returns[0] is None + assert aapl_returns[1] == pytest.approx(0.01, abs=1e-9) + assert googl_returns[0] is None + assert googl_returns[1] == pytest.approx(-0.0025, abs=1e-9) + + aapl_log_returns = aapl["log_daily_returns"].to_list() + googl_log_returns = googl["log_daily_returns"].to_list() + + assert aapl_log_returns[0] is None + assert aapl_log_returns[1] == pytest.approx(0.00995, abs=1e-4) + assert googl_log_returns[0] is None + assert googl_log_returns[1] == pytest.approx(-0.00251, abs=1e-4) + + +def test_add_equity_bars_returns_and_realized_volatility_columns_null_prices_handled() -> ( # noqa: E501 + None +): + equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 35, + "timestamp": [ + datetime(2024, 1, i + 1, tzinfo=UTC).timestamp() for i in range(31) + ] + + [datetime(2024, 2, i + 1, tzinfo=UTC).timestamp() for i in range(4)], + "close_price": [ + 100.0, + None, + 102.0, + *list(range(103, 135)), + ], + } + ) + + result = add_equity_bars_returns_and_realized_volatility_columns(equity_bars) + + daily_returns = result["daily_returns"].to_list() + assert daily_returns[1] is None # second row should be null + + +def test_add_positions_performance_columns_long_position_outperforming() -> None: + base_timestamp = datetime(2024, 1, 10, tzinfo=UTC).timestamp() + + positions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "side": ["LONG"], + "dollar_amount": [1000.0], + "action": ["UNSPECIFIED"], + } + ) + + original_predictions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "quantile_10": [-0.05], # -5% lower threshold + "quantile_90": [0.15], # +15% upper threshold + } + ) + + raw_equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 30, + "timestamp": [base_timestamp + (i * 86400) for i in range(30)], + "close_price": [100.0 + (20.0 * i / 29) for i in range(30)], + } + ) + + original_equity_bars = add_equity_bars_returns_and_realized_volatility_columns( + raw_equity_bars + ) + current_timestamp = datetime.fromtimestamp(base_timestamp + (29 * 86400), tz=UTC) + + result = add_positions_performance_columns( + positions, original_predictions, original_equity_bars, current_timestamp + ) + + assert result["action"][0] == "MAINTAIN_POSITION" # 20% > 15% threshold + + +def test_add_positions_performance_columns_long_position_underperforming() -> None: + base_timestamp = datetime(2024, 1, 10, tzinfo=UTC).timestamp() + + positions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "side": ["LONG"], + "dollar_amount": [1000.0], + "action": ["UNSPECIFIED"], + } + ) + + original_predictions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "quantile_10": [-0.05], # -5% lower threshold + "quantile_90": [0.15], # +15% upper threshold + } + ) + + raw_equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 30, + "timestamp": [base_timestamp + (i * 86400) for i in range(30)], + "close_price": [100.0 - (10.0 * i / 29) for i in range(30)], + } + ) + + original_equity_bars = add_equity_bars_returns_and_realized_volatility_columns( + raw_equity_bars + ) + current_timestamp = datetime.fromtimestamp(base_timestamp + (29 * 86400), tz=UTC) + + result = add_positions_performance_columns( + positions, original_predictions, original_equity_bars, current_timestamp + ) + + assert result["action"][0] == "CLOSE_POSITION" # -10% < -5% threshold + + +def test_add_positions_performance_columns_short_position_outperforming() -> None: + base_timestamp = datetime(2024, 1, 10, tzinfo=UTC).timestamp() + + positions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "side": ["SHORT"], + "dollar_amount": [1000.0], + "action": ["UNSPECIFIED"], + } + ) + + original_predictions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "quantile_10": [-0.05], # -5% lower threshold + "quantile_90": [0.15], # +15% upper threshold + } + ) + + raw_equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 30, + "timestamp": [base_timestamp + (i * 86400) for i in range(30)], + "close_price": [100.0 - (10.0 * i / 29) for i in range(30)], + } + ) + + original_equity_bars = add_equity_bars_returns_and_realized_volatility_columns( + raw_equity_bars + ) + current_timestamp = datetime.fromtimestamp(base_timestamp + (29 * 86400), tz=UTC) + + result = add_positions_performance_columns( + positions, original_predictions, original_equity_bars, current_timestamp + ) + + assert ( + result["action"][0] == "MAINTAIN_POSITION" + ) # -10% <= -5% threshold (good for short) + + +def test_add_positions_performance_columns_pdt_locked_position_maintained() -> None: + base_timestamp = datetime(2024, 1, 10, tzinfo=UTC).timestamp() + + positions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "side": ["LONG"], + "dollar_amount": [1000.0], + "action": ["PDT_LOCKED"], # pdt locked + } + ) + + original_predictions = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [base_timestamp], + "quantile_10": [-0.05], + "quantile_90": [0.15], + } + ) + + raw_equity_bars = pl.DataFrame( + { + "ticker": ["AAPL"] * 30, + "timestamp": [base_timestamp + (i * 86400) for i in range(30)], + "close_price": [100.0 - (20.0 * i / 29) for i in range(30)], + } + ) + + original_equity_bars = add_equity_bars_returns_and_realized_volatility_columns( + raw_equity_bars + ) + current_timestamp = datetime.fromtimestamp(base_timestamp + (29 * 86400), tz=UTC) + + result = add_positions_performance_columns( + positions, original_predictions, original_equity_bars, current_timestamp + ) + + assert result["action"][0] == "PDT_LOCKED" # pdt locked overrides performance + + +def test_add_positions_performance_columns_multiple_tickers_independent() -> None: + current_timestamp = datetime(2024, 1, 10, tzinfo=UTC).timestamp() + + positions = pl.DataFrame( + { + "ticker": ["AAPL", "GOOGL"], + "timestamp": [current_timestamp, current_timestamp], + "side": ["LONG", "LONG"], + "dollar_amount": [1000.0, 1000.0], + "action": ["UNSPECIFIED", "UNSPECIFIED"], + } + ) + + predictions = pl.DataFrame( + { + "ticker": ["AAPL", "GOOGL"], + "timestamp": [current_timestamp, current_timestamp], + "quantile_10": [-0.05, -0.05], + "quantile_90": [0.15, 0.15], + } + ) + + aapl_data = [] + googl_data = [] + + for i in range(30): + timestamp = current_timestamp + (i * 86400) + aapl_price = 100.0 + (20.0 * i / 29) + googl_price = 200.0 - (20.0 * i / 29) + + aapl_data.append( + {"ticker": "AAPL", "timestamp": timestamp, "close_price": aapl_price} + ) + googl_data.append( + {"ticker": "GOOGL", "timestamp": timestamp, "close_price": googl_price} + ) + + all_data = [] + for i in range(30): + all_data.append(aapl_data[i]) + all_data.append(googl_data[i]) + + raw_equity_bars = pl.DataFrame(all_data) + + equity_bars = add_equity_bars_returns_and_realized_volatility_columns( + raw_equity_bars + ) + + out = add_positions_performance_columns( + positions, + predictions, + equity_bars, + datetime.fromtimestamp(current_timestamp + (29 * 86400), tz=UTC), # 30th day + ) + + assert out.filter(pl.col("ticker") == "AAPL")["action"][0] == "MAINTAIN_POSITION" + assert out.filter(pl.col("ticker") == "GOOGL")["action"][0] == "CLOSE_POSITION" + + +def test_add_predictions_zscore_ranked_columns_zscore_calculation() -> None: + predictions = pl.DataFrame( + { + "ticker": ["A", "B", "C"], + "quantile_10": [0.0, 0.0, 0.0], + "quantile_50": [0.05, 0.10, 0.15], # 5%, 10%, 15% expected returns + "quantile_90": [0.20, 0.20, 0.20], + } + ) + + result = add_predictions_zscore_ranked_columns(predictions) + + assert result["ticker"][0] == "C" # highest expected return + assert result["ticker"][2] == "A" # lowest expected return + + assert "z_score_return" in result.columns + assert "inter_quartile_range" in result.columns + assert "composite_score" in result.columns + + +def test_add_predictions_zscore_ranked_columns_inter_quartile_range_calculation() -> ( + None +): + predictions = pl.DataFrame( + { + "ticker": ["A", "B"], + "quantile_10": [0.05, 0.10], + "quantile_50": [0.10, 0.15], + "quantile_90": [0.15, 0.30], # a has narrow range, b has wide range + } + ) + + result = add_predictions_zscore_ranked_columns(predictions) + + assert result["ticker"][0] == "B" # higher expected return ranks first + assert result["inter_quartile_range"][0] == pytest.approx( + 0.20 + ) # 0.30 - 0.10 (B's range) + assert result["inter_quartile_range"][1] == pytest.approx( + 0.10 + ) # 0.15 - 0.05 (A's range) + + +def test_add_predictions_zscore_ranked_columns_single_prediction() -> None: + predictions = pl.DataFrame( + { + "ticker": ["AAPL"], + "quantile_10": [0.05], + "quantile_50": [0.10], + "quantile_90": [0.15], + } + ) + + result = add_predictions_zscore_ranked_columns(predictions) + + assert len(result) == 1 + assert result["z_score_return"][0] == 0.0 # single value has z-score of 0 + + +def test_create_optimal_portfolio_fresh_start_no_existing_positions() -> None: + predictions = pl.DataFrame( + { + "ticker": [f"STOCK{i}" for i in range(25)], + "composite_score": list(range(25, 0, -1)), # descending scores + "inter_quartile_range": [0.1] * 25, # low uncertainty + } + ) + + positions = pl.DataFrame( + { + "ticker": [], + "timestamp": [], + "side": [], + "dollar_amount": [], + "action": [], + } + ) + + result = create_optimal_portfolio( + predictions, positions, 20000.0, datetime.now(tz=UTC) + ) + + assert len(result) == 20 # 10 long + 10 short # noqa: PLR2004 + assert result.filter(pl.col("side") == "LONG").height == 10 # noqa: PLR2004 + assert result.filter(pl.col("side") == "SHORT").height == 10 # noqa: PLR2004 + + long_total = result.filter(pl.col("side") == "LONG")["dollar_amount"].sum() + short_total = result.filter(pl.col("side") == "SHORT")["dollar_amount"].sum() + assert abs(long_total - short_total) < 0.01 # noqa: PLR2004 + + +def test_create_optimal_portfolio_some_maintained_positions() -> None: + predictions = pl.DataFrame( + { + "ticker": [f"STOCK{i}" for i in range(25)], + "composite_score": list(range(25, 0, -1)), + "inter_quartile_range": [0.1] * 25, + } + ) + + positions = pl.DataFrame( + { + "ticker": ["STOCK1", "STOCK2", "STOCK24"], + "timestamp": [ + datetime(2024, 1, 10, tzinfo=UTC).timestamp(), + datetime(2024, 1, 11, tzinfo=UTC).timestamp(), + datetime(2024, 1, 12, tzinfo=UTC).timestamp(), + ], + "side": ["LONG", "LONG", "SHORT"], + "dollar_amount": [1000.0, 1000.0, 1000.0], + "action": ["MAINTAIN_POSITION", "MAINTAIN_POSITION", "MAINTAIN_POSITION"], + } + ) + + result = create_optimal_portfolio( + predictions, positions, 20000.0, datetime.now(tz=UTC) + ) + + assert len(result) == 20 # noqa: PLR2004 + assert "STOCK1" in result["ticker"].to_list() + assert "STOCK2" in result["ticker"].to_list() + assert "STOCK24" in result["ticker"].to_list() + + +def test_create_optimal_portfolio_high_uncertainty_exclusions() -> None: + predictions = pl.DataFrame( + { + "ticker": ["HIGH_UNCERT", "LOW_UNCERT1", "LOW_UNCERT2"], + "composite_score": [10.0, 5.0, 1.0], + "inter_quartile_range": [0.8, 0.1, 0.1], # first one too uncertain + } + ) + + positions = pl.DataFrame( + { + "ticker": [], + "timestamp": [], + "side": [], + "dollar_amount": [], + "action": [], + } + ) + + result = create_optimal_portfolio( + predictions, positions, 20000.0, datetime.now(tz=UTC) + ) + + assert "HIGH_UNCERT" not in result["ticker"].to_list() + assert len(result) == 2 # only 2 available predictions # noqa: PLR2004 + + +def test_create_optimal_portfolio_all_positions_maintained_no_new_needed() -> None: + predictions = pl.DataFrame( + { + "ticker": [f"STOCK{i}" for i in range(25)], + "composite_score": list(range(25, 0, -1)), + "inter_quartile_range": [0.1] * 25, + } + ) + + positions = pl.DataFrame( + { + "ticker": [f"MAINTAINED{i}" for i in range(20)], + "timestamp": [datetime(2024, 1, 10, tzinfo=UTC).timestamp()] * 20, + "side": ["LONG"] * 10 + ["SHORT"] * 10, + "dollar_amount": [500.0] * 20, + "action": ["MAINTAIN_POSITION"] * 20, + } + ) + + result = create_optimal_portfolio( + predictions, positions, 20000.0, datetime.now(tz=UTC) + ) + + assert len(result) == 20 # all maintained positions # noqa: PLR2004 + expected_timestamp = datetime(2024, 1, 10, tzinfo=UTC).timestamp() + assert all(ts == expected_timestamp for ts in result["timestamp"].to_list()) + + +def test_create_optimal_portfolio_capital_rebalancing_with_closed_positions() -> None: + predictions = pl.DataFrame( + { + "ticker": [f"NEW{i}" for i in range(15)], + "composite_score": list(range(15, 0, -1)), + "inter_quartile_range": [0.1] * 15, + } + ) + + positions = pl.DataFrame( + { + "ticker": ["MAINTAINED1", "MAINTAINED2", "CLOSED1", "CLOSED2"], + "timestamp": [ + datetime(2024, 1, 10, tzinfo=UTC).timestamp(), + datetime(2024, 1, 11, tzinfo=UTC).timestamp(), + datetime(2024, 1, 12, tzinfo=UTC).timestamp(), + datetime(2024, 1, 13, tzinfo=UTC).timestamp(), + ], + "side": ["LONG", "SHORT", "LONG", "SHORT"], + "dollar_amount": [800.0, 1200.0, 500.0, 500.0], # uneven amounts + "action": [ + "MAINTAIN_POSITION", + "MAINTAIN_POSITION", + "CLOSE_POSITION", + "CLOSE_POSITION", + ], + } + ) + + result = create_optimal_portfolio( + predictions, positions, 20000.0, datetime.now(tz=UTC) + ) + + # 2 maintained + 15 new (limited by available predictions) + # even though this isn't a realistic scenario + assert len(result) == 17 # noqa: PLR2004 + + maintained = result.filter( + pl.col("timestamp").is_in( + [ + datetime(2024, 1, 10, tzinfo=UTC).timestamp(), + datetime(2024, 1, 11, tzinfo=UTC).timestamp(), + ] + ) + ) + assert len(maintained) == 2 # noqa: PLR2004 + + long_total = result.filter(pl.col("side") == "LONG")["dollar_amount"].sum() + short_total = result.filter(pl.col("side") == "SHORT")["dollar_amount"].sum() + assert abs(long_total - short_total) < 0.01 # noqa: PLR2004 + + +def test_create_optimal_portfolio_mixed_closed_and_maintained_positions() -> None: + predictions = pl.DataFrame( + { + "ticker": [f"STOCK{i:02d}" for i in range(30)], + "composite_score": list(range(30, 0, -1)), + "inter_quartile_range": [0.2] * 30, # all acceptable uncertainty + } + ) + + positions = pl.DataFrame( + { + "ticker": ["OLD1", "OLD2", "OLD3", "OLD4", "OLD5"], + "timestamp": [ + datetime(2024, 1, 10, tzinfo=UTC).timestamp(), + datetime(2024, 1, 11, tzinfo=UTC).timestamp(), + datetime(2024, 1, 12, tzinfo=UTC).timestamp(), + datetime(2024, 1, 13, tzinfo=UTC).timestamp(), + datetime(2024, 1, 14, tzinfo=UTC).timestamp(), + ], + "side": ["LONG", "LONG", "SHORT", "SHORT", "LONG"], + "dollar_amount": [1000.0, 1000.0, 1000.0, 1000.0, 1000.0], + "action": [ + "CLOSE_POSITION", + "MAINTAIN_POSITION", + "MAINTAIN_POSITION", + "CLOSE_POSITION", + "MAINTAIN_POSITION", + ], + } + ) + + result = create_optimal_portfolio( + predictions, + positions, + 20000.0, + datetime.now(tz=UTC), + ) + + assert len(result) == 20 # noqa: PLR2004 + + maintained_tickers = ["OLD2", "OLD3", "OLD5"] + for ticker in maintained_tickers: + assert ticker in result["ticker"].to_list() + + closed_tickers = ["OLD1", "OLD4"] + for ticker in closed_tickers: + assert ticker not in result["ticker"].to_list() + + assert "ticker" in result.columns + assert "timestamp" in result.columns + assert "side" in result.columns + assert "dollar_amount" in result.columns + assert len(result.columns) == 4 # only these 4 columns # noqa: PLR2004 + + sorted_result = result.sort(["ticker", "side"]) + assert sorted_result.equals(result) diff --git a/infrastructure/stack.yml b/infrastructure/stack.yml index 3d1a5bff5..863bae429 100644 --- a/infrastructure/stack.yml +++ b/infrastructure/stack.yml @@ -140,4 +140,4 @@ services: - traefik.http.routers.grafana.rule=Host(`grafana.example.com`) - traefik.http.routers.grafana.entrypoints=web - traefik.http.services.grafana.loadbalancer.server.port=3000 - networks: [public, internal] \ No newline at end of file + networks: [public, internal] diff --git a/libraries/python/src/internal/company_information.py b/libraries/python/src/internal/company_information.py index 4b5f930e4..4da755ad2 100644 --- a/libraries/python/src/internal/company_information.py +++ b/libraries/python/src/internal/company_information.py @@ -1,37 +1,27 @@ import pandera.polars as pa -import polars as pl -from pandera.polars import PolarsData - - -def is_uppercase(data: PolarsData) -> pl.LazyFrame: - return data.lazyframe.select( - pl.col(data.key).str.to_uppercase() == pl.col(data.key) - ) - - -def is_stripped(data: PolarsData) -> pl.LazyFrame: - return data.lazyframe.select(pl.col(data.key).str.strip_chars() == pl.col(data.key)) - company_information_schema = pa.DataFrameSchema( { "sector": pa.Column( dtype=str, - nullable=False, - coerce=True, - checks=[ - pa.Check(is_uppercase), - pa.Check(is_stripped), - ], + default="NOT AVAILABLE", ), "industry": pa.Column( dtype=str, - nullable=False, - coerce=True, - checks=[ - pa.Check(is_uppercase), - pa.Check(is_stripped), - ], + default="NOT AVAILABLE", + ), + }, + coerce=True, + checks=[ + pa.Check( + lambda s: s.upper() == s, + error="Sector and industry must be uppercase", + element_wise=True, + ), + pa.Check( + lambda s: s.strip() == s, + error="Sector and industry must be stripped", + element_wise=True, ), - } + ], ) diff --git a/libraries/python/src/internal/equity_bar.py b/libraries/python/src/internal/equity_bar.py index c9d2b18ae..4e66a262a 100644 --- a/libraries/python/src/internal/equity_bar.py +++ b/libraries/python/src/internal/equity_bar.py @@ -1,77 +1,51 @@ import pandera.polars as pa import polars as pl -from pandera.polars import PolarsData - - -def is_uppercase(data: PolarsData) -> pl.LazyFrame: - return data.lazyframe.select( - pl.col(data.key).str.to_uppercase() == pl.col(data.key) - ) - - -def is_positive(data: PolarsData) -> pl.LazyFrame: - return data.lazyframe.select(pl.col(data.key) > 0) - equity_bar_schema = pa.DataFrameSchema( { "ticker": pa.Column( dtype=str, - nullable=False, - coerce=True, - checks=[pa.Check(is_uppercase)], + checks=[ + pa.Check( + lambda s: s.upper() == s, + error="Ticker must be uppercase", + element_wise=True, + ) + ], ), "timestamp": pa.Column( - dtype=int, - nullable=False, - coerce=True, - checks=[pa.Check(is_positive)], + dtype=pl.Float64, + checks=[pa.Check.greater_than(0)], ), "open_price": pa.Column( dtype=float, - nullable=False, - coerce=True, - checks=[pa.Check(is_positive)], + checks=[pa.Check.greater_than(0)], ), "high_price": pa.Column( dtype=float, - nullable=False, - coerce=True, - checks=[pa.Check(is_positive)], + checks=[pa.Check.greater_than(0)], ), "low_price": pa.Column( dtype=float, - nullable=False, - coerce=True, - checks=[pa.Check(is_positive)], + checks=[pa.Check.greater_than(0)], ), "close_price": pa.Column( dtype=float, - nullable=False, - coerce=True, - checks=[pa.Check(is_positive)], + checks=[pa.Check.greater_than(0)], ), "volume": pa.Column( dtype=int, - nullable=False, - coerce=True, - checks=[pa.Check(is_positive)], + checks=[pa.Check.greater_than_or_equal_to(0)], ), "volume_weighted_average_price": pa.Column( dtype=float, nullable=True, - coerce=True, - # allow missing value or enforce > 0 when present - checks=[ - pa.Check( - lambda df: df.lazyframe.select( - pl.col(df.key).is_null() | (pl.col(df.key) > 0) - ) - ) - ], + checks=[pa.Check.greater_than_or_equal_to(0)], ), }, + unique=["ticker", "timestamp"], strict="filter", # allows DuckDB partion columns ordered=True, name="equity_bar", + coerce=True, ) diff --git a/libraries/python/src/internal/mhsa_network.py b/libraries/python/src/internal/mhsa_network.py index 488984490..89e50b97b 100644 --- a/libraries/python/src/internal/mhsa_network.py +++ b/libraries/python/src/internal/mhsa_network.py @@ -36,7 +36,7 @@ def forward(self, inputs: Tensor) -> tuple[Tensor, Tensor]: shape = (batch_size, sequence_length, self.heads_count, self.heads_dimension) - # shape: (batch, heads_count, sequence_length, head_dimension) # noqa: ERA001 + # shape: (batch, heads_count, sequence_length, head_dimension) # noqa: ERA001 query_weights = query_weights.view(shape).transpose(1, 2) key_weights = key_weights.view(shape).transpose(1, 2) value_weights = value_weights.view(shape).transpose(1, 2) diff --git a/libraries/python/src/internal/portfolio.py b/libraries/python/src/internal/portfolio.py new file mode 100644 index 000000000..cab3a0448 --- /dev/null +++ b/libraries/python/src/internal/portfolio.py @@ -0,0 +1,112 @@ +import pandera.polars as pa +import polars as pl +from pandera.polars import PolarsData + + +def is_uppercase(data: PolarsData) -> pl.LazyFrame: + return data.lazyframe.select( + pl.col(data.key).str.to_uppercase() == pl.col(data.key) + ) + + +def check_position_side_counts( + data: PolarsData, + total_positions_count: int = 20, # 10 long and 10 short +) -> bool: + counts = data.lazyframe.select( + pl.len().alias("total_count"), + (pl.col("side") == "LONG").sum().alias("long_count"), + (pl.col("side") == "SHORT").sum().alias("short_count"), + ).collect() + total_count = counts.get_column("total_count").item() + long_count = counts.get_column("long_count").item() + short_count = counts.get_column("short_count").item() + side_count = total_positions_count // 2 + if long_count != side_count: + message = f"Expected {side_count} long side positions, found: {long_count}" + raise ValueError(message) + + if short_count != side_count: + message = f"Expected {side_count} short side positions, found: {short_count}" + raise ValueError(message) + + if total_count != total_positions_count: + message = ( + f"Expected {total_positions_count} total positions, found: {total_count}" + ) + raise ValueError(message) + + return True + + +def check_position_side_sums( + data: PolarsData, + maximum_imbalance_percentage: float = 0.05, # 5% +) -> bool: + sums = data.lazyframe.select( + pl.when(pl.col("side") == "LONG") + .then(pl.col("dollar_amount")) + .otherwise(0.0) + .sum() + .alias("long_sum"), + pl.when(pl.col("side") == "SHORT") + .then(pl.col("dollar_amount")) + .otherwise(0.0) + .sum() + .alias("short_sum"), + ).collect() + + long_sum = float(sums.get_column("long_sum").fill_null(0.0).item()) + short_sum = float(sums.get_column("short_sum").fill_null(0.0).item()) + total_sum = long_sum + short_sum + + if total_sum <= 0.0: + message = "Total dollar amount must be > 0 to assess imbalance" + raise ValueError(message) + + if abs(long_sum - short_sum) / total_sum > maximum_imbalance_percentage: + message = ( + "Expected long and short dollar amount sums to be within " + f"{maximum_imbalance_percentage * 100}%, " + f"found long: {long_sum}, short: {short_sum}" + ) + raise ValueError(message) + + return True + + +portfolio_schema = pa.DataFrameSchema( + { + "ticker": pa.Column( + dtype=str, + checks=[pa.Check(is_uppercase)], + ), + "timestamp": pa.Column( + dtype=pl.Float64, + checks=[pa.Check.greater_than(0)], + ), + "side": pa.Column( + dtype=str, + checks=[ + pa.Check.isin(["LONG", "SHORT"]), + pa.Check(is_uppercase), + ], + ), + "dollar_amount": pa.Column( + dtype=float, + checks=[pa.Check.greater_than(0)], + ), + }, + unique=["ticker"], + coerce=True, + checks=[ + pa.Check( + check_fn=lambda df: check_position_side_counts(df), + error="Each side must have expected position counts", + ), + pa.Check( + check_fn=lambda df: check_position_side_sums(df), + error="Position side sums must be approximately equal", + ), + ], +) diff --git a/libraries/python/src/internal/prediction.py b/libraries/python/src/internal/prediction.py new file mode 100644 index 000000000..71201c0aa --- /dev/null +++ b/libraries/python/src/internal/prediction.py @@ -0,0 +1,93 @@ +import pandera.polars as pa +import polars as pl +from pandera.polars import PolarsData + + +def check_dates_count_per_ticker( + data: PolarsData, + dates_count: int = 7, +) -> bool: + grouped = data.lazyframe.group_by("ticker").agg( + pl.col("timestamp").unique().alias("unique_dates") + ) + + unique_dates_per_ticker = grouped.collect()["unique_dates"].to_list() + + if not all(len(dates) == dates_count for dates in unique_dates_per_ticker): + message = f"Each ticker must have exactly {dates_count} unique dates, found: {unique_dates_per_ticker}" # noqa: E501 + raise ValueError(message) + + return True + + +def check_same_dates_per_ticker(data: PolarsData) -> bool: + grouped = data.lazyframe.group_by("ticker").agg( + pl.col("timestamp").unique().alias("unique_dates") + ) + + unique_dates_per_ticker = grouped.collect()["unique_dates"].to_list() + + if len(unique_dates_per_ticker) > 1: + first_ticker_dates = set(unique_dates_per_ticker[0]) + for dates in unique_dates_per_ticker[1:]: + if set(dates) != first_ticker_dates: + message = f"Expected all tickers to have the same dates, mismatch between: {first_ticker_dates} and: {set(dates)}" # noqa: E501 + raise ValueError(message) + + return True + + +def check_monotonic_quantiles(data: PolarsData) -> bool: + lazy_frame = data.lazyframe.collect() + + if ( + not (lazy_frame["quantile_10"] <= lazy_frame["quantile_50"]).all() + or not (lazy_frame["quantile_50"] <= lazy_frame["quantile_90"]).all() + ): + message = "Quantiles must be monotonic: q10 โ‰ค q50 โ‰ค q90" + raise ValueError(message) + + return True + + +prediction_schema = pa.DataFrameSchema( + columns={ + "ticker": pa.Column( + dtype=str, + checks=[ + pa.Check( + lambda s: s.upper() == s, + error="Ticker must be uppercase", + element_wise=True, + ) + ], + ), + "timestamp": pa.Column( + dtype=pl.Float64, + checks=[pa.Check.greater_than(0)], + ), + "quantile_10": pa.Column(dtype=float), + "quantile_50": pa.Column(dtype=float), + "quantile_90": pa.Column(dtype=float), + }, + coerce=True, + checks=[ + pa.Check( + check_fn=lambda df: check_dates_count_per_ticker(df), + name="check_dates_count_per_ticker", + error="Each ticker must have expected date count", + ), + pa.Check( + check_fn=lambda df: check_same_dates_per_ticker(df), + name="check_same_dates_per_ticker", + error="All tickers must have same date values", + ), + pa.Check( + check_fn=lambda df: check_monotonic_quantiles(df), + name="quantile_monotonic", + error="Quantiles must be monotonic: q10 โ‰ค q50 โ‰ค q90", + ), + ], + unique=["ticker", "timestamp"], + name="prediction", +) diff --git a/libraries/python/src/internal/tft_dataset.py b/libraries/python/src/internal/tft_dataset.py index aa075916f..1e2a82574 100644 --- a/libraries/python/src/internal/tft_dataset.py +++ b/libraries/python/src/internal/tft_dataset.py @@ -1,4 +1,4 @@ -from datetime import date +from datetime import date, datetime, timedelta import pandera.polars as pa import polars as pl @@ -24,9 +24,12 @@ def inverse_transform(self, data: pl.DataFrame) -> pl.DataFrame: class TFTDataset: - """Temporal fusion transformer dataset.""" + """Temporal fusion transformer dataset preprocessing and postprocessing.""" - def __init__(self, data: pl.DataFrame) -> None: + def __init__(self) -> None: + pass + + def preprocess_and_set_data(self, data: pl.DataFrame) -> None: data = data.clone() raw_columns = ( @@ -53,6 +56,7 @@ def __init__(self, data: pl.DataFrame) -> None: "close_price", "volume", "volume_weighted_average_price", + "daily_return", ] self.categorical_columns = [ @@ -162,6 +166,12 @@ def __init__(self, data: pl.DataFrame) -> None: .alias("time_idx") ) + data = data.with_columns( + pl.col("close_price").pct_change().over("ticker").alias("daily_return") + ) + + data = data.filter(pl.col("daily_return").is_not_null()) + data = dataset_schema.validate(data) self.scaler = Scaler() @@ -320,63 +330,113 @@ def get_batches( } if data_type in {"train", "validate"}: - batch["targets"] = Tensor(decoder_slice[["close_price"]].to_numpy()) + batch["targets"] = Tensor( + decoder_slice[["daily_return"]].to_numpy() + ) batches.append(batch) return batches + def postprocess_predictions( + self, + input_batch: Tensor, # static_categorical_features + predictions: Tensor, # quantiles + current_datetime: datetime, + ) -> pl.DataFrame: + predictions_array = predictions.numpy() + + batch_size, output_length, _, _ = predictions_array.shape + + ticker_reverse_mapping = {v: k for k, v in self.mappings["ticker"].items()} + + rows = [] + for batch_idx in range(batch_size): + ticker_encoded = int( + input_batch["static_categorical_features"][batch_idx, 0, 0].item() + ) + ticker_str = ticker_reverse_mapping[ticker_encoded] + + for time_idx in range(output_length): + timestamp = int( + (current_datetime + timedelta(days=time_idx)) + .replace( + hour=0, + minute=0, + second=0, + microsecond=0, + ) + .timestamp() + ) + + quantile_values = predictions_array[batch_idx, time_idx, 0, :] + + daily_return_mean = self.scaler.means["daily_return"] + daily_return_standard_deviation = self.scaler.standard_deviations[ + "daily_return" + ] + + unscaled_quantiles = ( + quantile_values * daily_return_standard_deviation + ) + daily_return_mean + + row = { + "ticker": ticker_str, + "timestamp": timestamp, + "quantile_10": unscaled_quantiles[0], + "quantile_50": unscaled_quantiles[1], + "quantile_90": unscaled_quantiles[2], + } + rows.append(row) + + return pl.DataFrame(rows) + dataset_schema = pa.DataFrameSchema( { "ticker": pa.Column( - str, + dtype=str, checks=pa.Check.str_matches(r"^[A-Z0-9.\-]+$"), - coerce=True, - required=True, ), "timestamp": pa.Column( - int, - checks=pa.Check.gt(0), - coerce=True, - required=True, + dtype=int, + checks=pa.Check.greater_than(0), ), "open_price": pa.Column( - float, - checks=pa.Check.ge(0), - coerce=True, - required=True, + dtype=float, + checks=pa.Check.greater_than(0), ), "high_price": pa.Column( - float, checks=pa.Check.ge(0), coerce=True, required=True + dtype=float, + checks=pa.Check.greater_than(0), ), "low_price": pa.Column( - float, checks=pa.Check.ge(0), coerce=True, required=True + dtype=float, + checks=pa.Check.greater_than(0), ), "close_price": pa.Column( - float, checks=pa.Check.ge(0), coerce=True, required=True + dtype=float, + checks=pa.Check.greater_than(0), ), "volume": pa.Column( - int, - checks=pa.Check.ge(0), - coerce=True, - required=True, + dtype=int, + checks=pa.Check.greater_than_or_equal_to(0), ), "volume_weighted_average_price": pa.Column( - float, - checks=pa.Check.ge(0), - coerce=True, - required=True, + dtype=float, + checks=pa.Check.greater_than_or_equal_to(0), ), - "sector": pa.Column(str, coerce=True, required=True), - "industry": pa.Column(str, coerce=True, required=True), - "date": pa.Column(date, coerce=True, required=True), - "day_of_week": pa.Column(int, coerce=True, required=True), - "day_of_month": pa.Column(int, coerce=True, required=True), - "day_of_year": pa.Column(int, coerce=True, required=True), - "month": pa.Column(int, coerce=True, required=True), - "year": pa.Column(int, coerce=True, required=True), - "is_holiday": pa.Column(bool, coerce=True, required=True), - "time_idx": pa.Column(int, coerce=True, required=True), - } + "sector": pa.Column(dtype=str), + "industry": pa.Column(dtype=str), + "date": pa.Column(dtype=date), + "day_of_week": pa.Column(dtype=int), + "day_of_month": pa.Column(dtype=int), + "day_of_year": pa.Column(dtype=int), + "month": pa.Column(dtype=int), + "year": pa.Column(dtype=int), + "is_holiday": pa.Column(dtype=bool), + "time_idx": pa.Column(dtype=int), + "daily_return": pa.Column(dtype=float), + }, + coerce=True, ) diff --git a/libraries/python/src/internal/tft_model.py b/libraries/python/src/internal/tft_model.py index 8e2654237..3fa10d465 100644 --- a/libraries/python/src/internal/tft_model.py +++ b/libraries/python/src/internal/tft_model.py @@ -18,7 +18,7 @@ class Parameters(BaseModel): hidden_size: int = 64 - output_size: int = 1 # closing price + output_size: int = 1 # daily return lstm_layer_count: int = 3 attention_head_size: int = 4 dropout_rate: float = 0.1 diff --git a/libraries/python/tests/test_company_information.py b/libraries/python/tests/test_company_information.py index 1d2df794f..a5fd4b0c9 100644 --- a/libraries/python/tests/test_company_information.py +++ b/libraries/python/tests/test_company_information.py @@ -98,8 +98,9 @@ def test_company_information_schema_null_sector() -> None: } ) - with pytest.raises(SchemaError): - company_information_schema.validate(data) + validated_df = company_information_schema.validate(data) + assert validated_df["sector"][0] == "NOT AVAILABLE" + assert validated_df["industry"][0] == "SOFTWARE" def test_company_information_schema_null_industry() -> None: @@ -110,8 +111,9 @@ def test_company_information_schema_null_industry() -> None: } ) - with pytest.raises(SchemaError): - company_information_schema.validate(data) + validated_df = company_information_schema.validate(data) + assert validated_df["industry"][0] == "NOT AVAILABLE" + assert validated_df["sector"][0] == "TECHNOLOGY" def test_company_information_schema_missing_sector_column() -> None: @@ -137,7 +139,6 @@ def test_company_information_schema_missing_industry_column() -> None: def test_company_information_schema_type_coercion() -> None: - # test that numeric inputs get coerced to strings data = pl.DataFrame( { "sector": [123], # coerced to string diff --git a/libraries/python/tests/test_equity_bar.py b/libraries/python/tests/test_equity_bar.py index 59ef8958a..2c9eb6829 100644 --- a/libraries/python/tests/test_equity_bar.py +++ b/libraries/python/tests/test_equity_bar.py @@ -26,7 +26,7 @@ def test_equity_bar_schema_ticker_lowercase_fails() -> None: data = pl.DataFrame( { "ticker": ["aapl"], - "timestamp": [1674000000], + "timestamp": [1674000000.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -44,7 +44,7 @@ def test_equity_bar_schema_ticker_uppercase_passes() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [1674000000], + "timestamp": [1674000000.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -62,7 +62,7 @@ def test_equity_bar_schema_negative_timestamp() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [-1674000000], + "timestamp": [-1674000000.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -80,7 +80,7 @@ def test_equity_bar_schema_zero_timestamp() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [0], + "timestamp": [0.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -107,7 +107,7 @@ def test_equity_bar_schema_negative_prices() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [1674000000], + "timestamp": [1674000000.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -127,7 +127,7 @@ def test_equity_bar_schema_zero_prices_not_allowed() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [1674000000], + "timestamp": [1674000000.0], "open_price": [0.0], "high_price": [155.0], "low_price": [149.0], @@ -145,7 +145,7 @@ def test_equity_bar_schema_negative_volume() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [1674000000], + "timestamp": [1674000000.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -159,29 +159,11 @@ def test_equity_bar_schema_negative_volume() -> None: equity_bar_schema.validate(data) -def test_equity_bar_schema_zero_volume_not_allowed() -> None: - data = pl.DataFrame( - { - "ticker": ["AAPL"], - "timestamp": [1674000000], - "open_price": [150.0], - "high_price": [155.0], - "low_price": [149.0], - "close_price": [153.0], - "volume": [0], - "volume_weighted_average_price": [152.5], - } - ) - - with pytest.raises(SchemaError): - equity_bar_schema.validate(data) - - def test_equity_bar_schema_type_coercion() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": ["1674000000"], # string that can be coerced to int + "timestamp": ["1674000000.0"], # string that can be coerced to float "open_price": ["150.0"], # string that can be coerced to float "high_price": [155], # int that can be coerced to float "low_price": [149.0], @@ -192,7 +174,7 @@ def test_equity_bar_schema_type_coercion() -> None: ) validated_df = equity_bar_schema.validate(data) - assert validated_df["timestamp"].dtype == pl.Int64 + assert validated_df["timestamp"].dtype == pl.Float64 assert validated_df["open_price"].dtype == pl.Float64 assert validated_df["high_price"].dtype == pl.Float64 assert validated_df["volume"].dtype == pl.Int64 @@ -202,7 +184,7 @@ def test_equity_bar_schema_missing_required_column() -> None: data = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [1674000000], + "timestamp": [1674000000.0], # Missing open_price "high_price": [155.0], "low_price": [149.0], @@ -220,7 +202,7 @@ def test_equity_bar_schema_null_values() -> None: data = pl.DataFrame( { "ticker": [None], - "timestamp": [1674000000], + "timestamp": [1674000000.0], "open_price": [150.0], "high_price": [155.0], "low_price": [149.0], @@ -238,7 +220,7 @@ def test_equity_bar_schema_multiple_rows() -> None: data = pl.DataFrame( { "ticker": ["AAPL", "GOOGL", "NVDA"], - "timestamp": [1674000000, 1674000060, 1674000120], + "timestamp": [1674000000.0, 1674000060.0, 1674000120.0], "open_price": [150.0, 100.0, 300.0], "high_price": [155.0, 105.0, 305.0], "low_price": [149.0, 99.0, 299.0], diff --git a/libraries/python/tests/test_portfolio.py b/libraries/python/tests/test_portfolio.py new file mode 100644 index 000000000..f2da49f1f --- /dev/null +++ b/libraries/python/tests/test_portfolio.py @@ -0,0 +1,185 @@ +from datetime import UTC, datetime + +import polars as pl +import pytest +from internal.portfolio import portfolio_schema +from pandera.errors import SchemaError + + +def test_portfolio_schema_valid_data() -> None: + valid_data = pl.DataFrame( + { + "ticker": [ + "AAPL", + "GOOGL", + "MSFT", + "AMZN", + "TSLA", + "NVDA", + "META", + "NFLX", + "BABA", + "CRM", + "AMD", + "INTC", + "ORCL", + "ADBE", + "PYPL", + "SHOP", + "SPOT", + "ROKU", + "ZM", + "DOCU", + ], + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()] + * 20, + "side": (["LONG"] * 10) + (["SHORT"] * 10), + "dollar_amount": [1000.0] * 20, # Equal amounts for balanced portfolio + } + ) + + validated_df = portfolio_schema.validate(valid_data) + assert validated_df.shape == (20, 4) + + +def test_portfolio_schema_ticker_lowercase_fails() -> None: + data = pl.DataFrame( + { + "ticker": ["aapl"], # lowercase should fail + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()], + "side": ["LONG"], + "dollar_amount": [1000.0], + } + ) + + with pytest.raises(SchemaError): + portfolio_schema.validate(data) + + +def test_portfolio_schema_invalid_side_fails() -> None: + data = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()], + "side": ["BUY"], # Invalid side value + "dollar_amount": [1000.0], + } + ) + + with pytest.raises(SchemaError): + portfolio_schema.validate(data) + + +def test_portfolio_schema_negative_dollar_amount_fails() -> None: + data = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()], + "side": ["LONG"], + "dollar_amount": [-1000.0], # Negative amount should fail + } + ) + + with pytest.raises(SchemaError): + portfolio_schema.validate(data) + + +def test_portfolio_schema_unbalanced_sides_fails() -> None: + data = pl.DataFrame( + { + "ticker": [ + "AAPL", + "GOOGL", + "MSFT", + "AMZN", + "TSLA", + "NVDA", + "META", + "NFLX", + "BABA", + "CRM", + "AMD", + "INTC", + "ORCL", + "ADBE", + "PYPL", + "SHOP", + "SPOT", + "ROKU", + "ZM", + "DOCU", + ], + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()] + * 20, + "side": ["LONG"] * 15 + ["SHORT"] * 5, # Unbalanced: 15 LONG, 5 SHORT + "dollar_amount": [1000.0] * 20, + } + ) + + with pytest.raises(SchemaError, match="Expected 10 long side positions, found: 15"): + portfolio_schema.validate(data) + + +def test_portfolio_schema_imbalanced_dollar_amounts_fails() -> None: + data = pl.DataFrame( + { + "ticker": [ + "AAPL", + "GOOGL", + "MSFT", + "AMZN", + "TSLA", + "NVDA", + "META", + "NFLX", + "BABA", + "CRM", + "AMD", + "INTC", + "ORCL", + "ADBE", + "PYPL", + "SHOP", + "SPOT", + "ROKU", + "ZM", + "DOCU", + ], + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()] + * 20, + "side": (["LONG"] * 10) + (["SHORT"] * 10), + "dollar_amount": ([2000.0] * 10) + + ([500.0] * 10), # Very imbalanced amounts + } + ) + + with pytest.raises(SchemaError, match="long and short dollar amount sums"): + portfolio_schema.validate(data) + + +def test_portfolio_schema_duplicate_tickers_fails() -> None: + data = pl.DataFrame( + { + "ticker": ["AAPL", "AAPL"], # Duplicate ticker + "timestamp": [datetime(2025, 1, 1, 0, 0, 0, 0, tzinfo=UTC).timestamp()] * 2, + "side": ["LONG", "SHORT"], + "dollar_amount": [1000.0, 1000.0], + } + ) + + with pytest.raises(SchemaError): + portfolio_schema.validate(data) + + +def test_portfolio_schema_zero_timestamp_fails() -> None: + data = pl.DataFrame( + { + "ticker": ["AAPL"], + "timestamp": [0.0], # Zero timestamp should fail + "side": ["LONG"], + "dollar_amount": [1000.0], + } + ) + + with pytest.raises(SchemaError): + portfolio_schema.validate(data) diff --git a/libraries/python/tests/test_prediction.py b/libraries/python/tests/test_prediction.py new file mode 100644 index 000000000..3d7bc6747 --- /dev/null +++ b/libraries/python/tests/test_prediction.py @@ -0,0 +1,162 @@ +from datetime import UTC, datetime, timedelta + +import polars as pl +import pytest +from internal.prediction import prediction_schema +from pandera.errors import SchemaError + + +def test_prediction_schema_valid_data() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + valid_data = pl.DataFrame( + { + "ticker": ["AAPL"] * 7, + "timestamp": [ + (base_date + timedelta(days=i)).timestamp() for i in range(7) + ], + "quantile_10": [100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0], + "quantile_50": [150.0, 151.0, 152.0, 153.0, 154.0, 155.0, 156.0], + "quantile_90": [200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0], + } + ) + + validated_df = prediction_schema.validate(valid_data) + assert validated_df.shape == (7, 5) + + +def test_prediction_schema_ticker_lowercase_fails() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + data = pl.DataFrame( + { + "ticker": ["aapl"] * 7, # lowercase should fail + "timestamp": [ + (base_date + timedelta(days=i)).timestamp() for i in range(7) + ], + "quantile_10": [100.0] * 7, + "quantile_50": [150.0] * 7, + "quantile_90": [200.0] * 7, + } + ) + + with pytest.raises(SchemaError): + prediction_schema.validate(data) + + +def test_prediction_schema_negative_timestamp_fails() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + data = pl.DataFrame( + { + "ticker": ["AAPL"] * 7, + "timestamp": [ + (base_date + timedelta(days=i)).timestamp() for i in range(7) + ], + "quantile_10": [100.0] * 7, + "quantile_50": [150.0] * 7, + "quantile_90": [200.0] * 7, + } + ) + + data[1, "timestamp"] = -1.0 # introduce a negative timestamp + + with pytest.raises(SchemaError): + prediction_schema.validate(data) + + +def test_prediction_schema_duplicate_ticker_timestamp_fails() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + data = pl.DataFrame( + { + "ticker": ["AAPL"] * 8, # 8 rows but duplicate timestamp + "timestamp": [(base_date + timedelta(days=i)).timestamp() for i in range(7)] + + [(base_date + timedelta(days=0)).timestamp()], # duplicate timestamp + "quantile_10": [100.0] * 8, + "quantile_50": [150.0] * 8, + "quantile_90": [200.0] * 8, + } + ) + + with pytest.raises(SchemaError): + prediction_schema.validate(data) + + +def test_prediction_schema_multiple_tickers_same_dates() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + valid_data = pl.DataFrame( + { + "ticker": (["AAPL"] * 7) + (["GOOGL"] * 7), + "timestamp": [(base_date + timedelta(days=i)).timestamp() for i in range(7)] + * 2, # same timestamps for both tickers + "quantile_10": [100.0] * 14, + "quantile_50": [150.0] * 14, + "quantile_90": [200.0] * 14, + } + ) + + validated_df = prediction_schema.validate(valid_data) + assert validated_df.shape == (14, 5) + + +def test_prediction_schema_multiple_tickers_different_dates_fails() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + data = pl.DataFrame( + { + "ticker": (["AAPL"] * 7) + (["GOOGL"] * 7), + "timestamp": [ + (base_date + timedelta(days=i)).timestamp() for i in range(14) + ], + "quantile_10": [100.0] * 14, + "quantile_50": [150.0] * 14, + "quantile_90": [200.0] * 14, + } + ) + + with pytest.raises( + SchemaError, match="Expected all tickers to have the same dates" + ): + prediction_schema.validate(data) + + +def test_prediction_schema_wrong_date_count_per_ticker_fails() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + data = pl.DataFrame( + { + "ticker": ["AAPL"] * 5, # only 5 dates instead of 7 + "timestamp": [ + (base_date + timedelta(days=i)).timestamp() for i in range(5) + ], + "quantile_10": [100.0] * 5, + "quantile_50": [150.0] * 5, + "quantile_90": [200.0] * 5, + } + ) + + with pytest.raises( + SchemaError, match="Each ticker must have exactly 7 unique dates" + ): + prediction_schema.validate(data) + + +def test_prediction_schema_float_quantile_values() -> None: + base_date = datetime(2024, 1, 1, tzinfo=UTC) + + valid_data = pl.DataFrame( + { + "ticker": ["AAPL"] * 7, + "timestamp": [ + (base_date + timedelta(days=i)).timestamp() for i in range(7) + ], + "quantile_10": [100.5, 101.7, 102.3, 103.8, 104.2, 105.9, 106.1], + "quantile_50": [150.1, 151.4, 152.6, 153.2, 154.8, 155.3, 156.7], + "quantile_90": [200.9, 201.2, 202.5, 203.7, 204.4, 205.6, 206.8], + } + ) + + validated_df = prediction_schema.validate(valid_data) + assert validated_df.shape == (7, 5) diff --git a/libraries/python/tests/test_tft_dataset.py b/libraries/python/tests/test_tft_dataset.py index 4261889f8..6c720d64e 100644 --- a/libraries/python/tests/test_tft_dataset.py +++ b/libraries/python/tests/test_tft_dataset.py @@ -1,3 +1,5 @@ +import math + import polars as pl from internal.tft_dataset import TFTDataset @@ -17,7 +19,7 @@ def test_tft_dataset_load_data() -> None: "high_price": [110.0, 111.0, 112.0, 60.0, 61.0, 62.0], "low_price": [90.0, 91.0, 92.0, 40.0, 41.0, 42.0], "close_price": [105.0, 106.0, 107.0, 55.0, 56.0, 57.0], - "volume": [1000.0, 1100.0, 1200.0, 500.0, 600.0, 700.0], + "volume": [1000, 1100, 1200, 500, 600, 700], "volume_weighted_average_price": [105.0, 106.0, 107.0, 55.0, 56.0, 57.0], "ticker": ["AAPL", "AAPL", "AAPL", "GOOGL", "GOOGL", "GOOGL"], "sector": [ @@ -39,7 +41,9 @@ def test_tft_dataset_load_data() -> None: } ) - dataset = TFTDataset(data=data) + dataset = TFTDataset() + + dataset.preprocess_and_set_data(data) assert hasattr(dataset, "data") assert hasattr(dataset, "mappings") @@ -56,7 +60,7 @@ def test_tft_dataset_get_dimensions() -> None: "high_price": [110.0, 111.0], "low_price": [90.0, 91.0], "close_price": [105.0, 106.0], - "volume": [1000.0, 1100.0], + "volume": [1000, 1100], "volume_weighted_average_price": [105.0, 106.0], "ticker": ["AAPL", "AAPL"], "sector": ["Technology", "Technology"], @@ -64,7 +68,9 @@ def test_tft_dataset_get_dimensions() -> None: } ) - dataset = TFTDataset(data=data) + dataset = TFTDataset() + + dataset.preprocess_and_set_data(data) dimensions = dataset.get_dimensions() @@ -83,24 +89,28 @@ def test_tft_dataset_batches() -> None: 1672531200000, # 2023-01-01 1672617600000, # 2023-01-02 1672704000000, # 2023-01-03 + 1672790400000, # 2023-01-04 ], - "open_price": [100.0, 101.0, 102.0], - "high_price": [110.0, 111.0, 112.0], - "low_price": [90.0, 91.0, 92.0], - "close_price": [105.0, 106.0, 107.0], - "volume": [1000.0, 1100.0, 1200.0], - "volume_weighted_average_price": [105.0, 106.0, 107.0], - "ticker": ["AAPL", "AAPL", "AAPL"], - "sector": ["Technology", "Technology", "Technology"], + "open_price": [100.0, 101.0, 102.0, 103.0], + "high_price": [110.0, 111.0, 112.0, 113.0], + "low_price": [90.0, 91.0, 92.0, 93.0], + "close_price": [105.0, 106.0, 107.0, 108.0], + "volume": [1000, 1100, 1200, 1300], + "volume_weighted_average_price": [105.0, 106.0, 107.0, 108.0], + "ticker": ["AAPL", "AAPL", "AAPL", "AAPL"], + "sector": ["Technology", "Technology", "Technology", "Technology"], "industry": [ "Consumer Electronics", "Consumer Electronics", "Consumer Electronics", + "Consumer Electronics", ], } ) - dataset = TFTDataset(data=data) + dataset = TFTDataset() + + dataset.preprocess_and_set_data(data) expected_input_length = 2 expected_output_length = 1 @@ -129,3 +139,106 @@ def test_tft_dataset_batches() -> None: assert encoder_continuous_features.shape[0] == expected_input_length assert decoder_categorical_features.shape[0] == expected_output_length assert static_categorical_features.shape[0] == 1 + + +def test_tft_dataset_daily_return_calculation() -> None: + data = pl.DataFrame( + { + "timestamp": [ + 1672531200000, # 2023-01-01 + 1672617600000, # 2023-01-02 + 1672704000000, # 2023-01-03 + 1672790400000, # 2023-01-04 + ], + "open_price": [100.0, 101.0, 102.0, 103.0], + "high_price": [110.0, 111.0, 112.0, 113.0], + "low_price": [90.0, 91.0, 92.0, 93.0], + "close_price": [100.0, 105.0, 102.0, 108.0], # returns: +5%, -2.86%, +5.88% + "volume": [1000, 1100, 1200, 1300], + "volume_weighted_average_price": [100.0, 105.0, 102.0, 108.0], + "ticker": ["AAPL", "AAPL", "AAPL", "AAPL"], + "sector": ["Technology", "Technology", "Technology", "Technology"], + "industry": [ + "Consumer Electronics", + "Consumer Electronics", + "Consumer Electronics", + "Consumer Electronics", + ], + } + ) + + dataset = TFTDataset() + + dataset.preprocess_and_set_data(data) + + assert "daily_return" in dataset.data.columns + + expected_rows_after_filter = 3 + assert len(dataset.data) == expected_rows_after_filter + + assert "daily_return" in dataset.continuous_columns + + daily_returns = dataset.data["daily_return"].to_list() + + expected_daily_return_count = 3 + assert len(daily_returns) == expected_daily_return_count + assert all(isinstance(val, float) and math.isfinite(val) for val in daily_returns) + + +def test_tft_dataset_daily_return_targets() -> None: + data = pl.DataFrame( + { + "timestamp": [ + 1672531200000, # 2023-01-01 + 1672617600000, # 2023-01-02 + 1672704000000, # 2023-01-03 + 1672790400000, # 2023-01-04 + 1672876800000, # 2023-01-05 + 1672963200000, # 2023-01-06 + ], + "open_price": [100.0, 101.0, 102.0, 103.0, 104.0, 105.0], + "high_price": [110.0, 111.0, 112.0, 113.0, 114.0, 115.0], + "low_price": [90.0, 91.0, 92.0, 93.0, 94.0, 95.0], + "close_price": [100.0, 105.0, 102.0, 108.0, 106.0, 110.0], + "volume": [1000, 1100, 1200, 1300, 1400, 1500], + "volume_weighted_average_price": [100.0, 105.0, 102.0, 108.0, 106.0, 110.0], + "ticker": ["AAPL", "AAPL", "AAPL", "AAPL", "AAPL", "AAPL"], + "sector": [ + "Technology", + "Technology", + "Technology", + "Technology", + "Technology", + "Technology", + ], + "industry": [ + "Consumer Electronics", + "Consumer Electronics", + "Consumer Electronics", + "Consumer Electronics", + "Consumer Electronics", + "Consumer Electronics", + ], + } + ) + + dataset = TFTDataset() + + dataset.preprocess_and_set_data(data) + + batches = dataset.get_batches( + data_type="train", + input_length=2, + output_length=1, + ) + + assert len(batches) > 0 + + for batch in batches: + if "targets" in batch: + targets = batch["targets"] + assert targets.shape[1] == 1 + expected_tensor_dimensions = 2 + assert len(targets.shape) == expected_tensor_dimensions + assert targets.shape[0] > 0 + assert targets.shape[1] == 1 diff --git a/maskfile.md b/maskfile.md index 68e48d482..b1fab2865 100644 --- a/maskfile.md +++ b/maskfile.md @@ -5,23 +5,23 @@ ```bash set -euo pipefail -echo "๐Ÿš€ Setting up PocketSizeFund development environment" +echo "Setting up PocketSizeFund development environment" echo "==================================================" # Check prerequisites -echo "๐Ÿ” Checking prerequisites..." +echo "Checking prerequisites..." missing_deps=() -if ! command -v docker >/dev/null 2>&1; then +if ! command -v docker >/dev/null >&1; then missing_deps+=("Docker") fi -if ! command -v pulumi >/dev/null 2>&1; then +if ! command -v pulumi >/dev/null >&1; then missing_deps+=("Pulumi CLI") fi if [[ ${#missing_deps[@]} -gt 0 ]]; then - echo "โŒ Missing prerequisites: ${missing_deps[*]}" + echo " Missing prerequisites: ${missing_deps[*]}" echo "" echo "Please install the following:" for dep in "${missing_deps[@]}"; do @@ -38,40 +38,40 @@ if [[ ${#missing_deps[@]} -gt 0 ]]; then fi # Check Docker login -echo "๐Ÿณ Checking DockerHub authentication..." -if ! docker info >/dev/null 2>&1; then - echo "โŒ Docker daemon not running" +echo " Checking DockerHub authentication..." +if ! docker info >/dev/null >&1; then + echo " Docker daemon not running" exit 1 fi if ! docker system info --format '{{.Username}}' 2>/dev/null | grep -q .; then - echo "โš ๏ธ Not logged into DockerHub. Run 'docker login' first" + echo "๏ธ Not logged into DockerHub. Run 'docker login' first" exit 1 fi # Set up environment variable -echo "๐ŸŒ Setting up environment..." +echo " Setting up environment..." if [[ -z "${ACME_EMAIL:-}" ]]; then - echo "โš ๏ธ ACME_EMAIL environment variable not set" + echo "๏ธ ACME_EMAIL environment variable not set" echo " Add this to your shell profile: export ACME_EMAIL=\"your-email@example.com\"" read -p "Enter your email for ACME certificates: " email if [[ -n "$email" ]]; then export ACME_EMAIL="$email" - echo "โœ… ACME_EMAIL set to: $ACME_EMAIL" + echo " ACME_EMAIL set to: $ACME_EMAIL" echo " Add 'export ACME_EMAIL=\"$ACME_EMAIL\"' to your shell profile" fi else - echo "โœ… ACME_EMAIL: $ACME_EMAIL" + echo " ACME_EMAIL: $ACME_EMAIL" fi -echo "โœ… Prerequisites check completed" +echo " Prerequisites check completed" echo "" -echo "๐Ÿ“‹ Next steps:" +echo " Next steps:" echo " mask secrets create # Set up Docker secrets" echo " mask infrastructure base up # Deploy infrastructure" echo " mask development python install # Install Python dependencies" echo "" -echo "๐Ÿ“– Service URLs (after deployment):" +echo "Service URLs (after deployment):" echo " - DataManager: http://[manager-ip]:8080" echo " - PortfolioManager: http://[manager-ip]:8081" echo " - Grafana: https://grafana.example.com" @@ -83,22 +83,22 @@ echo " - Portainer: https://[manager-ip]:9443" ```bash set -euo pipefail -echo "๐Ÿ”„ Running full development workflow" +echo " Running full development workflow" echo "===================================" -echo "1๏ธโƒฃ Installing dependencies..." +echo "Installing dependencies..." mask development python install echo "" -echo "2๏ธโƒฃ Running quality checks..." +echo "Running quality checks..." mask development quality echo "" -echo "3๏ธโƒฃ Running tests..." +echo "Running tests..." mask development python test echo "" -echo "โœ… Development workflow completed successfully!" +echo " Development workflow completed successfully!" ``` ## ci @@ -106,22 +106,22 @@ echo "โœ… Development workflow completed successfully!" ```bash set -euo pipefail -echo "๐Ÿค– Running CI workflow" +echo " Running CI workflow" echo "=====================" -echo "1๏ธโƒฃ Quality checks..." +echo "Quality checks..." mask development quality echo "" -echo "2๏ธโƒฃ Testing..." +echo "Testing..." mask development python test echo "" -echo "3๏ธโƒฃ Building applications..." +echo "Building applications..." mask infrastructure applications build echo "" -echo "โœ… CI workflow completed successfully!" +echo " CI workflow completed successfully!" ``` ## infrastructure @@ -133,31 +133,31 @@ echo "โœ… CI workflow completed successfully!" ```bash set -euo pipefail -echo "๐Ÿš€ Starting infrastructure deployment..." +echo " Starting infrastructure deployment..." cd infrastructure # Deploy infrastructure with Pulumi -echo "๐Ÿ“ก Deploying infrastructure with Pulumi..." +echo " Deploying infrastructure with Pulumi..." if ! pulumi up --yes; then - echo "โŒ Pulumi deployment failed" + echo " Pulumi deployment failed" exit 1 fi -echo "๐Ÿ“‹ Getting infrastructure outputs..." +echo " Getting infrastructure outputs..." MANAGER_IP=$(pulumi stack output managerIp | tr -d '\r\n') if [[ -z "$MANAGER_IP" ]]; then - echo "โŒ Failed to get manager IP from Pulumi" + echo " Failed to get manager IP from Pulumi" exit 1 fi -echo "๐Ÿ”‘ Setting up SSH configuration..." +echo " Setting up SSH configuration..." pulumi stack output --show-secrets sshPrivateKeyPem | tr -d '\r' > swarm.pem chmod 600 swarm.pem # Verify SSH key format -if ! ssh-keygen -l -f swarm.pem >/dev/null 2>&1; then - echo "โŒ Invalid SSH key format" +if ! ssh-keygen -l -f swarm.pem >/dev/null >&1; then + echo " Invalid SSH key format" exit 1 fi @@ -184,106 +184,106 @@ Host pocketsizefund-production ServerAliveCountMax 3 EOF -echo "๐Ÿ” Updating SSH known hosts..." +echo " Updating SSH known hosts..." # Remove old host key and add new one -ssh-keygen -R "$MANAGER_IP" >/dev/null 2>&1 || true -ssh-keyscan -H "$MANAGER_IP" >> "$HOME/.ssh/known_hosts" 2>/dev/null +ssh-keygen -R "$MANAGER_IP" >/dev/null >&1 || true +ssh-keyscan -H "$MANAGER_IP" 2>/dev/null >> "$HOME/.ssh/known_hosts" # Test SSH connection and Docker -echo "๐Ÿงช Testing SSH connection..." +echo " Testing SSH connection..." MAX_RETRIES=5 RETRY_COUNT=0 while [[ $RETRY_COUNT -lt $MAX_RETRIES ]]; do - if ssh -o ConnectTimeout=10 pocketsizefund-production 'docker info -f "{{.ServerVersion}} {{.Swarm.LocalNodeState}}"' 2>/dev/null; then - echo "โœ… SSH connection successful" + if ssh -o ConnectTimeout=0 pocketsizefund-production 'docker info -f "{{.ServerVersion}} {{.Swarm.LocalNodeState}}"' 2>/dev/null; then + echo " SSH connection successful" break else ((RETRY_COUNT++)) - echo "โณ SSH connection attempt $RETRY_COUNT/$MAX_RETRIES failed, retrying in 5 seconds..." + echo " SSH connection attempt $RETRY_COUNT/$MAX_RETRIES failed, retrying in 5 seconds..." sleep 5 fi done if [[ $RETRY_COUNT -eq $MAX_RETRIES ]]; then - echo "โŒ Failed to establish SSH connection after $MAX_RETRIES attempts" + echo " Failed to establish SSH connection after $MAX_RETRIES attempts" exit 1 fi -echo "๐Ÿณ Setting up Docker contexts..." +echo " Setting up Docker contexts..." # Remove existing contexts safely for context in pocketsizefund-production pocketsizefund-local; do if docker context ls --format '{{.Name}}' | grep -q "^${context}$"; then - docker context use default >/dev/null 2>&1 || true - docker context rm -f "$context" >/dev/null 2>&1 || true + docker context use default >/dev/null >&1 || true + docker context rm -f "$context" >/dev/null >&1 || true fi done # Create Docker contexts if ! docker context create pocketsizefund-production --docker "host=ssh://pocketsizefund-production"; then - echo "โŒ Failed to create production Docker context" + echo " Failed to create production Docker context" exit 1 fi if ! docker context create pocketsizefund-local --docker "host=unix:///var/run/docker.sock"; then - echo "โŒ Failed to create local Docker context" + echo " Failed to create local Docker context" exit 1 fi # Initialize local swarm if needed -echo "๐Ÿ”„ Ensuring local Docker swarm is initialized..." +echo " Ensuring local Docker swarm is initialized..." docker context use pocketsizefund-local if ! docker info --format '{{.Swarm.LocalNodeState}}' | grep -q active; then docker swarm init --advertise-addr 127.0.0.1 >/dev/null 2>&1 || true fi # Deploy infrastructure stack to production -echo "๐Ÿš€ Deploying infrastructure stack to production..." +echo " Deploying infrastructure stack to production..." docker context use pocketsizefund-production if ! docker stack deploy -c stack.yml infrastructure --with-registry-auth; then - echo "โŒ Failed to deploy infrastructure stack to production" + echo " Failed to deploy infrastructure stack to production" exit 1 fi # Deploy infrastructure stack to local -echo "๐Ÿ  Deploying infrastructure stack to local..." +echo " Deploying infrastructure stack to local..." docker context use pocketsizefund-local if ! docker stack deploy -c stack.yml infrastructure --with-registry-auth; then - echo "โŒ Failed to deploy infrastructure stack to local" + echo " Failed to deploy infrastructure stack to local" exit 1 fi # Deploy application services -echo "๐Ÿ“ฑ Deploying application services..." +echo " Deploying application services..." cd ../applications # Deploy to production -echo "๐ŸŒ Deploying applications to production..." +echo " Deploying applications to production..." docker context use pocketsizefund-production if ! docker stack deploy -c stack.yml applications --with-registry-auth; then - echo "โŒ Failed to deploy applications to production" + echo " Failed to deploy applications to production" exit 1 fi # Deploy to local -echo "๐Ÿ  Deploying applications to local..." +echo " Deploying applications to local..." docker context use pocketsizefund-local if ! docker stack deploy -c stack.yml applications --with-registry-auth; then - echo "โŒ Failed to deploy applications to local" + echo " Failed to deploy applications to local" exit 1 fi # Show cluster status -echo "๐Ÿ“Š Cluster status:" +echo " Cluster status:" docker context use pocketsizefund-production echo "Production cluster:" -docker node ls 2>/dev/null || echo " Unable to list production nodes" +docker node ls >/dev/null || echo " Unable to list production nodes" docker context use pocketsizefund-local echo "Local cluster:" -docker node ls 2>/dev/null || echo " Unable to list local nodes" +docker node ls >/dev/null || echo " Unable to list local nodes" -echo "โœ… Infrastructure deployment completed successfully!" +echo " Infrastructure deployment completed successfully!" echo "" echo "Next steps:" echo " mask test all # Test complete deployment" @@ -298,35 +298,35 @@ echo " mask docker context production # Switch to production context" ```bash set -euo pipefail -echo "๐Ÿ›‘ Taking down infrastructure..." +echo " Taking down infrastructure..." cd infrastructure -echo "๐Ÿ“ฑ Removing application stacks..." +echo " Removing application stacks..." for context in pocketsizefund-production pocketsizefund-local; do echo " Removing from $context..." if docker context ls --format '{{.Name}}' | grep -q "^${context}$"; then docker context use "$context" - docker stack rm applications 2>/dev/null || echo " applications not found in $context" - docker stack rm infrastructure 2>/dev/null || echo " infrastructure stack not found in $context" + docker stack rm applications >/dev/null || echo " applications not found in $context" + docker stack rm infrastructure >/dev/null || echo " infrastructure stack not found in $context" else echo " Context $context not found" fi done # Wait for services to be removed -echo "โณ Waiting for services to stop..." +echo " Waiting for services to stop..." sleep 10 # Destroy Pulumi infrastructure -echo "โ˜๏ธ Destroying cloud infrastructure..." +echo "๏ธ Destroying cloud infrastructure..." if ! pulumi destroy --yes; then - echo "โŒ Pulumi destroy failed" + echo " Pulumi destroy failed" exit 1 fi # Clean up SSH config -echo "๐Ÿงน Cleaning up SSH configuration..." +echo " Cleaning up SSH configuration..." SSH_CONFIG="$HOME/.ssh/config" if [[ -f "$SSH_CONFIG" ]]; then sed -i.bak '/^Host pocketsizefund-production$/,/^Host /{ /^Host pocketsizefund-production$/d; /^Host /!d; }' "$SSH_CONFIG" || true @@ -337,15 +337,15 @@ fi rm -f swarm.pem # Remove Docker contexts -echo "๐Ÿณ Removing Docker contexts..." -docker context use default >/dev/null 2>&1 || true +echo " Removing Docker contexts..." +docker context use default >/dev/null >&1 || true for context in pocketsizefund-production pocketsizefund-local; do - docker context rm -f "$context" >/dev/null 2>&1 || true + docker context rm -f "$context" >/dev/null >&1 || true done -echo "โœ… Infrastructure taken down successfully!" +echo " Infrastructure taken down successfully!" echo "" -echo "๐Ÿ“ Summary: All cloud resources destroyed, contexts removed, SSH config cleaned" +echo " Summary: All cloud resources destroyed, contexts removed, SSH config cleaned" ``` ### applications @@ -355,12 +355,12 @@ echo "๐Ÿ“ Summary: All cloud resources destroyed, contexts removed, SSH config ```bash set -euo pipefail -echo "๐Ÿ—๏ธ Building and pushing application images..." +echo "๏ธ Building and pushing application images..." for app_dir in applications/*/; do app_name=$(basename "$app_dir") if [[ -f "$app_dir/Dockerfile" ]]; then - echo "๐Ÿ“ฆ Building $app_name..." + echo " Building $app_name..." cd "$app_dir" # Get version from uv @@ -374,13 +374,13 @@ for app_dir in applications/*/; do docker push "pocketsizefund/$app_name:$version" cd .. - echo "โœ… $app_name built and pushed" + echo " $app_name built and pushed" else - echo "โš ๏ธ Skipping $app_name (no Dockerfile found)" + echo "๏ธ Skipping $app_name (no Dockerfile found)" fi done -echo "๐ŸŽ‰ All application images built and pushed successfully!" +echo " All application images built and pushed successfully!" ``` #### deploy @@ -388,30 +388,30 @@ echo "๐ŸŽ‰ All application images built and pushed successfully!" ```bash set -euo pipefail -echo "๐Ÿš€ Deploying applications..." +echo " Deploying applications..." # Deploy to production context if docker context ls --format '{{.Name}}' | grep -q "^pocketsizefund-production$"; then - echo "๐ŸŒ Deploying to production..." + echo " Deploying to production..." docker context use pocketsizefund-production docker stack deploy -c applications/stack.yml applications --with-registry-auth else - echo "โš ๏ธ Production context not found, skipping production deployment" + echo "๏ธ Production context not found, skipping production deployment" fi # Deploy to local context if docker context ls --format '{{.Name}}' | grep -q "^pocketsizefund-local$"; then - echo "๐Ÿ  Deploying to local..." + echo " Deploying to local..." docker context use pocketsizefund-local docker stack deploy -c applications/stack.yml applications --with-registry-auth else - echo "โš ๏ธ Local context not found, skipping local deployment" + echo "๏ธ Local context not found, skipping local deployment" fi # Reset to default context -docker context use default >/dev/null 2>&1 || true +docker context use default >/dev/null >&1 || true -echo "โœ… Application deployment completed!" +echo " Application deployment completed!" echo "" echo "Next steps:" echo " mask test endpoints # Test service endpoints" @@ -426,25 +426,25 @@ echo " mask docker services ls # Check service status" ```bash set -euo pipefail -echo "๐Ÿงช Testing Application Endpoints" +echo " Testing Application Endpoints" echo "=================================" # Function to test HTTP endpoints test_endpoint() { - local name="$1" - local url="$2" - local context="$3" + local name="1$" + local url="2$" + local context="3$" printf "%-25s %-15s " "$name" "[$context]" if timeout 10 curl -sf "$url" >/dev/null 2>&1; then - echo "โœ… OK" + echo " OK" return 0 elif timeout 10 curl -s "$url" >/dev/null 2>&1; then - echo "๐Ÿ”ถ RESPONDING (non-200)" + echo " RESPONDING (non-200)" return 1 else - echo "โŒ FAILED" + echo " FAILED" return 1 fi } @@ -452,12 +452,12 @@ test_endpoint() { # Get production manager IP cd infrastructure MANAGER_IP="" -if pulumi stack --show-name 2>/dev/null; then +if pulumi stack --show-name >/dev/null 2>&1; then MANAGER_IP=$(pulumi stack output managerIp 2>/dev/null || echo "") fi cd .. -echo "๐Ÿ  Local endpoints:" +echo " Local endpoints:" test_endpoint "DataManager" "http://localhost:8080/health" "local" test_endpoint "DataManager (root)" "http://localhost:8080/" "local" test_endpoint "PortfolioManager" "http://localhost:8081/health" "local" @@ -466,18 +466,18 @@ test_endpoint "PortfolioManager (root)" "http://localhost:8081/" "local" # Production tests (if manager IP available) if [[ -n "$MANAGER_IP" ]]; then echo "" - echo "๐ŸŒ Production endpoints (IP: $MANAGER_IP):" + echo " Production endpoints (IP: $MANAGER_IP):" test_endpoint "DataManager" "http://$MANAGER_IP:8080/health" "production" test_endpoint "DataManager (root)" "http://$MANAGER_IP:8080/" "production" - test_endpoint "PortfolioManager" "http://$MANAGER_IP:8081/health" "production" - test_endpoint "PortfolioManager (root)" "http://$MANAGER_IP:8081/" "production" + test_endpoint "PortfolioManager" "http://$MANAGER_IP:808/health" "production" + test_endpoint "PortfolioManager (root)" "http://$MANAGER_IP:808/" "production" else echo "" - echo "โš ๏ธ Production manager IP not available - skipping production tests" + echo "๏ธ Production manager IP not available - skipping production tests" fi -docker context use default >/dev/null 2>&1 || true -echo "โœ… Endpoint testing completed" +docker context use default >/dev/null >&1 || true +echo " Endpoint testing completed" ``` ### health @@ -487,13 +487,13 @@ set -euo pipefail test_service_health() { local context="$1" - echo "๐Ÿ” Docker Services in $context:" + echo " Docker Services in $context:" if docker context ls --format '{{.Name}}' | grep -q "^${context}$"; then - docker context use "$context" >/dev/null 2>&1 + docker context use "$context" >/dev/null >&1 - if docker service ls --format "table {{.Name}}\t{{.Replicas}}\t{{.Image}}" 2>/dev/null; then - echo "โœ… Services listed successfully" + if docker service ls --format "table {{.Name}}\t{{.Replicas}}\t{{.Image}}" >/dev/null; then + echo " Services listed successfully" else echo " No services found or connection error" fi @@ -503,12 +503,12 @@ test_service_health() { echo "" } -echo "๐Ÿ“Š Service Health Check" +echo " Service Health Check" echo "======================" test_service_health "pocketsizefund-local" test_service_health "pocketsizefund-production" -docker context use default >/dev/null 2>&1 || true +docker context use default >/dev/null >&1 || true ``` ### all @@ -516,14 +516,14 @@ docker context use default >/dev/null 2>&1 || true ```bash set -euo pipefail -echo "๐Ÿง  Running complete test suite..." +echo " Running complete test suite..." echo "====================================" mask test endpoints echo "" mask test health -echo "๐ŸŽ‰ Complete test suite finished" +echo " Complete test suite finished" ``` ## docker @@ -535,7 +535,7 @@ echo "๐ŸŽ‰ Complete test suite finished" > Switch to local Docker swarm context ```bash docker context use pocketsizefund-local -echo "โœ… Switched to local context" +echo " Switched to local context" docker context ls | grep "\*" ``` @@ -543,7 +543,7 @@ docker context ls | grep "\*" > Switch to production Docker swarm context ```bash docker context use pocketsizefund-production -echo "โœ… Switched to production context" +echo " Switched to production context" docker context ls | grep "\*" ``` @@ -551,7 +551,7 @@ docker context ls | grep "\*" > Switch back to default Docker context ```bash docker context use default -echo "โœ… Switched to default context" +echo " Switched to default context" ``` ### services @@ -559,10 +559,10 @@ echo "โœ… Switched to default context" #### ls > List all Docker services with health status ```bash -echo "๐Ÿ“Š Docker Services Status:" +echo " Docker Services Status:" docker service ls echo "" -echo "๐Ÿ“‹ Service Health Details:" +echo " Service Health Details:" docker service ls --format "table {{.Name}}\t{{.Replicas}}\t{{.Image}}\t{{.Ports}}" | \ while IFS=$'\t' read -r name replicas image ports; do if [[ "$name" != "NAME" ]]; then @@ -574,18 +574,18 @@ docker service ls --format "table {{.Name}}\t{{.Replicas}}\t{{.Image}}\t{{.Ports #### logs > View logs for a specific service (interactive selection) ```bash -echo "๐Ÿ“‹ Available services:" +echo " Available services:" services=($(docker service ls --format '{{.Name}}')) if [[ ${#services[@]} -eq 0 ]]; then - echo "โŒ No services found" + echo " No services found" exit 1 fi echo "Select a service:" select service in "${services[@]}"; do if [[ -n "$service" ]]; then - echo "๐Ÿ“œ Showing logs for $service (press Ctrl+C to exit):" + echo " Showing logs for $service (press Ctrl+C to exit):" docker service logs -f "$service" break else @@ -597,18 +597,18 @@ done #### inspect > Inspect service configuration and status ```bash -echo "๐Ÿ“‹ Available services:" +echo " Available services:" services=($(docker service ls --format '{{.Name}}')) if [[ ${#services[@]} -eq 0 ]]; then - echo "โŒ No services found" + echo " No services found" exit 1 fi echo "Select a service to inspect:" select service in "${services[@]}"; do if [[ -n "$service" ]]; then - echo "๐Ÿ” Inspecting $service:" + echo " Inspecting $service:" echo "--- Service Tasks ---" docker service ps "$service" echo "" @@ -633,22 +633,22 @@ docker stack ls #### ps > Show tasks for infrastructure and application stacks ```bash -echo "๐Ÿ—๏ธ Infrastructure Stack:" -docker stack ps infrastructure 2>/dev/null || echo "Infrastructure stack not deployed" +echo "๏ธ Infrastructure Stack:" +docker stack ps infrastructure >/dev/null || echo "Infrastructure stack not deployed" echo "" -echo "๐Ÿ“ฑ Applications Stack:" -docker stack ps applications 2>/dev/null || echo "Applications stack not deployed" +echo " Applications Stack:" +docker stack ps applications >/dev/null || echo "Applications stack not deployed" ``` #### rm > Remove infrastructure and application stacks ```bash -echo "๐Ÿ›‘ Removing Docker stacks..." -docker stack rm applications 2>/dev/null && echo "โœ… Applications stack removed" || echo "โš ๏ธ Applications stack not found" -docker stack rm infrastructure 2>/dev/null && echo "โœ… Infrastructure stack removed" || echo "โš ๏ธ Infrastructure stack not found" -echo "๐Ÿงน Waiting for cleanup..." +echo " Removing Docker stacks..." +docker stack rm applications >/dev/null && echo "โœ… Applications stack removed" || echo "โš ๏ธ Applications stack not found" +docker stack rm infrastructure >/dev/null && echo "โœ… Infrastructure stack removed" || echo "โš ๏ธ Infrastructure stack not found" +echo " Waiting for cleanup..." sleep 5 -echo "โœ… Stack removal completed" +echo " Stack removal completed" ``` ## status @@ -656,19 +656,19 @@ echo "โœ… Stack removal completed" ```bash set -euo pipefail -echo "๐Ÿ“Š PocketSizeFund System Status" +echo " PocketSizeFund System Status" echo "===============================" echo "" # Docker contexts -echo "๐Ÿณ Docker Contexts:" +echo " Docker Contexts:" docker context ls echo "" # Pulumi stack info -echo "โ˜๏ธ Infrastructure Status:" +echo "๏ธ Infrastructure Status:" cd infrastructure -if pulumi stack --show-name 2>/dev/null; then +if pulumi stack --show-name >/dev/null; then echo "Current stack: $(pulumi stack --show-name)" if MANAGER_IP=$(pulumi stack output managerIp 2>/dev/null); then echo "Manager IP: $MANAGER_IP" @@ -683,20 +683,20 @@ echo "" # Docker stacks echo "๐Ÿ“š Docker Stacks:" -if docker context use pocketsizefund-local >/dev/null 2>&1; then +if docker context use pocketsizefund-local >/dev/null >&1; then echo "Local stacks:" - docker stack ls 2>/dev/null || echo " No stacks deployed locally" + docker stack ls >/dev/null || echo " No stacks deployed locally" fi echo "" -if docker context use pocketsizefund-production >/dev/null 2>&1; then +if docker context use pocketsizefund-production >/dev/null >&1; then echo "Production stacks:" - docker stack ls 2>/dev/null || echo " No stacks deployed in production" + docker stack ls >/dev/null || echo " No stacks deployed in production" fi # Reset context -docker context use default >/dev/null 2>&1 || true +docker context use default >/dev/null >&1 || true echo "" -echo "โœ… Status check completed" +echo " Status check completed" ``` ## secrets @@ -705,11 +705,11 @@ echo "โœ… Status check completed" ### list > List all Docker secrets (requires active swarm context) ```bash -echo "๐Ÿ” Docker Secrets:" +echo " Docker Secrets:" if docker info --format '{{.Swarm.LocalNodeState}}' 2>/dev/null | grep -q active; then docker secret ls else - echo "โŒ Not connected to Docker swarm - switch context first:" + echo " Not connected to Docker swarm - switch context first:" echo " mask docker context local" echo " mask docker context production" fi @@ -721,13 +721,13 @@ fi set -euo pipefail if ! docker info --format '{{.Swarm.LocalNodeState}}' 2>/dev/null | grep -q active; then - echo "โŒ Not connected to Docker swarm - switch context first:" + echo " Not connected to Docker swarm - switch context first:" echo " mask docker context local" echo " mask docker context production" exit 1 fi -echo "๐Ÿ” Creating Docker Secrets" +echo " Creating Docker Secrets" echo "=========================" echo "Leave blank to skip a secret" echo "" @@ -753,18 +753,18 @@ for secret_info in "${secrets[@]}"; do read -r secret_value if [[ -n "$secret_value" ]]; then - if echo "$secret_value" | docker secret create "$secret_name" - 2>/dev/null; then - echo "โœ… Created $secret_name" + if echo "$secret_value" | docker secret create "$secret_name" - >/dev/null; then + echo " Created $secret_name" else - echo "โš ๏ธ $secret_name already exists or creation failed" + echo "๏ธ $secret_name already exists or creation failed" fi else - echo "โญ๏ธ Skipped $secret_name" + echo "๏ธ Skipped $secret_name" fi echo "" done -echo "๐Ÿ” Secret creation completed" +echo " Secret creation completed" ``` ## development @@ -777,11 +777,11 @@ echo "๐Ÿ” Secret creation completed" ```bash set -euo pipefail -echo "๐Ÿ“ฆ Installing Python dependencies..." +echo " Installing Python dependencies..." export COMPOSE_BAKE=true uv sync --all-packages --all-groups -echo "โœ… Python dependencies installed successfully" +echo " Python dependencies installed successfully" echo "" echo "Next steps:" echo " mask development python format # Format code" @@ -794,10 +794,10 @@ echo " mask development python test # Run tests" ```bash set -euo pipefail -echo "๐ŸŽจ Formatting Python code..." +echo " Formatting Python code..." ruff format -echo "โœ… Python code formatted successfully" +echo " Python code formatted successfully" ``` #### dead-code @@ -805,16 +805,16 @@ echo "โœ… Python code formatted successfully" ```bash set -euo pipefail -echo "๐Ÿงน Checking for dead Python code..." +echo " Checking for dead Python code..." mask development python format -echo "๐Ÿ” Running vulture dead code analysis..." +echo " Running vulture dead code analysis..." uvx vulture \ --min-confidence 80 \ --exclude '.flox,.venv,target' \ . -echo "โœ… Dead code check completed" +echo " Dead code check completed" ``` #### lint @@ -822,12 +822,12 @@ echo "โœ… Dead code check completed" ```bash set -euo pipefail -echo "๐Ÿ” Running Python code quality checks..." +echo " Running Python code quality checks..." # Run dead code check first (which includes formatting) mask development python dead-code -echo "๐Ÿ“‹ Running ruff linting..." +echo " Running ruff linting..." ruff check \ --output-format=github \ . @@ -835,7 +835,7 @@ ruff check \ # Note: ty check commented out in original # uvx ty check -echo "โœ… Python linting completed successfully" +echo " Python linting completed successfully" ``` #### test @@ -843,29 +843,29 @@ echo "โœ… Python linting completed successfully" ```bash set -euo pipefail -echo "๐Ÿงช Running Python tests..." +echo " Running Python tests..." # Create coverage output directory mkdir -p coverage_output # Clean up any existing test containers -echo "๐Ÿงน Cleaning up previous test runs..." +echo " Cleaning up previous test runs..." docker compose --file tests.yaml down --volumes --remove-orphans # Build test containers -echo "๐Ÿ—๏ธ Building test containers..." +echo "๏ธ Building test containers..." docker compose --file tests.yaml build tests # Run tests -echo "๐Ÿš€ Running tests with coverage..." +echo " Running tests with coverage..." docker compose --file tests.yaml run --rm --no-TTY tests # Clean up after tests -echo "๐Ÿงน Cleaning up test containers..." +echo " Cleaning up test containers..." docker compose --file tests.yaml down --volumes --remove-orphans -echo "โœ… Python tests completed successfully" -echo "๐Ÿ“Š Coverage report available in coverage_output/.python_coverage.xml" +echo " Python tests completed successfully" +echo " Coverage report available in coverage_output/.python_coverage.xml" ``` ### quality @@ -873,19 +873,19 @@ echo "๐Ÿ“Š Coverage report available in coverage_output/.python_coverage.xml" ```bash set -euo pipefail -echo "๐Ÿ” Running comprehensive code quality checks..." +echo " Running comprehensive code quality checks..." # Run Python quality checks mask development python lint # Run additional linting tools -echo "๐Ÿ“‹ Running additional linters..." +echo " Running additional linters..." nu linter.nu yamllint -d "{extends: relaxed, rules: {line-length: {max: 110}}}" . -echo "โœ… All quality checks completed successfully" +echo " All quality checks completed successfully" echo "" -echo "๐ŸŽ‰ Code is ready for review!" +echo " Code is ready for review!" ``` ## logs @@ -894,7 +894,7 @@ echo "๐ŸŽ‰ Code is ready for review!" ### infrastructure > View logs for infrastructure services (Grafana, Prometheus, Traefik) ```bash -echo "๐Ÿ“‹ Infrastructure Service Logs" +echo " Infrastructure Service Logs" echo "Select environment:" select env in "local" "production"; do case $env in @@ -910,7 +910,7 @@ select env in "local" "production"; do echo "Select service:" select service in "${services[@]}"; do if [[ -n "$service" ]]; then - echo "๐Ÿ“œ Logs for $service (press Ctrl+C to exit):" + echo " Logs for $service (press Ctrl+C to exit):" docker service logs -f "$service" break fi @@ -922,13 +922,13 @@ select env in "local" "production"; do ;; esac done -docker context use default >/dev/null 2>&1 || true +docker context use default >/dev/null >&1 || true ``` ### applications > View logs for application services (DataManager, PortfolioManager) ```bash -echo "๐Ÿ“‹ Application Service Logs" +echo " Application Service Logs" echo "Select environment:" select env in "local" "production"; do case $env in @@ -944,7 +944,7 @@ select env in "local" "production"; do echo "Select service:" select service in "${services[@]}"; do if [[ -n "$service" ]]; then - echo "๐Ÿ“œ Logs for $service (press Ctrl+C to exit):" + echo " Logs for $service (press Ctrl+C to exit):" docker service logs -f "$service" break fi @@ -956,5 +956,5 @@ select env in "local" "production"; do ;; esac done -docker context use default >/dev/null 2>&1 || true +docker context use default >/dev/null >&1 || true ``` diff --git a/uv.lock b/uv.lock index 39a43c3c1..63b93ea81 100644 --- a/uv.lock +++ b/uv.lock @@ -117,7 +117,7 @@ wheels = [ [[package]] name = "alpaca-py" -version = "0.42.0" +version = "0.42.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "msgpack" }, @@ -127,9 +127,9 @@ dependencies = [ { name = "sseclient-py" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2d/bf/3c2712ec8b9c4a36d5a7a2b58217512e4e26b3ffa02db554b33f9c9a5ba2/alpaca_py-0.42.0.tar.gz", hash = "sha256:3ac4fd3439b8701d678db38bbc797b12bc003190996b19b0aa5f0d22aea65be7", size = 97534, upload-time = "2025-07-04T15:31:54.727Z" } +sdist = { url = "https://files.pythonhosted.org/packages/23/23/f33b24af8dfd3149f490685dda1ee64f88d15b75fbf9057ac5c83e58268c/alpaca_py-0.42.1.tar.gz", hash = "sha256:a44bf45d40d34fbcba1125a3144ddaea62b44806cfb25e3d948029bb5a233a37", size = 97671, upload-time = "2025-08-29T18:40:52.432Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e8/6f/669d3aba1d010be97169bfa25b55eba561545335ae97f9efce86954b20da/alpaca_py-0.42.0-py3-none-any.whl", hash = "sha256:e7e874fc9090c07a7b1698603e34f2e48be4f27beab47d7fc3d38d6f6855513a", size = 121974, upload-time = "2025-07-04T15:31:53.607Z" }, + { url = "https://files.pythonhosted.org/packages/ef/b5/efdf8d3d206632b0dfaad07193dd396413235f9d9a96c1843d83594851d4/alpaca_py-0.42.1-py3-none-any.whl", hash = "sha256:234fea37151d4d5995de6d4f2fff24d584120214eb9f39198c399d1000004efa", size = 122157, upload-time = "2025-08-29T18:40:50.9Z" }, ] [[package]] @@ -473,7 +473,7 @@ wheels = [ [[package]] name = "datamanager" version = "0.1.0" -source = { editable = "applications/datamanager" } +source = { virtual = "applications/datamanager" } dependencies = [ { name = "duckdb" }, { name = "fastapi" }, @@ -1292,7 +1292,7 @@ wheels = [ [[package]] name = "pandas" -version = "2.3.1" +version = "2.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, @@ -1300,15 +1300,15 @@ dependencies = [ { name = "pytz" }, { name = "tzdata" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/6f/75aa71f8a14267117adeeed5d21b204770189c0a0025acbdc03c337b28fc/pandas-2.3.1.tar.gz", hash = "sha256:0a95b9ac964fe83ce317827f80304d37388ea77616b1425f0ae41c9d2d0d7bb2", size = 4487493, upload-time = "2025-07-07T19:20:04.079Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/8e/0e90233ac205ad182bd6b422532695d2b9414944a280488105d598c70023/pandas-2.3.2.tar.gz", hash = "sha256:ab7b58f8f82706890924ccdfb5f48002b83d2b5a3845976a9fb705d36c34dcdb", size = 4488684, upload-time = "2025-08-21T10:28:29.257Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/de/b8445e0f5d217a99fe0eeb2f4988070908979bec3587c0633e5428ab596c/pandas-2.3.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:689968e841136f9e542020698ee1c4fbe9caa2ed2213ae2388dc7b81721510d3", size = 11588172, upload-time = "2025-07-07T19:18:52.054Z" }, - { url = "https://files.pythonhosted.org/packages/1e/e0/801cdb3564e65a5ac041ab99ea6f1d802a6c325bb6e58c79c06a3f1cd010/pandas-2.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:025e92411c16cbe5bb2a4abc99732a6b132f439b8aab23a59fa593eb00704232", size = 10717365, upload-time = "2025-07-07T19:18:54.785Z" }, - { url = "https://files.pythonhosted.org/packages/51/a5/c76a8311833c24ae61a376dbf360eb1b1c9247a5d9c1e8b356563b31b80c/pandas-2.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b7ff55f31c4fcb3e316e8f7fa194566b286d6ac430afec0d461163312c5841e", size = 11280411, upload-time = "2025-07-07T19:18:57.045Z" }, - { url = "https://files.pythonhosted.org/packages/da/01/e383018feba0a1ead6cf5fe8728e5d767fee02f06a3d800e82c489e5daaf/pandas-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7dcb79bf373a47d2a40cf7232928eb7540155abbc460925c2c96d2d30b006eb4", size = 11988013, upload-time = "2025-07-07T19:18:59.771Z" }, - { url = "https://files.pythonhosted.org/packages/5b/14/cec7760d7c9507f11c97d64f29022e12a6cc4fc03ac694535e89f88ad2ec/pandas-2.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:56a342b231e8862c96bdb6ab97170e203ce511f4d0429589c8ede1ee8ece48b8", size = 12767210, upload-time = "2025-07-07T19:19:02.944Z" }, - { url = "https://files.pythonhosted.org/packages/50/b9/6e2d2c6728ed29fb3d4d4d302504fb66f1a543e37eb2e43f352a86365cdf/pandas-2.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ca7ed14832bce68baef331f4d7f294411bed8efd032f8109d690df45e00c4679", size = 13440571, upload-time = "2025-07-07T19:19:06.82Z" }, - { url = "https://files.pythonhosted.org/packages/80/a5/3a92893e7399a691bad7664d977cb5e7c81cf666c81f89ea76ba2bff483d/pandas-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:ac942bfd0aca577bef61f2bc8da8147c4ef6879965ef883d8e8d5d2dc3e744b8", size = 10987601, upload-time = "2025-07-07T19:19:09.589Z" }, + { url = "https://files.pythonhosted.org/packages/ec/db/614c20fb7a85a14828edd23f1c02db58a30abf3ce76f38806155d160313c/pandas-2.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fbb977f802156e7a3f829e9d1d5398f6192375a3e2d1a9ee0803e35fe70a2b9", size = 11587652, upload-time = "2025-08-21T10:27:15.888Z" }, + { url = "https://files.pythonhosted.org/packages/99/b0/756e52f6582cade5e746f19bad0517ff27ba9c73404607c0306585c201b3/pandas-2.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b9b52693123dd234b7c985c68b709b0b009f4521000d0525f2b95c22f15944b", size = 10717686, upload-time = "2025-08-21T10:27:18.486Z" }, + { url = "https://files.pythonhosted.org/packages/37/4c/dd5ccc1e357abfeee8353123282de17997f90ff67855f86154e5a13b81e5/pandas-2.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bd281310d4f412733f319a5bc552f86d62cddc5f51d2e392c8787335c994175", size = 11278722, upload-time = "2025-08-21T10:27:21.149Z" }, + { url = "https://files.pythonhosted.org/packages/d3/a4/f7edcfa47e0a88cda0be8b068a5bae710bf264f867edfdf7b71584ace362/pandas-2.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96d31a6b4354e3b9b8a2c848af75d31da390657e3ac6f30c05c82068b9ed79b9", size = 11987803, upload-time = "2025-08-21T10:27:23.767Z" }, + { url = "https://files.pythonhosted.org/packages/f6/61/1bce4129f93ab66f1c68b7ed1c12bac6a70b1b56c5dab359c6bbcd480b52/pandas-2.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:df4df0b9d02bb873a106971bb85d448378ef14b86ba96f035f50bbd3688456b4", size = 12766345, upload-time = "2025-08-21T10:27:26.6Z" }, + { url = "https://files.pythonhosted.org/packages/8e/46/80d53de70fee835531da3a1dae827a1e76e77a43ad22a8cd0f8142b61587/pandas-2.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:213a5adf93d020b74327cb2c1b842884dbdd37f895f42dcc2f09d451d949f811", size = 13439314, upload-time = "2025-08-21T10:27:29.213Z" }, + { url = "https://files.pythonhosted.org/packages/28/30/8114832daff7489f179971dbc1d854109b7f4365a546e3ea75b6516cea95/pandas-2.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:8c13b81a9347eb8c7548f53fd9a4f08d4dfe996836543f805c987bafa03317ae", size = 10983326, upload-time = "2025-08-21T10:27:31.901Z" }, ] [[package]] @@ -1328,6 +1328,10 @@ wheels = [ ] [package.optional-dependencies] +pandas = [ + { name = "numpy" }, + { name = "pandas" }, +] polars = [ { name = "polars" }, ] @@ -1430,15 +1434,21 @@ name = "portfoliomanager" version = "0.1.0" source = { editable = "applications/portfoliomanager" } dependencies = [ + { name = "alpaca-py" }, { name = "fastapi" }, { name = "httpx" }, + { name = "internal" }, + { name = "pandera", extra = ["pandas", "polars"] }, { name = "uvicorn" }, ] [package.metadata] requires-dist = [ + { name = "alpaca-py", specifier = ">=0.42.1" }, { name = "fastapi", specifier = ">=0.116.1" }, { name = "httpx", specifier = ">=0.27.0" }, + { name = "internal", editable = "libraries/python" }, + { name = "pandera", extras = ["polars", "pandas"], specifier = ">=0.26.0" }, { name = "uvicorn", specifier = ">=0.35.0" }, ]