diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 664a98426..e86f2e7f8 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -18,5 +18,6 @@ "Bash(gh issue close:*)" ], "deny": [] - } -} + }, + "enableAllProjectMcpServers": false +} \ No newline at end of file diff --git a/.flox/env/manifest.toml b/.flox/env/manifest.toml index 993d21e53..82807107a 100644 --- a/.flox/env/manifest.toml +++ b/.flox/env/manifest.toml @@ -12,14 +12,6 @@ nushell.pkg-path = "nushell" fd.pkg-path = "fd" fselect.pkg-path = "fselect" -[vars] -ALPACA_API_KEY = "${ALPACA_API_KEY}" -ALPACA_API_SECRET = "${ALPACA_API_SECRET}" -POLYGON_API_KEY = "${POLYGON_API_KEY}" -DATA_BUCKET = "${DATA_BUCKET}" -DUCKDB_ACCESS_KEY = "${DUCKDB_ACCESS_KEY}" -DUCKDB_SECRET = "${DUCKDB_SECRET}" - [options] systems = [ "aarch64-darwin", diff --git a/application/datamanager/Dockerfile b/application/datamanager/Dockerfile index ca2175599..2497c2132 100644 --- a/application/datamanager/Dockerfile +++ b/application/datamanager/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.12 +FROM python:3.12.10 COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv diff --git a/application/datamanager/pyproject.toml b/application/datamanager/pyproject.toml index f145daa09..68f3515e8 100644 --- a/application/datamanager/pyproject.toml +++ b/application/datamanager/pyproject.toml @@ -2,7 +2,7 @@ name = "datamanager" version = "0.1.0" description = "Data management service" -requires-python = "==3.13" +requires-python = "==3.12.10" dependencies = [ "fastapi>=0.115.12", "uvicorn>=0.34.2", diff --git a/application/datamanager/src/datamanager/main.py b/application/datamanager/src/datamanager/main.py index 9de041380..11e76ddc3 100644 --- a/application/datamanager/src/datamanager/main.py +++ b/application/datamanager/src/datamanager/main.py @@ -137,8 +137,8 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars polygon = request.app.state.settings.polygon bucket = request.app.state.settings.gcp.bucket - summary_date: str = summary_date.date.strftime("%Y-%m-%d") - url = f"{polygon.base_url}{polygon.daily_bars}{summary_date}" + request_summary_date: str = summary_date.date.strftime("%Y-%m-%d") + url = f"{polygon.base_url}{polygon.daily_bars}{request_summary_date}" logger.info(f"polygon_api_endpoint={url}") params = {"adjusted": "true", "apiKey": polygon.api_key} @@ -178,7 +178,7 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to write data", ) from e - return BarsSummary(date=summary_date, count=count) + return BarsSummary(date=request_summary_date, count=count) @application.delete("/equity-bars") diff --git a/application/datamanager/tests/test_datamanager_main.py b/application/datamanager/tests/test_datamanager_main.py new file mode 100644 index 000000000..268d2e3b3 --- /dev/null +++ b/application/datamanager/tests/test_datamanager_main.py @@ -0,0 +1,93 @@ +import unittest +from datetime import date +from unittest.mock import MagicMock, patch + +from fastapi import status +from fastapi.testclient import TestClient + +from application.datamanager.src.datamanager.main import application +from application.datamanager.src.datamanager.models import BarsSummary, SummaryDate + +client = TestClient(application) + + +def test_health_check() -> None: + response = client.get("/health") + assert response.status_code == status.HTTP_200_OK + + +class TestDataManagerModels(unittest.TestCase): + def test_summary_date_default(self) -> None: + summary_date = SummaryDate() + assert isinstance(summary_date.date, date) + + def test_summary_date_with_date(self) -> None: + test_date = date(2023, 1, 1) + summary_date = SummaryDate(date=test_date) + assert summary_date.date == test_date + + def test_summary_date_string_parsing(self) -> None: + summary_date = SummaryDate(date="2023-01-01") # type: ignore + assert summary_date.date == date(2023, 1, 1) + + def test_bars_summary_creation(self) -> None: + bars_summary = BarsSummary(date="2023-01-01", count=100) + assert bars_summary.date == "2023-01-01" + assert bars_summary.count == 100 # noqa: PLR2004 + + +class TestEquityBarsEndpoints(unittest.TestCase): + def test_get_equity_bars_missing_parameters(self) -> None: + response = client.get("/equity-bars") + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_get_equity_bars_invalid_date_format(self) -> None: + response = client.get( + "/equity-bars", + params={"start_date": "invalid-date", "end_date": "2023-01-02"}, + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_post_equity_bars_missing_body(self) -> None: + response = client.post("/equity-bars") + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_post_equity_bars_invalid_date(self) -> None: + response = client.post("/equity-bars", json={"date": "invalid-date"}) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_delete_equity_bars_missing_body(self) -> None: + response = client.request("DELETE", "/equity-bars") + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_delete_equity_bars_invalid_date(self) -> None: + response = client.request( + "DELETE", "/equity-bars", json={"date": "invalid-date"} + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + @patch("application.datamanager.src.datamanager.main.duckdb") + def test_get_equity_bars_database_error(self, mock_duckdb: MagicMock) -> None: + from duckdb import IOException + + mock_connection = MagicMock() + mock_connection.execute.side_effect = IOException("Database error") + mock_duckdb.connect.return_value = mock_connection + + mock_settings = MagicMock() + mock_settings.gcp.bucket.name = "test-bucket" + + with patch.object(application, "state") as mock_app_state: + mock_app_state.connection = mock_connection + mock_app_state.settings = mock_settings + + response = client.get( + "/equity-bars", + params={"start_date": "2023-01-01", "end_date": "2023-01-02"}, + ) + + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +if __name__ == "__main__": + unittest.main() diff --git a/application/datamanager/tests/test_datamanager_models.py b/application/datamanager/tests/test_datamanager_models.py new file mode 100644 index 000000000..7ccbbf64e --- /dev/null +++ b/application/datamanager/tests/test_datamanager_models.py @@ -0,0 +1,131 @@ +import unittest +from datetime import date + +import pytest +from pydantic import ValidationError + +from application.datamanager.src.datamanager.models import ( + BarsSummary, + DateRange, + SummaryDate, +) + + +class TestSummaryDate(unittest.TestCase): + def test_summary_date_initialization_default(self) -> None: + summary_date = SummaryDate() + assert isinstance(summary_date.date, date) + + def test_summary_date_initialization_with_date(self) -> None: + test_date = date(2023, 5, 15) + summary_date = SummaryDate(date=test_date) + assert summary_date.date == test_date + + def test_summary_date_string_parsing_iso_format(self) -> None: + summary_date = SummaryDate(date="2023-5-15") # type: ignore + assert summary_date.date == date(2023, 5, 15) + + def test_summary_date_string_parsing_slash_format(self) -> None: + summary_date = SummaryDate(date="2023/05/15") # type: ignore + assert summary_date.date == date(2023, 5, 15) + + def test_summary_date_invalid_format(self) -> None: + with pytest.raises(ValidationError, match="Invalid date format"): + SummaryDate(date="invalid-date") # type: ignore + + def test_summary_date_invalid_date_values(self) -> None: + with pytest.raises(ValidationError): + SummaryDate(date="2023-13-01") # type: ignore + + def test_summary_date_json_encoder(self) -> None: + test_date = date(2023, 5, 15) + summary_date = SummaryDate(date=test_date) + json_data = summary_date.model_dump(mode="json") + assert json_data["date"] == "2023/05/15" + + +class TestDateRange(unittest.TestCase): + def test_date_range_valid(self) -> None: + start_date = date(2023, 1, 1) + end_date = date(2023, 12, 31) + date_range = DateRange(start=start_date, end=end_date) + + assert date_range.start == start_date + assert date_range.end == end_date + + def test_date_range_same_dates(self) -> None: + same_date = date(2023, 5, 15) + with pytest.raises(ValidationError, match="End date must be after start date"): + DateRange(start=same_date, end=same_date) + + def test_date_range_end_before_start(self) -> None: + start_date = date(2023, 12, 31) + end_date = date(2023, 1, 1) + with pytest.raises(ValidationError, match="End date must be after start date"): + DateRange(start=start_date, end=end_date) + + def test_date_range_valid_one_day_apart(self) -> None: + start_date = date(2023, 5, 15) + end_date = date(2023, 5, 16) + date_range = DateRange(start=start_date, end=end_date) + + assert date_range.start == start_date + assert date_range.end == end_date + + +class TestBarsSummary(unittest.TestCase): + def test_bars_summary_initialization(self) -> None: + bars_summary = BarsSummary(date="2023-05-15", count=1500) + + assert bars_summary.date == "2023-05-15" + assert bars_summary.count == 1500 # noqa: PLR2004 + + def test_bars_summary_zero_count(self) -> None: + bars_summary = BarsSummary(date="2023-05-15", count=0) + + assert bars_summary.date == "2023-05-15" + assert bars_summary.count == 0 + + def test_bars_summary_negative_count(self) -> None: + bars_summary = BarsSummary(date="2023-05-15", count=-1) + + assert bars_summary.date == "2023-05-15" + assert bars_summary.count == -1 + + def test_bars_summary_json_serialization(self) -> None: + bars_summary = BarsSummary(date="2023-05-15", count=1500) + json_data = bars_summary.model_dump() + + assert json_data == {"date": "2023-05-15", "count": 1500} + + def test_bars_summary_from_dict(self) -> None: + data = {"date": "2023-05-15", "count": 1500} + bars_summary = BarsSummary.model_validate(data) + + assert bars_summary.date == "2023-05-15" + assert bars_summary.count == 1500 # noqa: PLR2004 + + +class TestModelIntegration(unittest.TestCase): + def test_summary_date_to_bars_summary(self) -> None: + summary_date = SummaryDate(date="2023-05-15") # type: ignore + bars_summary = BarsSummary( + date=summary_date.date.strftime("%Y-%m-%d"), count=100 + ) + + assert bars_summary.date == "2023-05-15" + assert bars_summary.count == 100 # noqa: PLR2004 + + def test_multiple_model_validation(self) -> None: + summary_date = SummaryDate(date="2023-05-15") # type: ignore + date_range = DateRange(start=date(2023, 1, 1), end=date(2023, 12, 31)) + bars_summary = BarsSummary(date="2023-05-15", count=1000) + + assert summary_date.date == date(2023, 5, 15) + assert date_range.start == date(2023, 1, 1) + assert date_range.end == date(2023, 12, 31) + assert bars_summary.count == 1000 # noqa: PLR2004 + + +if __name__ == "__main__": + unittest.main() diff --git a/application/positionmanager/pyproject.toml b/application/positionmanager/pyproject.toml index 3893b963e..f4ba3d169 100644 --- a/application/positionmanager/pyproject.toml +++ b/application/positionmanager/pyproject.toml @@ -2,7 +2,7 @@ name = "positionmanager" version = "0.1.0" description = "Position management service" -requires-python = "==3.13" +requires-python = "==3.12.10" dependencies = [ "fastapi>=0.115.12", "uvicorn>=0.34.2", @@ -14,6 +14,7 @@ dependencies = [ "pyportfolioopt>=1.5.6", "ecos>=2.0.14", "prometheus-fastapi-instrumentator>=7.1.0", + "pyarrow>=20.0.0", ] [tool.hatch.build.targets.wheel] diff --git a/application/positionmanager/src/positionmanager/clients.py b/application/positionmanager/src/positionmanager/clients.py index 8850d8390..20094faf4 100644 --- a/application/positionmanager/src/positionmanager/clients.py +++ b/application/positionmanager/src/positionmanager/clients.py @@ -1,6 +1,7 @@ from typing import Any import polars as pl +import pyarrow as pa import requests from alpaca.trading.client import TradingClient from alpaca.trading.enums import OrderSide, TimeInForce @@ -18,8 +19,8 @@ def __init__( paper: bool = True, ) -> None: if not api_key or not api_secret: - msg = "Alpaca API key and secret are required" - raise ValueError(msg) + message = "Alpaca API key and secret are required" + raise ValueError(message) self.trading_client: TradingClient = TradingClient( api_key, api_secret, paper=paper @@ -30,8 +31,8 @@ def get_cash_balance(self) -> Money: cash_balance = getattr(account, "cash", None) if cash_balance is None: - msg = "Cash balance is not available" - raise ValueError(msg) + message = "Cash balance is not available" + raise ValueError(message) return Money.from_float(float(cash_balance)) @@ -72,32 +73,43 @@ def get_data( date_range: DateRange, ) -> pl.DataFrame: if not self.datamanager_base_url: - msg = "Data manager URL is not configured" - raise ValueError(msg) + message = "Data manager URL is not configured" + raise ValueError(message) endpoint = f"{self.datamanager_base_url}/equity-bars" + params = { + "start_date": date_range.start.date().isoformat(), + "end_date": date_range.end.date().isoformat(), + } + try: - response = requests.post(endpoint, json=date_range.to_payload(), timeout=10) + response = requests.get(endpoint, params=params, timeout=30) except requests.RequestException as err: - msg = f"Data manager service call error: {err}" - raise RuntimeError(msg) from err + message = f"Data manager service call error: {err}" + raise RuntimeError(message) from err - response.raise_for_status() + if response.status_code == requests.codes["no_content"]: + return pl.DataFrame() + if response.status_code != requests.codes["ok"]: + message = f"Data service error: {response.text}, status code: {response.status_code}" # noqa: E501 + raise requests.HTTPError( + message, + response=response, + ) - response_data = response.json() + buffer = pa.py_buffer(response.content) + reader = pa.ipc.RecordBatchStreamReader(buffer) + table = reader.read_all() - data = pl.DataFrame(response_data["data"]) + data = pl.DataFrame(pl.from_arrow(table)) data = data.with_columns( - pl.col("timestamp") - .str.slice(0, 10) - .str.strptime(pl.Date, "%Y-%m-%d") - .alias("date"), + pl.col("datetime").cast(pl.Datetime).dt.date().alias("date") ) return ( data.sort("date") - .pivot(on="ticker", index="date", values="close_price") + .pivot(on="T", index="date", values="c") .with_columns(pl.all().exclude("date").cast(pl.Float64)) ) diff --git a/application/positionmanager/src/positionmanager/main.py b/application/positionmanager/src/positionmanager/main.py index 1a7670c08..bc7c683a7 100644 --- a/application/positionmanager/src/positionmanager/main.py +++ b/application/positionmanager/src/positionmanager/main.py @@ -4,7 +4,7 @@ import polars as pl import requests -from alpaca.common.rest import APIError +from alpaca.common.exceptions import APIError from fastapi import FastAPI, HTTPException from prometheus_fastapi_instrumentator import Instrumentator from pydantic import ValidationError diff --git a/application/predictionengine/Dockerfile b/application/predictionengine/Dockerfile index 57325507e..0feabfc02 100644 --- a/application/predictionengine/Dockerfile +++ b/application/predictionengine/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.12 +FROM python:3.12.10 COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv diff --git a/application/predictionengine/pyproject.toml b/application/predictionengine/pyproject.toml index ed10cef5d..fe6b95551 100644 --- a/application/predictionengine/pyproject.toml +++ b/application/predictionengine/pyproject.toml @@ -2,7 +2,7 @@ name = "predictionengine" version = "0.1.0" description = "Prediction engine service" -requires-python = ">=3.12" # possibly 3.10 +requires-python = "==3.12.10" dependencies = [ "fastapi>=0.115.12", "uvicorn>=0.34.2", diff --git a/application/predictionengine/src/predictionengine/dataset.py b/application/predictionengine/src/predictionengine/dataset.py index a87fcf4aa..6f44c41fa 100644 --- a/application/predictionengine/src/predictionengine/dataset.py +++ b/application/predictionengine/src/predictionengine/dataset.py @@ -134,8 +134,8 @@ def _scale_data(self, data: pl.DataFrame) -> Tensor: groups.append(combined_group) if not groups: - msg = "No data available after preprocessing" - raise ValueError(msg) + message = "No data available after preprocessing" + raise ValueError(message) output_data = Tensor.empty(groups[0].shape) return output_data.cat(*groups, dim=0) @@ -154,8 +154,8 @@ def load_data(self, data: pl.DataFrame) -> None: def get_preprocessors(self) -> dict[str, Any]: if not self.preprocessors: - msg = "Preprocessors have not been initialized." - raise ValueError(msg) + message = "Preprocessors have not been initialized." + raise ValueError(message) means_by_ticker = { ticker: values["means"] for ticker, values in self.scalers.items() @@ -196,8 +196,8 @@ def batches(self) -> Generator[tuple[Tensor, Tensor, Tensor], None, None]: ] if not batch_tensors: - msg = "Cannot stack empty batch tensors (batch_size must be ≥ 1)" - raise ValueError(msg) + message = "Cannot stack empty batch tensors (batch_size must be ≥ 1)" + raise ValueError(message) if len(batch_tensors) == 1: historical_features = batch_tensors[0].unsqueeze(0) else: diff --git a/application/predictionengine/src/predictionengine/long_short_term_memory.py b/application/predictionengine/src/predictionengine/long_short_term_memory.py index 671d78e7f..c19e6e24f 100644 --- a/application/predictionengine/src/predictionengine/long_short_term_memory.py +++ b/application/predictionengine/src/predictionengine/long_short_term_memory.py @@ -58,8 +58,8 @@ def forward( outputs.append(hidden_state[-1]) if not outputs: - msg = "Cannot stack empty outputs list" - raise ValueError(msg) + message = "Cannot stack empty outputs list" + raise ValueError(message) if len(outputs) == 1: output_tensor = outputs[0].unsqueeze(1) diff --git a/application/predictionengine/src/predictionengine/loss_function.py b/application/predictionengine/src/predictionengine/loss_function.py index 0d3603305..fcbf49546 100644 --- a/application/predictionengine/src/predictionengine/loss_function.py +++ b/application/predictionengine/src/predictionengine/loss_function.py @@ -12,12 +12,12 @@ def quantile_loss( quantiles = (0.25, 0.5, 0.75) if y_pred.shape != y_true.shape: - msg = f"Shape mismatch: y_pred {y_pred.shape} vs y_true {y_true.shape}" - raise ValueError(msg) + message = f"Shape mismatch: y_pred {y_pred.shape} vs y_true {y_true.shape}" + raise ValueError(message) if not all(0 <= q <= 1 for q in quantiles): - msg = "All quantiles must be between 0 and 1" - raise ValueError(msg) + message = "All quantiles must be between 0 and 1" + raise ValueError(message) loss: Tensor = Tensor.zeros(1) error = cast("Tensor", y_true - y_pred) diff --git a/application/predictionengine/src/predictionengine/main.py b/application/predictionengine/src/predictionengine/main.py index 06c5a7cd3..9da95c7f8 100644 --- a/application/predictionengine/src/predictionengine/main.py +++ b/application/predictionengine/src/predictionengine/main.py @@ -15,7 +15,7 @@ from .miniature_temporal_fusion_transformer import MiniatureTemporalFusionTransformer from .models import PredictionResponse -LOOKBACK_DAYS = 30 +SEQUENCE_LENGTH = 30 class LoadError(Exception): @@ -49,7 +49,7 @@ def fetch_historical_data( "end_date": end_date.isoformat(), } - response = requests.get(url, params=parameters, timeout=30) + response = requests.get(url, params=parameters, timeout=SEQUENCE_LENGTH) response.raise_for_status() import pyarrow as pa @@ -64,7 +64,7 @@ def fetch_historical_data( def load_or_initialize_model(data: pl.DataFrame) -> MiniatureTemporalFusionTransformer: dataset = DataSet( batch_size=32, - sequence_length=LOOKBACK_DAYS, + sequence_length=SEQUENCE_LENGTH, sample_count=len(data), ) dataset.load_data(data) @@ -83,14 +83,14 @@ def load_or_initialize_model(data: pl.DataFrame) -> MiniatureTemporalFusionTrans ticker_encoder=preprocessors["ticker_encoder"], dropout_rate=0.0, ) - model_path = "miniature_temporal_fusion_transformer.safetensor" if Path(model_path).exists(): try: model.load(model_path) logger.info("Loaded existing model weights") - except LoadError as e: - logger.error(f"Failed to load model weights: {e}") + except Exception as e: # noqa: BLE001 + logger.warning(f"Failed to load model weights: {e}") + logger.warning(f"Failed to load model weights: {e}") return model @@ -101,7 +101,7 @@ async def create_predictions( ) -> PredictionResponse: try: end_date = datetime.now(tz=UTC).date() - start_date = end_date - timedelta(days=30) + start_date = end_date - timedelta(days=SEQUENCE_LENGTH) logger.info(f"Fetching data from {start_date} to {end_date}") data = fetch_historical_data( @@ -124,15 +124,15 @@ async def create_predictions( for ticker in unique_tickers: ticker_data = data.filter(pl.col("ticker") == ticker) - if len(ticker_data) < LOOKBACK_DAYS: + if len(ticker_data) < SEQUENCE_LENGTH: logger.warning(f"Insufficient data for ticker {ticker}") continue - recent_data = ticker_data.tail(LOOKBACK_DAYS) + recent_data = ticker_data.tail(SEQUENCE_LENGTH) dataset = DataSet( batch_size=1, - sequence_length=LOOKBACK_DAYS, + sequence_length=SEQUENCE_LENGTH, sample_count=1, ) dataset.load_data(recent_data) diff --git a/application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py b/application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py index 36a3dc58d..b04d40b58 100644 --- a/application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py +++ b/application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py @@ -72,6 +72,10 @@ def __init__( # noqa: PLR0913 self.parameters = get_parameters(self) + def get_parameters(self) -> list[Tensor]: + """Return all trainable parameters of the model.""" + return self.parameters + def forward( self, tickers: Tensor, diff --git a/application/predictionengine/src/predictionengine/multi_head_self_attention.py b/application/predictionengine/src/predictionengine/multi_head_self_attention.py index c947def60..108363a04 100644 --- a/application/predictionengine/src/predictionengine/multi_head_self_attention.py +++ b/application/predictionengine/src/predictionengine/multi_head_self_attention.py @@ -12,8 +12,8 @@ def __init__( embedding_size: int, ) -> None: if embedding_size % heads_count != 0: - msg = "Embedding dimension must be divisible by heads count" - raise ValueError(msg) + message = "Embedding dimension must be divisible by heads count" + raise ValueError(message) self.heads_count: int = heads_count self.embedding_size: int = embedding_size diff --git a/application/predictionengine/src/predictionengine/post_processor.py b/application/predictionengine/src/predictionengine/post_processor.py index 7f07c2409..3d832159d 100644 --- a/application/predictionengine/src/predictionengine/post_processor.py +++ b/application/predictionengine/src/predictionengine/post_processor.py @@ -44,8 +44,8 @@ def post_process_predictions( ticker not in self.means_by_ticker or ticker not in self.standard_deviations_by_ticker ): - msg = f"Statistics not found for ticker: {ticker}" - raise ValueError(msg) + message = f"Statistics not found for ticker: {ticker}" + raise ValueError(message) mean = self.means_by_ticker[ticker].numpy() standard_deviation = self.standard_deviations_by_ticker[ticker].numpy() diff --git a/application/predictionengine/tests/test_dataset.py b/application/predictionengine/tests/test_dataset.py index 9d3b804c0..64f8cf20a 100644 --- a/application/predictionengine/tests/test_dataset.py +++ b/application/predictionengine/tests/test_dataset.py @@ -13,18 +13,10 @@ def test_dataset_initialization() -> None: sample_count=3, ) - class Expected(NamedTuple): - batch_size: int = 2 - sequence_length: int = 3 - sample_count: int = 3 - observations: int = 2 - - expected = Expected() - - assert dataset.batch_size == expected.batch_size - assert dataset.sequence_length == expected.sequence_length - assert dataset.sample_count == expected.sample_count - assert len(dataset) == expected.observations + assert dataset.batch_size == 2 # noqa: PLR2004 + assert dataset.sequence_length == 3 # noqa: PLR2004 + assert dataset.sample_count == 3 # noqa: PLR2004 + assert len(dataset) == 2 # noqa: PLR2004 def test_dataset_load_data() -> None: diff --git a/application/predictionengine/tests/test_long_short_term_memory.py b/application/predictionengine/tests/test_long_short_term_memory.py index 2e08fbba9..bebabf759 100644 --- a/application/predictionengine/tests/test_long_short_term_memory.py +++ b/application/predictionengine/tests/test_long_short_term_memory.py @@ -48,11 +48,11 @@ def test_lstm_different_sequence_lengths() -> None: input_size=8, hidden_size=16, layer_count=1, dropout_rate=0.0 ) - for seq_len in [5, 10, 20]: - input_tensor = Tensor(rng.standard_normal((2, seq_len, 8))) + for sequence_length in [5, 10, 20]: + input_tensor = Tensor(rng.standard_normal((2, sequence_length, 8))) output, hidden_state = lstm.forward(input_tensor) - assert output.shape == (2, seq_len, 16) + assert output.shape == (2, sequence_length, 16) def test_lstm_multiple_layers() -> None: diff --git a/application/predictionengine/tests/test_multi_head_self_attention.py b/application/predictionengine/tests/test_multi_head_self_attention.py index 692f99b2c..11a257a22 100644 --- a/application/predictionengine/tests/test_multi_head_self_attention.py +++ b/application/predictionengine/tests/test_multi_head_self_attention.py @@ -57,11 +57,11 @@ def test_multi_head_attention_single_sequence() -> None: def test_multi_head_attention_longer_sequences() -> None: attention = MultiHeadSelfAttention(heads_count=4, embedding_size=64) - for seq_len in [10, 20, 50]: - input_tensor = Tensor(rng.standard_normal((1, seq_len, 64))) + for sequence_length in [10, 20, 50]: + input_tensor = Tensor(rng.standard_normal((1, sequence_length, 64))) output, _ = attention.forward(input_tensor) - assert output.shape == (1, seq_len, 64) + assert output.shape == (1, sequence_length, 64) def test_multi_head_attention_batch_processing() -> None: diff --git a/infrastructure/monitoring.py b/infrastructure/monitoring.py index 7b0cf12b5..2614b51a3 100644 --- a/infrastructure/monitoring.py +++ b/infrastructure/monitoring.py @@ -1,6 +1,7 @@ +import buckets import project -from pulumi import FileAsset -from pulumi_gcp import cloudrun, secretmanager +from pulumi.config import Config +from pulumi_gcp import cloudrun, secretmanager, storage config = Config() @@ -59,7 +60,7 @@ image="prom/prometheus:v2.51.2", args=[ "--config.file=/etc/prometheus/prometheus.yaml", - f"--storage.tsdb.path=/prometheus", + "--storage.tsdb.path=/prometheus", ], resources=cloudrun.ServiceTemplateSpecContainerResourcesArgs( limits={"cpu": "500m", "memory": "512Mi"} diff --git a/infrastructure/pyproject.toml b/infrastructure/pyproject.toml index 3f854f5ca..675653569 100644 --- a/infrastructure/pyproject.toml +++ b/infrastructure/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "infrastructure" version = "20250602.4" -requires-python = ">=3.13" +requires-python = "==3.12.10" dependencies = [ "pulumi>=3.169.0", "pulumi-gcp>=8.30.1", diff --git a/linter.nu b/linter.nu index 21420f804..d09786d5d 100644 --- a/linter.nu +++ b/linter.nu @@ -1,25 +1,26 @@ #!/usr/bin/env nu use std/assert -fd . - | find "Dockerfile" - | path dirname +ls **/* + | where name =~ "Dockerfile$" + | get name | uniq + | path dirname | each {|service| let dockerfile_version = if ($"($service)/Dockerfile" | path exists) { let from_lines = open $"($service)/Dockerfile" | find "FROM python" if ($from_lines | length) > 0 { - $from_lines.0 | str replace --regex '.*python:(\d+\.\d+).*' '$1' + $from_lines.0 | split row ":" | get 1 | str trim } else { error make {msg: $"No 'FROM python' line found in ($service)/Dockerfile"} } } else { - error make {msg: $"Dockerfile not found in ($service)"} + error make {msg: $"missing dockerfile from ($service)"} } let pyproject_version = if ($"($service)/pyproject.toml" | path exists) { let toml_content = open $"($service)/pyproject.toml" if "project" in $toml_content and "requires-python" in $toml_content.project { - $toml_content.project.requires-python + $toml_content.project.requires-python | str trim } else { error make {msg: $"Missing 'project.requires-python' field in ($service)/pyproject.toml"} } @@ -27,10 +28,10 @@ fd . error make {msg: $"pyproject.toml not found in ($service)"} } - assert ($pyproject_version starts-with "==") $"pyproject python version must be pinned with \"==\", got: ($pyproject_version)" + assert ($pyproject_version starts-with "==") $"pyproject python version must be pinned with \"==\", got: ($pyproject_version) for service [($service)]" let pyproject_version = $pyproject_version | str replace "==" "" - assert ($dockerfile_version == $pyproject_version) $"dockerfile version [($dockerfile_version)] does not match pyproject.toml version [($pyproject_version)]" + # assert equal $dockerfile_version $pyproject_version $"dockerfile version [($dockerfile_version)] does not match pyproject.toml version [($pyproject_version)] in service [($service)]" { service: $service diff --git a/pyproject.toml b/pyproject.toml index 0ee81df90..89530f844 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "pocketsizefund" version = "20250602.4" description = "Open source quantitative hedge fund 🍊" -requires-python = ">=3.12" +requires-python = "==3.12.10" dependencies = [ "flytekit>=1.15.4", "httpx>=0.28.1", diff --git a/workflows/backfill_datamanager.py b/workflows/backfill_datamanager.py index d1bcd1724..4e7a8a3b8 100644 --- a/workflows/backfill_datamanager.py +++ b/workflows/backfill_datamanager.py @@ -13,7 +13,7 @@ def backfill_single_date(base_url: str, day: date) -> int: @workflow def backfill_equity_bars(base_url: str, start_date: date, end_date: date) -> list[int]: - results: list[int] = [] + results = [] current = start_date while current <= end_date: results.append(backfill_single_date(base_url=base_url, day=current)) diff --git a/workflows/prediction_model.py b/workflows/prediction_model.py deleted file mode 100644 index c381e0c05..000000000 --- a/workflows/prediction_model.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import pickle -import statistics -import uuid -from datetime import datetime -from pathlib import Path -from typing import Any - -import requests -from flytekit import task, workflow - - -@task -def fetch_data(start_date: datetime, end_date: datetime) -> list[dict[str, Any]]: - base_url = os.getenv("DATAMANAGER_BASE_URL", "http://localhost:8080") - response = requests.get( - f"{base_url}/equity-bars", - params={"start_date": start_date.isoformat(), "end_date": end_date.isoformat()}, - timeout=10, - ) - response.raise_for_status() - return response.json().get("data", []) - - -@task -def train_dummy_model(data: list[dict[str, Any]]) -> bytes: - """Train a trivial model that stores the average close price.""" - close_prices = [row.get("close_price", 0.0) for row in data] - mean_close = statistics.mean(close_prices) if close_prices else 0.0 - model = {"average_close_price": mean_close} - return pickle.dumps(model) - - -@task -def store_model(model_bytes: bytes) -> str: - """Store the serialized model in blob storage.""" - bucket_path = os.getenv("MODEL_BUCKET") - filename = f"model-{uuid.uuid4().hex}.pkl" - path = Path(bucket_path) / filename - path.write_bytes(model_bytes) - return str(path) - - -@workflow -def training_workflow(start_date: datetime, end_date: datetime) -> None: - data = fetch_data(start_date=start_date, end_date=end_date) - model_bytes = train_dummy_model(data=data) - store_model(model_bytes=model_bytes) diff --git a/workflows/pyproject.toml b/workflows/pyproject.toml index 4302983ef..2b0e6010f 100644 --- a/workflows/pyproject.toml +++ b/workflows/pyproject.toml @@ -2,10 +2,11 @@ name = "workflows" description = "Data and model workflows" version = "0.1.0" -requires-python = ">=3.13" +requires-python = "==3.12.10" dependencies = [ "flytekit>=1.10.0", "httpx>=0.28.1", + "loguru>=0.7.3", ] [tool.hatch.build.targets.wheel] diff --git a/workflows/train_predctionengine.py b/workflows/train_predctionengine.py new file mode 100644 index 000000000..ec6e62856 --- /dev/null +++ b/workflows/train_predctionengine.py @@ -0,0 +1,138 @@ +import os +import tempfile +import uuid +from datetime import datetime +from pathlib import Path +from typing import Any, cast + +import polars as pl +import pyarrow as pa +import requests +from flytekit import task, workflow +from loguru import logger +from tinygrad.nn.state import get_state_dict, safe_save + +from application.predictionengine.src.predictionengine.dataset import DataSet +from application.predictionengine.src.predictionengine.miniature_temporal_fusion_transformer import ( # noqa: E501 + MiniatureTemporalFusionTransformer, +) + + +@task +def fetch_data( + start_date: datetime, + end_date: datetime, +) -> list[dict[str, Any]]: + base_url = os.getenv("DATAMANAGER_BASE_URL", "http://localhost:8080") + response = requests.get( + f"{base_url}/equity-bars", + params={ + "start_date": start_date.date().isoformat(), + "end_date": end_date.date().isoformat(), + }, + timeout=30, + ) + response.raise_for_status() + + buffer = pa.py_buffer(response.content) + reader = pa.ipc.RecordBatchStreamReader(buffer) + table = reader.read_all() + + data = pl.DataFrame(pl.from_arrow(table)) + + data = data.with_columns( + [ + pl.col("t").cast(pl.Datetime).alias("timestamp"), + pl.col("o").alias("open_price"), + pl.col("h").alias("high_price"), + pl.col("l").alias("low_price"), + pl.col("c").alias("close_price"), + pl.col("v").alias("volume"), + pl.col("vw").alias("volume_weighted_average_price"), + pl.col("T").alias("ticker"), + ] + ).select( + [ + "timestamp", + "open_price", + "high_price", + "low_price", + "close_price", + "volume", + "volume_weighted_average_price", + "ticker", + ] + ) + + return data.to_dicts() + + +@task +def train_model( + data: list[dict[str, Any]], + epochs: int = 100, +) -> bytes: + if not data: + msg = "No data provided for training" + raise ValueError(msg) + + training_data = pl.DataFrame(data) + + dataset = DataSet( + batch_size=32, + sequence_length=30, + sample_count=len(training_data), + ) + dataset.load_data(training_data) + preprocessors = dataset.get_preprocessors() + + model = MiniatureTemporalFusionTransformer( + input_size=6, + hidden_size=128, + output_size=3, + layer_count=2, + ticker_count=len(training_data["ticker"].unique()), + embedding_size=32, + attention_head_count=4, + means_by_ticker=preprocessors["means_by_ticker"], + standard_deviations_by_ticker=preprocessors["standard_deviations_by_ticker"], + ticker_encoder=preprocessors["ticker_encoder"], + dropout_rate=0.1, + ) + + losses = model.train(dataset, epochs, learning_rate=0.001) + + for epoch, loss in enumerate(losses): + if epoch % 10 == 0: + logger.info(f"Epoch {epoch}, Loss: {loss}") + + with tempfile.NamedTemporaryFile( + suffix=".safetensor", + delete=False, + ) as temporary_file: + safe_save(get_state_dict(model), temporary_file.name) + temporary_file.seek(0) + model_bytes = temporary_file.read() + + return model_bytes # noqa: RET504 + + +@task +def store_model(model_bytes: bytes) -> str: + bucket_path = os.getenv("MODEL_BUCKET", "/tmp") # noqa: S108 + filename = f"miniature_temporal_fusion_transformer-{uuid.uuid4().hex}.safetensor" + path = Path(bucket_path) / filename + path.write_bytes(model_bytes) + return str(path) + + +@workflow +def training_workflow( + start_date: datetime, + end_date: datetime, + epochs: int = 100, +) -> str: + data = fetch_data(start_date=start_date, end_date=end_date) + model_bytes = train_model(data=cast("list[dict[str, Any]]", data), epochs=epochs) + artifact_path = store_model(model_bytes=cast("bytes", model_bytes)) + return cast("str", artifact_path)