Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions maskfile.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,14 @@ case "$application_name" in

if [ "$data_type" = "equity-bars" ]; then
if [ -n "${date_range:-}" ]; then
uv run python tools/sync_equity_bars_data.py "$base_url" "$date_range"
uv run python -m tools.sync_equity_bars_data "$base_url" "$date_range"
else
current_date=$(date -u +"%Y-%m-%d")
date_range_json="{\"start_date\": \"$current_date\", \"end_date\": \"$current_date\"}"
uv run python tools/sync_equity_bars_data.py "$base_url" "$date_range_json"
uv run python -m tools.sync_equity_bars_data "$base_url" "$date_range_json"
fi
elif [ "$data_type" = "equity-details" ]; then
uv run python tools/sync_equity_details_data.py "$base_url"
uv run python -m tools.sync_equity_details_data "$base_url"
Comment thread
forstmeier marked this conversation as resolved.
fi
;;

Expand Down Expand Up @@ -591,7 +591,7 @@ echo "Running dead code analysis"
uvx vulture \
--min-confidence 80 \
--exclude '.flox,.venv,target' \
. tools/vulture_whitelist.py
. tools/src/tools/vulture_whitelist.py

echo "Dead code check completed"
```
Expand Down Expand Up @@ -778,7 +778,7 @@ export LOOKBACK_DAYS="${LOOKBACK_DAYS:-365}"

cd ../

uv run python tools/prepare_training_data.py
uv run python -m tools.prepare_training_data
```

### train (application_name) [instance_preset]
Expand Down Expand Up @@ -856,7 +856,7 @@ export AWS_S3_EQUITY_PRICE_MODEL_TRAINING_DATA_PATH="s3://${AWS_S3_MODEL_ARTIFAC

cd ../

uv run python tools/run_training_job.py
uv run python -m tools.run_training_job
```

### artifacts
Expand All @@ -872,7 +872,7 @@ set -euo pipefail

export APPLICATION_NAME="${application_name}"

uv run python tools/download_model_artifacts.py
uv run python -m tools.download_model_artifacts
```

## mcp
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ dependencies = [
"fastapi>=0.121.0",
"uvicorn>=0.35.0",
"structlog>=25.5.0",
"sagemaker>=2.256.0",
"sagemaker>=2.256.0,<3",
"numpy>=1.26.4",
"tinygrad>=0.10.3",
"requests>=2.32.5",
"mypy-boto3-s3>=1.42.37",
"polars",
Comment thread
forstmeier marked this conversation as resolved.
]

[tool.uv.sources]
Expand All @@ -31,7 +32,7 @@ members = [
dev = ["coverage>=7.8.0", "pytest>=8.3.5", "behave>=1.2.6"]

[tool.pytest.ini_options]
testpaths = ["applications/*/tests", "libraries/python/tests"]
testpaths = ["applications/*/tests", "libraries/python/tests", "tools/tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
Expand All @@ -49,7 +50,7 @@ filterwarnings = [

[tool.coverage.run]
parallel = true
omit = ["*/__init__.py", "**/test_*.py"]
omit = ["*/__init__.py", "**/test_*.py", "tools/**"]

[tool.coverage.report]
show_missing = true
Expand Down
6 changes: 5 additions & 1 deletion tools/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ name = "tools"
version = "0.1.0"
description = "Project tools and scripts"
requires-python = "==3.12.10"
dependencies = ["boto3>=1.40.74", "massive>=2.0.2"]
dependencies = ["boto3>=1.40.74"]

[tool.uv]
package = true
src = ["src"]
Comment thread
forstmeier marked this conversation as resolved.
Empty file added tools/src/tools/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def download_model_artifacts( # noqa: C901, PLR0915

file_object_name_parts = file_object_name.split("/")

if len(file_object_name_parts) < 2: # noqa: PLR2004
if len(file_object_name_parts) < 3: # noqa: PLR2004
logger.warning("Skipping malformed path", path=file_object_name)
continue

options.add(file_object_name_parts[1])
options.add(file_object_name_parts[2])
file_objects_with_timestamps.append(
{
"name": file_object_name_parts[1],
"name": file_object_name_parts[2],
Comment thread
forstmeier marked this conversation as resolved.
"last_modified": file_object["LastModified"],
}
)
Expand Down Expand Up @@ -88,7 +88,7 @@ def download_model_artifacts( # noqa: C901, PLR0915

logger.info("Selected artifact", selected_option=selected_option)

target_path = f"artifacts/{selected_option}/output/model.tar.gz"
target_path = f"artifacts/{application_name}/{selected_option}/output/model.tar.gz"
destination_directory = f"applications/{application_name}/src/{application_name}/"
destination_path = os.path.join(destination_directory, "model.tar.gz") # noqa: PTH118

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
"""Prepare consolidated training data from equity bars and categories.

This script:
1. Reads equity bars from S3 (partitioned parquet)
2. Reads categories CSV from S3
3. Joins them on ticker
4. Filters by minimum price/volume thresholds
5. Outputs consolidated parquet to S3 for SageMaker training
"""

import io
import os
import sys
Expand Down
File renamed without changes.
File renamed without changes.
48 changes: 48 additions & 0 deletions tools/tests/test_download_model_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch

from tools.download_model_artifacts import download_model_artifacts


def test_download_model_artifacts_github_actions_selects_latest() -> None:
mock_s3_client = MagicMock()
mock_s3_client.list_objects_v2.return_value = {
"Contents": [
{
"Key": "artifacts/equitypricemodel/run_20250601/output/model.tar.gz",
"LastModified": datetime(2025, 6, 1, tzinfo=UTC),
}
]
}

mock_tar = MagicMock()

with (
patch(
"tools.download_model_artifacts.boto3.client",
return_value=mock_s3_client,
),
patch("tools.download_model_artifacts.os.makedirs"),
patch("tools.download_model_artifacts.tarfile.open") as mock_tarfile_open,
):
mock_tarfile_open.return_value.__enter__.return_value = mock_tar

download_model_artifacts(
application_name="equitypricemodel",
artifacts_bucket="test-artifacts-bucket",
github_actions_check=True,
)

mock_s3_client.list_objects_v2.assert_called_once_with(
Bucket="test-artifacts-bucket",
Prefix="artifacts/equitypricemodel",
Comment thread
forstmeier marked this conversation as resolved.
)
mock_s3_client.download_file.assert_called_once_with(
Bucket="test-artifacts-bucket",
Key="artifacts/equitypricemodel/run_20250601/output/model.tar.gz",
Filename="applications/equitypricemodel/src/equitypricemodel/model.tar.gz",
)
mock_tar.extractall.assert_called_once_with(
path="applications/equitypricemodel/src/equitypricemodel/",
filter="data",
)
180 changes: 180 additions & 0 deletions tools/tests/test_prepare_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import io
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch

import polars as pl
from tools.prepare_training_data import (
MINIMUM_CLOSE_PRICE,
MINIMUM_VOLUME,
consolidate_data,
filter_equity_bars,
prepare_training_data,
read_categories_from_s3,
read_equity_bars_from_s3,
write_training_data_to_s3,
)

_TARGET_DATE = datetime(2025, 6, 1, tzinfo=UTC)

_SAMPLE_EQUITY_BARS = pl.DataFrame(
{
"ticker": ["AAPL"],
"timestamp": [_TARGET_DATE],
"open_price": [148.0],
"high_price": [152.0],
"low_price": [147.0],
"close_price": [150.0],
"volume": [1_000_000],
"volume_weighted_average_price": [151.0],
}
)

_SAMPLE_CATEGORIES = pl.DataFrame(
{
"ticker": ["AAPL"],
"sector": ["Technology"],
"industry": ["Consumer Electronics"],
}
)


def _to_parquet_bytes(data: pl.DataFrame) -> bytes:
buffer = io.BytesIO()
data.write_parquet(buffer)
return buffer.getvalue()


def _to_csv_bytes(data: pl.DataFrame) -> bytes:
return data.write_csv().encode()


def test_filter_equity_bars_keeps_rows_above_thresholds() -> None:
data = pl.DataFrame(
{
"close_price": [MINIMUM_CLOSE_PRICE + 1.0, 0.5],
"volume": [MINIMUM_VOLUME + 1, 50_000],
}
)

result = filter_equity_bars(data)

assert len(result) == 1
assert result["close_price"][0] == MINIMUM_CLOSE_PRICE + 1.0


def test_filter_equity_bars_empty_input_returns_empty() -> None:
data = pl.DataFrame({"close_price": [], "volume": []}).cast(
{"close_price": pl.Float64, "volume": pl.Int64}
)

result = filter_equity_bars(data)

assert len(result) == 0


def test_consolidate_data_joins_on_ticker_and_retains_columns() -> None:
result = consolidate_data(_SAMPLE_EQUITY_BARS, _SAMPLE_CATEGORIES)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
assert "sector" in result.columns
assert "industry" in result.columns


def test_consolidate_data_excludes_unmatched_tickers() -> None:
categories = pl.DataFrame(
{
"ticker": ["MSFT"],
"sector": ["Technology"],
"industry": ["Software"],
}
)

result = consolidate_data(_SAMPLE_EQUITY_BARS, categories)

assert len(result) == 0


def test_read_equity_bars_from_s3_returns_dataframe() -> None:
parquet_bytes = _to_parquet_bytes(_SAMPLE_EQUITY_BARS)

mock_body = MagicMock()
mock_body.read.return_value = parquet_bytes
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = {"Body": mock_body}

result = read_equity_bars_from_s3(
s3_client=mock_s3_client,
bucket_name="test-bucket",
start_date=_TARGET_DATE,
end_date=_TARGET_DATE,
)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
mock_s3_client.get_object.assert_called_once()
Comment thread
forstmeier marked this conversation as resolved.


def test_read_categories_from_s3_returns_dataframe() -> None:
csv_bytes = _to_csv_bytes(_SAMPLE_CATEGORIES)

mock_body = MagicMock()
mock_body.read.return_value = csv_bytes
mock_s3_client = MagicMock()
mock_s3_client.get_object.return_value = {"Body": mock_body}

result = read_categories_from_s3(
s3_client=mock_s3_client,
bucket_name="test-bucket",
)

assert len(result) == 1
assert result["ticker"][0] == "AAPL"
mock_s3_client.get_object.assert_called_once_with(
Bucket="test-bucket",
Key="equity/details/categories.csv",
)


def test_write_training_data_to_s3_returns_s3_uri() -> None:
mock_s3_client = MagicMock()

result = write_training_data_to_s3(
s3_client=mock_s3_client,
bucket_name="test-bucket",
data=_SAMPLE_EQUITY_BARS,
output_key="training/data.parquet",
)

assert result == "s3://test-bucket/training/data.parquet"
mock_s3_client.put_object.assert_called_once()
call_kwargs = mock_s3_client.put_object.call_args.kwargs
assert call_kwargs["Bucket"] == "test-bucket"
assert call_kwargs["Key"] == "training/data.parquet"


def test_prepare_training_data_returns_s3_uri() -> None:
parquet_bytes = _to_parquet_bytes(_SAMPLE_EQUITY_BARS)
csv_bytes = _to_csv_bytes(_SAMPLE_CATEGORIES)

mock_body_bars = MagicMock()
mock_body_bars.read.return_value = parquet_bytes
mock_body_categories = MagicMock()
mock_body_categories.read.return_value = csv_bytes

mock_s3_client = MagicMock()
mock_s3_client.get_object.side_effect = [
{"Body": mock_body_bars},
{"Body": mock_body_categories},
]

with patch("tools.prepare_training_data.boto3.client", return_value=mock_s3_client):
result = prepare_training_data(
data_bucket_name="test-data-bucket",
model_artifacts_bucket_name="test-artifacts-bucket",
start_date=_TARGET_DATE,
end_date=_TARGET_DATE,
)

assert result.startswith("s3://test-artifacts-bucket/")
mock_s3_client.put_object.assert_called_once()
Loading