diff --git a/.claude/skills/autotrain/SKILL.md b/.claude/skills/autotrain/SKILL.md index 6adb6adf2..cfb8a3a47 100644 --- a/.claude/skills/autotrain/SKILL.md +++ b/.claude/skills/autotrain/SKILL.md @@ -112,7 +112,7 @@ For code changes (model architecture, loss function, layer structure), run **one experiment at a time**: 1. **Hypothesize** a single architecture change -2. **Modify** the model code (tide_model.py or tide_data.py) +2. **Modify** the model code (model.py or data.py) 3. **Commit**: `git add -A && git commit -m "autoresearch: "` 4. **Run training**: ```bash @@ -153,16 +153,16 @@ Priority order: architecture > hyperparameters > epoch depth > loss function > d ## Files Modified During Loop Primarily: -- `models/tide/src/tide/tide_model.py` (model architecture) -- `models/tide/src/tide/tide_data.py` (data processing) +- `models/tide/src/tide/model.py` (model architecture) +- `models/tide/src/tide/data.py` (data processing) - `models/tide/src/tide/trainer.py` (DEFAULT_CONFIGURATION, training hyperparameters) ## Key Files - **Trainer**: `models/tide/src/tide/trainer.py` (DEFAULT_CONFIGURATION lives here) - **Workflow**: `models/tide/src/tide/workflow.py` -- **Model**: `models/tide/src/tide/tide_model.py` -- **Data**: `models/tide/src/tide/tide_data.py` +- **Model**: `models/tide/src/tide/model.py` +- **Data**: `models/tide/src/tide/data.py` - **Experiment log**: `autotrain//experiments.jsonl` - **Prefect config**: `prefect.yaml` - **devenv tasks**: `devenv.nix` diff --git a/.github/workflows/launch_infrastructure.yaml b/.github/workflows/launch_infrastructure.yaml index 44a03762d..4e9a07058 100644 --- a/.github/workflows/launch_infrastructure.yaml +++ b/.github/workflows/launch_infrastructure.yaml @@ -58,16 +58,11 @@ jobs: - name: Install Flox if: steps.changes.outputs.service == 'true' uses: flox/install-flox-action@v2 - - name: Build ${{ matrix.application }}-${{ matrix.stage }} image + - name: Build and push ${{ matrix.application }}-${{ matrix.stage }} image if: steps.changes.outputs.service == 'true' uses: flox/activate-action@v1 with: - command: mask infrastructure image build ${{ matrix.application }} ${{ matrix.stage }} - - name: Push ${{ matrix.application }}-${{ matrix.stage }} image - if: steps.changes.outputs.service == 'true' - uses: flox/activate-action@v1 - with: - command: mask infrastructure image push ${{ matrix.application }} ${{ matrix.stage }} + command: mask infrastructure image build-and-push ${{ matrix.application }} ${{ matrix.stage }} launch_infrastructure: name: Deploy with Pulumi needs: build_and_push_images @@ -142,4 +137,4 @@ jobs: env: PULUMI_ACCESS_TOKEN: ${{ secrets.PULUMI_ACCESS_TOKEN }} with: - command: mask infrastructure image deploy ${{ matrix.application }} ${{ matrix.stage }} + command: mask infrastructure service update ${{ matrix.application }} diff --git a/applications/data_manager/src/storage.rs b/applications/data_manager/src/storage.rs index c9c37b945..0ca25979b 100644 --- a/applications/data_manager/src/storage.rs +++ b/applications/data_manager/src/storage.rs @@ -14,6 +14,7 @@ use std::io::Cursor; use tracing::{debug, error, info, warn}; const EQUITY_DETAILS_KEY: &str = "equity/details/details.csv"; +pub const DUCKDB_CONFIG_VALUE_MAX_LENGTH: usize = 4096; pub async fn write_equity_bars_dataframe_to_s3( state: &State, @@ -130,8 +131,8 @@ pub fn sanitize_duckdb_config_value(value: &str) -> Result { return Err(Error::Other(message)); } - // Reasonable length limit - if value.len() > 512 { + // Reasonable length limit (4096 accommodates AWS session tokens which can exceed 1000 characters) + if value.len() > DUCKDB_CONFIG_VALUE_MAX_LENGTH { let message = format!("Configuration value too long: {} characters", value.len()); error!("{}", message); return Err(Error::Other(message)); diff --git a/applications/data_manager/tests/test_storage.rs b/applications/data_manager/tests/test_storage.rs index 373780703..e0864b197 100644 --- a/applications/data_manager/tests/test_storage.rs +++ b/applications/data_manager/tests/test_storage.rs @@ -13,7 +13,7 @@ use data_manager::{ query_predictions_dataframe_from_s3, read_equity_details_dataframe_from_s3, sanitize_duckdb_config_value, write_equity_bars_dataframe_to_s3, write_equity_details_dataframe_to_s3, write_portfolio_dataframe_to_s3, - write_predictions_dataframe_to_s3, PredictionQuery, + write_predictions_dataframe_to_s3, PredictionQuery, DUCKDB_CONFIG_VALUE_MAX_LENGTH, }, }; use polars::prelude::*; @@ -466,6 +466,8 @@ fn test_sanitize_duckdb_config_value_valid() { assert!(sanitize_duckdb_config_value("true").is_ok()); assert!(sanitize_duckdb_config_value("false").is_ok()); assert!(sanitize_duckdb_config_value("http://127.0.0.1:9000").is_ok()); + // AWS ECS session tokens can exceed 1000 characters + assert!(sanitize_duckdb_config_value(&"a".repeat(1052)).is_ok()); } #[test] @@ -474,7 +476,7 @@ fn test_sanitize_duckdb_config_value_rejects_injection() { assert!(sanitize_duckdb_config_value("localhost'; --").is_err()); assert!(sanitize_duckdb_config_value("\"malicious\"").is_err()); assert!(sanitize_duckdb_config_value("").is_err()); - assert!(sanitize_duckdb_config_value(&"a".repeat(513)).is_err()); + assert!(sanitize_duckdb_config_value(&"a".repeat(DUCKDB_CONFIG_VALUE_MAX_LENGTH + 1)).is_err()); assert!(sanitize_duckdb_config_value("value;another").is_err()); } diff --git a/applications/ensemble_manager/pyproject.toml b/applications/ensemble_manager/pyproject.toml index 77a191455..987f93e05 100644 --- a/applications/ensemble_manager/pyproject.toml +++ b/applications/ensemble_manager/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "requests>=2.32.5", "sentry-sdk[fastapi]>=2.0.0", "structlog>=25.5.0", + "prometheus-client>=0.21.0", ] [dependency-groups] diff --git a/applications/ensemble_manager/src/ensemble_manager/server.py b/applications/ensemble_manager/src/ensemble_manager/server.py index 9d07fbba4..6e948732b 100644 --- a/applications/ensemble_manager/src/ensemble_manager/server.py +++ b/applications/ensemble_manager/src/ensemble_manager/server.py @@ -25,8 +25,8 @@ from mypy_boto3_s3 import S3Client from internal.equity_details_schema import equity_details_schema -from tide.tide_data import Data -from tide.tide_model import Model +from tide.data import Data +from tide.model import Model from .metrics import ( get_metrics, @@ -342,30 +342,31 @@ def create_predictions(request: Request) -> Response: # noqa: PLR0915 tide_data.preprocess_and_set_data(data=data) - batches = tide_data.get_batches(data_type="predict") + from tinygrad.tensor import Tensor # noqa: PLC0415 - if not batches: + dataset = tide_data.get_dataset(data_type="predict") + + if len(dataset) == 0: prediction_errors_total.labels(stage="batch_creation").inc() - logger.error("No data batches available for prediction") + logger.error("No data samples available for prediction") observe_duration(timer_start) return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) - logger.info( - "Processing prediction batches", - num_batches=len(batches), - ) + logger.info("Processing prediction dataset", num_samples=len(dataset)) - all_predictions = [] - for batch in batches: - raw_predictions = request.app.state.tide_model.predict(inputs=batch) - batch_predictions = tide_data.postprocess_predictions( - input_batch=batch, - predictions=raw_predictions, - current_datetime=current_timestamp, - ) - all_predictions.append(batch_predictions) + batch = { + "past_continuous_features": Tensor(dataset.past_continuous), + "past_categorical_features": Tensor(dataset.past_categorical), + "future_categorical_features": Tensor(dataset.future_categorical), + "static_categorical_features": Tensor(dataset.static_categorical), + } - predictions = pl.concat(all_predictions) + raw_predictions = request.app.state.tide_model.predict(inputs=batch) + predictions = tide_data.postprocess_predictions( + input_batch=batch, + predictions=raw_predictions, + current_datetime=current_timestamp, + ) logger.info( "Combined predictions from all batches", total_predictions=predictions.height, @@ -417,7 +418,7 @@ def create_predictions(request: Request) -> Response: # noqa: PLR0915 observe_duration(timer_start) raise - prediction_batch_count.set(len(batches)) + prediction_batch_count.set(1) prediction_row_count.set(validated_predictions.height) observe_duration(timer_start) diff --git a/applications/portfolio_manager/pyproject.toml b/applications/portfolio_manager/pyproject.toml index 546f76f2d..facd5b13a 100644 --- a/applications/portfolio_manager/pyproject.toml +++ b/applications/portfolio_manager/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "sentry-sdk[fastapi]>=2.0.0", "structlog>=25.5.0", "scipy>=1.17.1", + "prometheus-client>=0.21.0", ] [tool.uv] diff --git a/infrastructure/__main__.py b/infrastructure/__main__.py index bfd7d1452..3134fb8e7 100644 --- a/infrastructure/__main__.py +++ b/infrastructure/__main__.py @@ -2,7 +2,7 @@ from compute import acm_certificate_arn, alb, cluster, service_discovery_namespace from config import account_id from iam import github_actions_infrastructure_role, github_actions_oidc_provider -from networking import vpc +from networking import ecs_security_group, private_subnet_1, private_subnet_2, vpc from storage import ( data_bucket, data_manager_image_uri, @@ -15,7 +15,7 @@ tide_model_runner_image_uri, tide_model_runner_repository, ) -from training import models_cluster +from training import models_cluster, tide_trainer_task_definition protocol = "https://" if acm_certificate_arn else "http://" @@ -33,8 +33,14 @@ pulumi.export("aws_account_id", account_id) pulumi.export("aws_vpc_id", vpc.id) +pulumi.export("aws_ecs_private_subnet_1_id", private_subnet_1.id) +pulumi.export("aws_ecs_private_subnet_2_id", private_subnet_2.id) +pulumi.export("aws_ecs_security_group_id", ecs_security_group.id) pulumi.export("aws_ecs_cluster_name", cluster.name) pulumi.export("aws_ecs_models_cluster_name", models_cluster.name) +pulumi.export( + "aws_ecs_tide_trainer_task_definition_arn", tide_trainer_task_definition.arn +) pulumi.export("aws_alb_dns_name", alb.dns_name) pulumi.export("aws_alb_url", pulumi.Output.concat(protocol, alb.dns_name)) pulumi.export("aws_service_discovery_namespace", service_discovery_namespace.name) diff --git a/infrastructure/training.py b/infrastructure/training.py index ba539737a..11997cae2 100644 --- a/infrastructure/training.py +++ b/infrastructure/training.py @@ -5,6 +5,7 @@ from config import tags from iam import execution_role, task_role from networking import ecs_security_group, private_subnet_1, private_subnet_2 +from storage import tide_model_runner_image_uri models_cluster = aws.ecs.Cluster( "models_cluster", @@ -154,9 +155,42 @@ tags=tags, ) +tide_trainer_task_definition = aws.ecs.TaskDefinition( + "tide_trainer_task_definition", + family="fund-tide-trainer", + requires_compatibilities=["EC2"], + network_mode="awsvpc", + cpu="4096", + memory="14336", + execution_role_arn=execution_role.arn, + task_role_arn=task_role.arn, + container_definitions=tide_model_runner_image_uri.apply( + lambda image_uri: json.dumps( + [ + { + "name": "prefect", + "image": image_uri, + "essential": True, + "resourceRequirements": [{"type": "GPU", "value": "1"}], + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "/ecs/fund/models", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "tide", + }, + }, + } + ] + ) + ), + tags=tags, +) + __all__ = [ "execution_role", "models_cluster", "models_log_group", "task_role", + "tide_trainer_task_definition", ] diff --git a/maskfile.md b/maskfile.md index da25da7af..a4c69cb0f 100644 --- a/maskfile.md +++ b/maskfile.md @@ -52,14 +52,14 @@ echo "Development environment setup completed successfully" > Manage Docker images for applications -#### build (package_name) (stage_name) +#### build-and-push (package_name) (stage_name) -> Build Docker images with optional cache pull (e.g. `portfolio-manager server`, `tide model-runner`) +> Build and push Docker image directly to ECR (e.g. `portfolio-manager server`, `tide model-runner`) ```bash set -euo pipefail -echo "Building image" +echo "Building and pushing image" aws_account_id=$(aws sts get-caller-identity --query Account --output text) aws_region="${AWS_REGION:-}" @@ -117,105 +117,19 @@ else docker buildx create --use --name fund-builder 2>/dev/null || docker buildx use fund-builder || (echo "Using default buildx builder" && docker buildx use default) fi -echo "Building with caching (will continue if cache doesn't exist)" +echo "Building and pushing with caching (will continue if cache doesn't exist)" docker buildx build \ --platform linux/amd64 \ --target ${build_target} \ --file ${dockerfile} \ --tag ${image_reference}:latest \ + --tag ${image_reference}:git-${commit_hash} \ ${cache_from_arguments} \ ${cache_to_arguments} \ - --load \ + --push \ . -echo "Image built: ${package_name} ${stage_name}" -``` - -#### push (package_name) (stage_name) - -> Push Docker image to ECR (e.g. `portfolio-manager server`, `tide model-runner`) - -```bash -set -euo pipefail - -echo "Pushing image to ECR" - -aws_account_id=$(aws sts get-caller-identity --query Account --output text) -aws_region="${AWS_REGION:-}" -if [ -z "$aws_region" ]; then - echo "AWS_REGION environment variable is not set" - exit 1 -fi - -repository_name="fund/${package_name}-${stage_name}" -image_reference="${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com/${repository_name}" -commit_hash=$(git rev-parse --short HEAD) - -echo "Logging into ECR" -aws ecr get-login-password --region ${aws_region} | docker login \ - --username AWS \ - --password-stdin ${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com > /dev/null - -echo "Checking if image for commit ${commit_hash} already exists in ECR" -existing_tag="NONE" -if image_digest=$(aws ecr describe-images \ - --repository-name "${repository_name}" \ - --image-ids "imageTag=git-${commit_hash}" \ - --query 'imageDetails[0].imageDigest' \ - --output text 2>/dev/null); then - existing_tag="${image_digest}" -fi - -if [ "$existing_tag" != "NONE" ] && [ "$existing_tag" != "None" ] && [ -n "$existing_tag" ]; then - echo "Image for commit ${commit_hash} already exists in ECR, skipping push" - echo "Image pushed: ${package_name} ${stage_name} (cached)" - exit 0 -fi - -echo "Pushing image" -docker tag "${image_reference}:latest" "${image_reference}:git-${commit_hash}" -docker push "${image_reference}:latest" -docker push "${image_reference}:git-${commit_hash}" - -echo "Image pushed: ${package_name} ${stage_name} (commit: ${commit_hash})" -``` - -#### deploy (package_name) (stage_name) - -> Deploy ECS service with latest image (e.g. `portfolio-manager server`, `data-manager server`) - -```bash -set -euo pipefail - -echo "Deploying ${package_name} ${stage_name}" - -case "${package_name}-${stage_name}" in - data-manager-server) service="fund-data-manager-server" ;; - portfolio-manager-server) service="fund-portfolio-manager-server" ;; - ensemble-manager-server) service="fund-ensemble-manager-server" ;; - tide-model-runner) echo "tide-model-runner is used for Prefect training jobs, not an ECS service" && exit 0 ;; - *) echo "Unknown service: ${package_name}-${stage_name}" && exit 1 ;; -esac - -cd infrastructure/ - -if ! organization_name=$(pulumi org get-default 2>/dev/null) || [ -z "${organization_name}" ]; then - echo "Error: Pulumi default organization not set. Run: pulumi org set-default " - exit 1 -fi -pulumi stack select "${organization_name}/fund/production" -cluster=$(pulumi stack output aws_ecs_cluster_name) - -cd "${MASKFILE_DIR}" - -aws ecs update-service --cluster "$cluster" --service "$service" --force-new-deployment --no-cli-pager > /dev/null -echo "Deployment started: ${service}" - -echo "Waiting for ${service} to stabilize" - -aws ecs wait services-stable --cluster "$cluster" --services "$service" - -echo "Deployment complete: ${service} (${package_name} ${stage_name})" +echo "Image built and pushed: ${package_name} ${stage_name} (commit: ${commit_hash})" ``` ### stack @@ -454,9 +368,9 @@ aws ecs wait services-stable --cluster "$cluster" --services "$service" echo "Deployment complete: ${service}" ``` -### models +### trainer -> Manage Prefect Cloud model resources +> Manage Prefect Cloud model training resources #### initialize (environment) @@ -469,12 +383,31 @@ case "${environment}" in remote) unset PREFECT_API_URL - echo "Creating fund-models-remote work pool on Prefect Cloud" - uv run prefect work-pool create "fund-models-remote" --type ecs 2>/dev/null \ - || echo " already exists" + cd infrastructure/ + pulumi stack select "$(pulumi org get-default)/fund/production" + models_cluster=$(pulumi stack output aws_ecs_models_cluster_name) + tide_trainer_task_definition_arn=$(pulumi stack output aws_ecs_tide_trainer_task_definition_arn) + vpc_id=$(pulumi stack output aws_vpc_id) + private_subnet_1_id=$(pulumi stack output aws_ecs_private_subnet_1_id) + private_subnet_2_id=$(pulumi stack output aws_ecs_private_subnet_2_id) + ecs_security_group_id=$(pulumi stack output aws_ecs_security_group_id) + cd "${MASKFILE_DIR}" - echo "Registering remote training deployment" - uv run prefect --no-prompt deploy --name tide-trainer-remote + echo "Creating fund-models-remote work pool on Prefect Cloud" + aws_credentials_block_id=$(uv run prefect block inspect "aws-credentials/fund-aws" | grep "Block id" | grep -oE '[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}') + base_job_template=$(uv run prefect work-pool get-default-base-job-template --type ecs \ + | uv run python -m tools.build_work_pool_template \ + "${models_cluster}" \ + "${aws_credentials_block_id}" \ + "${tide_trainer_task_definition_arn}" \ + "${vpc_id}" \ + "${private_subnet_1_id}" \ + "${private_subnet_2_id}" \ + "${ecs_security_group_id}") + uv run prefect work-pool create "fund-models-remote" --type ecs \ + --base-job-template <(echo "${base_job_template}") 2>/dev/null \ + || uv run prefect work-pool update "fund-models-remote" \ + --base-job-template <(echo "${base_job_template}") echo "" echo "Done. Visit Prefect Cloud dashboard to view deployments." @@ -850,36 +783,37 @@ echo "YAML development checks completed successfully" ### train (model_name) -> Train model via Prefect training pipeline +> Trigger model training run via Prefect Cloud deployment ```bash set -euo pipefail -cd infrastructure - -if ! organization_name=$(pulumi org get-default 2>/dev/null) || [ -z "${organization_name}" ]; then - echo "Unable to determine Pulumi organization name - ensure you are logged in" - exit 1 -fi - -pulumi stack select ${organization_name}/fund/production - -export FUND_DATA_MANAGER_BASE_URL="$(pulumi stack output aws_alb_url)" -export AWS_S3_DATA_BUCKET_NAME="$(pulumi stack output aws_s3_data_bucket_name)" -export AWS_S3_MODEL_ARTIFACTS_BUCKET_NAME="$(pulumi stack output aws_s3_model_artifacts_bucket_name)" -export FUND_LOOKBACK_DAYS="${FUND_LOOKBACK_DAYS:-365}" +unset PREFECT_API_URL -cd ../ +lookback_days="${FUND_LOOKBACK_DAYS:-365}" case "${model_name}" in tide) - uv run python -m tide.run + deployment="tide-training-pipeline/tide-trainer-remote" + log_group="/ecs/fund/models" + log_stream_prefix="tide/prefect" ;; *) echo "Unknown model: ${model_name}" + echo "Valid options: tide" exit 1 ;; esac + +echo "Triggering training run for ${model_name} (lookback_days=${lookback_days})" +uv run prefect deployment run "${deployment}" --param "lookback_days=${lookback_days}" + +echo "" +echo "To find logs once the run starts (GPU provisioning takes ~3-5 minutes):" +echo " 1. Open the flow run in Prefect Cloud and note the ECS task ARN under Infrastructure" +echo " 2. The task ID is the last segment of the ARN (after the final '/')" +echo " 3. Find the log stream in CloudWatch log group '${log_group}':" +echo " ${log_stream_prefix}/" ``` ### deploy (model_name) @@ -889,11 +823,26 @@ esac ```bash set -euo pipefail +echo "Deploying ${model_name} model" + unset PREFECT_API_URL export FUND_LOOKBACK_DAYS="${FUND_LOOKBACK_DAYS:-365}" +cd infrastructure + +if ! organization_name=$(pulumi org get-default 2>/dev/null) || [ -z "${organization_name}" ]; then + echo "Unable to determine Pulumi organization name - ensure you are logged in" + exit 1 +fi + +pulumi stack select "${organization_name}/fund/production" +tide_image_uri=$(pulumi stack output aws_ecr_tide_model_runner_image) + +cd "${MASKFILE_DIR}" + case "${model_name}" in tide) + export FUND_TIDE_IMAGE_URI="${tide_image_uri}" uv run python -m tide.deploy ;; *) @@ -901,6 +850,8 @@ case "${model_name}" in exit 1 ;; esac + +echo "Deployment complete: ${model_name}" ``` ### download (model_name) diff --git a/models/tide/Dockerfile b/models/tide/Dockerfile index fe979a98b..8fb2a53b0 100644 --- a/models/tide/Dockerfile +++ b/models/tide/Dockerfile @@ -39,6 +39,7 @@ COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /bin/uv COPY --from=builder /app /app ENV PYTHONPATH=/app/models/tide/src +ENV PATH="/app/.venv/bin:$PATH" RUN useradd --system --uid 10001 --create-home appuser && \ chown -R appuser:appuser /app diff --git a/models/tide/src/tide/tide_data.py b/models/tide/src/tide/data.py similarity index 74% rename from models/tide/src/tide/tide_data.py rename to models/tide/src/tide/data.py index 67707aa99..b9fbe9574 100644 --- a/models/tide/src/tide/tide_data.py +++ b/models/tide/src/tide/data.py @@ -1,15 +1,18 @@ import json import os from abc import ABC, abstractmethod +from dataclasses import dataclass from datetime import date, datetime, timedelta -from typing import cast +from typing import TYPE_CHECKING, cast import numpy as np import pandera.polars as pa import polars as pl import structlog from internal.timestamps import to_timestamp_milliseconds -from tinygrad.tensor import Tensor + +if TYPE_CHECKING: + from tinygrad.tensor import Tensor logger = structlog.get_logger() @@ -43,7 +46,6 @@ "day_of_year", "month", "year", - "is_holiday", ] STATIC_CATEGORICAL_COLUMNS = [ @@ -57,6 +59,18 @@ FeatureMappings = dict[str, CategoryMapping] +@dataclass +class TrainingDataset: + past_continuous: np.ndarray # [N, input_length, n_continuous_features] + past_categorical: np.ndarray # [N, input_length, n_categorical_features] + future_categorical: np.ndarray # [N, output_length, n_categorical_features] + static_categorical: np.ndarray # [N, 1, n_static_features] + targets: np.ndarray | None # [N, output_length, 1] — None for predict + + def __len__(self) -> int: + return len(self.past_continuous) + + class Scaler: def __init__(self) -> None: pass @@ -96,10 +110,10 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: return data -class ExpandDateRange(Stage): +class EngineerFeatures(Stage): @property def name(self) -> str: - return "expand_date_range" + return "engineer_features" def run(self, data: pl.DataFrame) -> pl.DataFrame: data = data.with_columns( @@ -107,85 +121,8 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: .cast(pl.Datetime(time_unit="ms")) .dt.date() .alias("date"), - pl.lit(False).alias("is_holiday"), # noqa: FBT003 - ) - - tickers = data.select(pl.col("ticker").unique()) - - minimum_date: date = data.select(pl.col("date").min()).item() - maximum_date: date = data.select(pl.col("date").max()).item() - - dates = pl.DataFrame( - { - "date": pl.date_range( - minimum_date, - maximum_date, - "1d", - eager=True, - ) - } - ) - - dates_and_tickers = tickers.join(dates, how="cross") - - return dates_and_tickers.join(data, on=["ticker", "date"], how="left") - - -class FillNulls(Stage): - @property - def name(self) -> str: - return "fill_nulls" - - def run(self, data: pl.DataFrame) -> pl.DataFrame: - friday_number = 4 - - data = ( - data.with_columns(pl.col("date").dt.weekday().alias("temporary_weekday")) - .with_columns( - pl.when( - pl.col("is_holiday").is_null() - & (pl.col("temporary_weekday") <= friday_number) - ) - .then(True) # noqa: FBT003 - .when( - pl.col("is_holiday").is_null() - & (pl.col("temporary_weekday") > friday_number) - ) - .then(False) # noqa: FBT003 - .otherwise(pl.col("is_holiday")) # keep existing values - .alias("is_holiday") - ) - .drop("temporary_weekday") - ) - - return data.with_columns( - [ - pl.col("open_price").fill_null(0.0), - pl.col("high_price").fill_null(0.0), - pl.col("low_price").fill_null(0.0), - pl.col("close_price").fill_null(0.0), - pl.col("volume").fill_null(0), - pl.col("volume_weighted_average_price").fill_null(0.0), - pl.col("sector").fill_null("NOT AVAILABLE"), - pl.col("industry").fill_null("NOT AVAILABLE"), - pl.col("ticker").fill_null("UNKNOWN"), - pl.col("timestamp").fill_null( - pl.col("date") - .cast(pl.Datetime) - .dt.replace_time_zone("UTC") - .cast(pl.Int64) - .floordiv(1000) - ), - ] ) - -class EngineerFeatures(Stage): - @property - def name(self) -> str: - return "engineer_features" - - def run(self, data: pl.DataFrame) -> pl.DataFrame: data = data.with_columns( pl.col("date").dt.weekday().alias("day_of_week").cast(pl.Int64), pl.col("date").dt.day().alias("day_of_month").cast(pl.Int64), @@ -231,23 +168,39 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: data = data.unique(subset=["ticker", "timestamp"]) + # Collect stats across all columns before filtering so counts reflect the + # original data and rows with multiple bad columns are reported accurately + column_stats: dict[str, dict[str, int]] = {} for col in CONTINUOUS_COLUMNS: - nan_count = data.filter(pl.col(col).is_nan()).height - null_count = data.filter(pl.col(col).is_null()).height - inf_count = data.filter(~pl.col(col).is_finite()).height + series = data[col] + nan_count = int(series.is_nan().sum() or 0) + null_count = int(series.is_null().sum() or 0) + inf_count = int(series.is_infinite().sum() or 0) if nan_count > 0 or null_count > 0 or inf_count > 0: - logger.warning( - "Invalid values in continuous column before scaling", - column=col, - nan_count=nan_count, - null_count=null_count, - inf_count=inf_count, - ) - data = data.filter( - pl.col(col).is_not_nan() - & pl.col(col).is_not_null() - & pl.col(col).is_finite() + column_stats[col] = { + "nan_count": nan_count, + "null_count": null_count, + "inf_count": inf_count, + } + + for col, stats in column_stats.items(): + logger.warning( + "Invalid values in continuous column before scaling", + column=col, + **stats, + ) + + if column_stats: + data = data.filter( + pl.all_horizontal( + [ + pl.col(col).is_not_nan() + & pl.col(col).is_not_null() + & pl.col(col).is_finite() + for col in CONTINUOUS_COLUMNS + ] ) + ) data_validated = data_schema.validate(data) return cast( @@ -302,7 +255,6 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: "ticker", "sector", "industry", - "is_holiday", ] mappings: FeatureMappings = {} @@ -331,8 +283,6 @@ def _create_mapping_and_encoding( def default_stages() -> list[Stage]: return [ ValidateColumns(), - ExpandDateRange(), - FillNulls(), EngineerFeatures(), CleanData(), ScaleAndEncode(), @@ -439,14 +389,13 @@ def get_dimensions(self) -> dict[str, int]: "static_continuous_features": 0, # not using static_continuous_features for now # noqa: E501 } - def get_batches( # noqa: C901 + def get_dataset( # noqa: C901 self, - data_type: str = "train", # "train", "validate", or "predict" + data_type: str = "train", validation_split: float = 0.8, input_length: int = 35, - output_length: int = 7, - batch_size: int = 32, - ) -> list[dict[str, Tensor]]: + output_length: int = 5, + ) -> TrainingDataset: if data_type not in {"train", "validate", "predict"}: message = f"Invalid data type: {data_type}. Must be 'train', 'validate', or 'predict'." # noqa: E501 raise ValueError(message) @@ -455,12 +404,10 @@ def get_batches( # noqa: C901 self.batch_data, _ = self._get_training_and_validation_data( validation_split ) - elif data_type == "validate": _, self.batch_data = self._get_training_and_validation_data( validation_split ) - elif data_type == "predict": self.batch_data = self._get_prediction_data(input_length + output_length) @@ -480,26 +427,26 @@ def get_batches( # noqa: C901 ) raise ValueError(message) - # collect all samples first - samples = [] + all_past_continuous: list[np.ndarray] = [] + all_past_categorical: list[np.ndarray] = [] + all_future_categorical: list[np.ndarray] = [] + all_static_categorical: list[np.ndarray] = [] + all_targets: list[np.ndarray] = [] has_targets = data_type in {"train", "validate"} - # Partition by ticker once upfront (much faster than filtering per ticker) logger.info("Partitioning data by ticker") ticker_groups = self.batch_data.sort("time_idx").partition_by( "ticker", as_dict=True ) total_tickers = len(ticker_groups) - logger.info("Batch creation started", total_tickers=total_tickers) + logger.info("Dataset creation started", total_tickers=total_tickers) for idx, ticker_df in enumerate(ticker_groups.values()): if idx % 25 == 0: logger.info( - "Batch progress", ticker_idx=idx, total_tickers=total_tickers + "Dataset progress", ticker_idx=idx, total_tickers=total_tickers ) - # Convert to numpy once per ticker (avoid repeated DataFrame operations) - # Use float32 for GPU compatibility (Metal doesn't support float64) cat_array = ticker_df[self.categorical_columns].to_numpy().astype(np.int32) cont_array = ( ticker_df[self.continuous_columns].to_numpy().astype(np.float32) @@ -522,62 +469,53 @@ def get_batches( # noqa: C901 if num_windows <= 0: continue - # For prediction, only use the last window (most recent data) - # For training/validation, use all windows window_indices = ( [num_windows - 1] if data_type == "predict" else range(num_windows) ) - # Use numpy slicing (much faster than DataFrame slicing) for i in window_indices: - sample = { - "past_categorical": cat_array[i : i + input_length].copy(), - "past_continuous": cont_array[i : i + input_length].copy(), - "future_categorical": cat_array[ - i + input_length : i + input_length + output_length - ].copy(), - "static_categorical": static_array.copy(), - } - - if has_targets and target_array is not None: - sample["targets"] = target_array[ + all_past_continuous.append(cont_array[i : i + input_length].copy()) + all_past_categorical.append(cat_array[i : i + input_length].copy()) + all_future_categorical.append( + cat_array[ i + input_length : i + input_length + output_length ].copy() + ) + all_static_categorical.append(static_array.copy()) - samples.append(sample) - - logger.info("Sample collection complete", total_samples=len(samples)) - - # now batch the samples - logger.info("Batching samples", batch_size=batch_size) - batches = [] - for i in range(0, len(samples), batch_size): - batch_samples = samples[i : i + batch_size] - - batch = { - "past_categorical_features": Tensor( - np.stack([s["past_categorical"] for s in batch_samples]) - ), - "past_continuous_features": Tensor( - np.stack([s["past_continuous"] for s in batch_samples]) - ), - "future_categorical_features": Tensor( - np.stack([s["future_categorical"] for s in batch_samples]) - ), - "static_categorical_features": Tensor( - np.stack([s["static_categorical"] for s in batch_samples]) - ), - } + if has_targets and target_array is not None: + all_targets.append( + target_array[ + i + input_length : i + input_length + output_length + ].copy() + ) - if data_type in {"train", "validate"}: - batch["targets"] = Tensor( - np.stack([s["targets"] for s in batch_samples]) - ) + logger.info( + "Sample collection complete", total_samples=len(all_past_continuous) + ) - batches.append(batch) + n_cont = len(self.continuous_columns) + n_cat = len(self.categorical_columns) + n_static = len(self.static_categorical_columns) + + if not all_past_continuous: + return TrainingDataset( + past_continuous=np.zeros((0, input_length, n_cont), dtype=np.float32), + past_categorical=np.zeros((0, input_length, n_cat), dtype=np.int32), + future_categorical=np.zeros((0, output_length, n_cat), dtype=np.int32), + static_categorical=np.zeros((0, 1, n_static), dtype=np.int32), + targets=None + if not has_targets + else np.zeros((0, output_length, 1), dtype=np.float32), + ) - logger.info("Batch creation complete", total_batches=len(batches)) - return batches + return TrainingDataset( + past_continuous=np.stack(all_past_continuous), + past_categorical=np.stack(all_past_categorical), + future_categorical=np.stack(all_future_categorical), + static_categorical=np.stack(all_static_categorical), + targets=np.stack(all_targets) if all_targets else None, + ) def save(self, directory_path: str) -> None: os.makedirs(directory_path, exist_ok=True) # noqa: PTH103 @@ -629,10 +567,8 @@ def load(cls, directory_path: str) -> "Data": def postprocess_predictions( self, - input_batch: dict[ - str, Tensor - ], # batch dictionary with static_categorical_features - predictions: Tensor, # quantiles + input_batch: dict[str, "Tensor"], + predictions: "Tensor", current_datetime: datetime, ) -> pl.DataFrame: predictions_array = predictions.numpy() @@ -695,9 +631,7 @@ def postprocess_predictions( ), "open_price": pa.Column( dtype=float, - checks=pa.Check.greater_than_or_equal_to( - 0 - ), # zeros are allowed for missing days + checks=pa.Check.greater_than_or_equal_to(0), ), "high_price": pa.Column( dtype=float, @@ -728,7 +662,6 @@ def postprocess_predictions( "day_of_year": pa.Column(dtype=int), "month": pa.Column(dtype=int), "year": pa.Column(dtype=int), - "is_holiday": pa.Column(dtype=bool), "time_idx": pa.Column(dtype=int), "daily_return": pa.Column(dtype=float), }, diff --git a/models/tide/src/tide/deploy.py b/models/tide/src/tide/deploy.py index 6581982f1..412cfdcb2 100644 --- a/models/tide/src/tide/deploy.py +++ b/models/tide/src/tide/deploy.py @@ -2,7 +2,7 @@ import sys import structlog -from prefect.flows import EntrypointType +from prefect.schedules import Schedule from tide.workflow import training_pipeline @@ -11,6 +11,7 @@ def deploy_training_flow( lookback_days: int = 365, + image: str | None = None, ) -> None: """Register the training pipeline deployment with the Prefect server.""" logger.info( @@ -21,13 +22,17 @@ def deploy_training_flow( training_pipeline.deploy( name="tide-trainer-remote", work_pool_name="fund-models-remote", - cron="0 22 * * 1-5", - timezone="America/New_York", + image=image, + schedule=Schedule(cron="0 22 * * 1-5", timezone="America/New_York"), parameters={ "lookback_days": lookback_days, }, - tags=["training", "daily"], - entrypoint_type=EntrypointType.MODULE_PATH, + job_variables={ + "cpu": 4096, + "memory": 14336, + }, + concurrency_limit=1, + tags=["training", "tide"], build=False, push=False, ) @@ -46,6 +51,12 @@ def deploy_training_flow( logger.error("FUND_LOOKBACK_DAYS must be positive", lookback_days=lookback_days) sys.exit(1) + image = os.getenv("FUND_TIDE_IMAGE_URI") + if not image: + logger.error("FUND_TIDE_IMAGE_URI environment variable is required") + sys.exit(1) + deploy_training_flow( lookback_days=lookback_days, + image=image, ) diff --git a/models/tide/src/tide/tide_model.py b/models/tide/src/tide/model.py similarity index 72% rename from models/tide/src/tide/tide_model.py rename to models/tide/src/tide/model.py index db6016a73..bbffedd58 100644 --- a/models/tide/src/tide/tide_model.py +++ b/models/tide/src/tide/model.py @@ -1,10 +1,11 @@ import json +import math import os -import random from typing import cast import numpy as np import structlog +from tinygrad import Device from tinygrad.nn import Linear from tinygrad.nn.optim import Adam from tinygrad.nn.state import ( @@ -16,8 +17,12 @@ ) from tinygrad.tensor import Tensor +from tide.data import TrainingDataset + logger = structlog.get_logger() +_rng = np.random.default_rng() + def quantile_loss( predictions: Tensor, @@ -126,7 +131,7 @@ def __init__( # noqa: PLR0913 hidden_size: int = 128, num_encoder_layers: int = 2, num_decoder_layers: int = 2, - output_length: int = 7, # number of days to forecast + output_length: int = 5, # number of trading days to forecast dropout_rate: float = 0.1, quantiles: list[float] | None = None, huber_delta: float = 0.0, @@ -218,56 +223,54 @@ def forward(self, x: Tensor) -> Tensor: # reshape to (batch_size, output_length, num_quantiles) return x.reshape(batch_size, self.output_length, len(self.quantiles)) - def _validate_batch(self, batch: dict[str, Tensor], _batch_idx: int) -> dict: - """Check a batch for NaN/Inf values and return statistics.""" - issues = {} - for key, tensor in batch.items(): - data = tensor.numpy() - nan_count = int(np.isnan(data).sum()) - inf_count = int(np.isinf(data).sum()) - if nan_count > 0 or inf_count > 0: - issues[key] = { - "nan_count": nan_count, - "inf_count": inf_count, - "total_elements": data.size, - "nan_pct": f"{(nan_count / data.size) * 100:.2f}%", - } - return issues - def validate_training_data( self, - train_batches: list, + dataset: TrainingDataset, sample_size: int = 10, ) -> bool: """Validate training data for NaN/Inf values.""" - total_batches = len(train_batches) - actual_sample_size = min(sample_size, total_batches) + total_samples = len(dataset) + actual_sample_size = min(sample_size, total_samples) logger.info( "Validating training data", - total_batches=total_batches, + total_samples=total_samples, sample_size=actual_sample_size, ) all_issues: dict[str, dict] = {} - # Sample batches to validate for efficiency with large datasets - sampled_indices = random.sample(range(total_batches), actual_sample_size) - for idx in sampled_indices: - batch = train_batches[idx] - batch_issues = self._validate_batch(batch, idx) - if batch_issues: - all_issues[f"batch_{idx}"] = batch_issues + sampled_indices = _rng.choice(total_samples, actual_sample_size, replace=False) + + arrays: dict[str, np.ndarray] = { + "past_continuous": dataset.past_continuous[sampled_indices], + "past_categorical": dataset.past_categorical[sampled_indices], + "future_categorical": dataset.future_categorical[sampled_indices], + "static_categorical": dataset.static_categorical[sampled_indices], + } + if dataset.targets is not None: + arrays["targets"] = dataset.targets[sampled_indices] + + for name, array in arrays.items(): + if array.dtype.kind != "f": + continue + nan_count = int(np.isnan(array).sum()) + inf_count = int(np.isinf(array).sum()) + if nan_count > 0 or inf_count > 0: + all_issues[name] = { + "nan_count": nan_count, + "inf_count": inf_count, + "total_elements": array.size, + "nan_pct": f"{(nan_count / array.size) * 100:.2f}%", + } if all_issues: - for batch_key, features in all_issues.items(): - for feature_key, stats in features.items(): - logger.error( - "Invalid values in training data", - batch=batch_key, - feature=feature_key, - **stats, - ) + for feature_key, stats in all_issues.items(): + logger.error( + "Invalid values in training data", + feature=feature_key, + **stats, + ) return False logger.info("Training data validation passed") @@ -275,53 +278,33 @@ def validate_training_data( def train( # noqa: PLR0913, PLR0912, PLR0915, C901 self, - train_batches: list, + dataset: TrainingDataset, + batch_size: int = 256, epochs: int = 10, learning_rate: float = 0.001, log_interval: int = 100, validate_data: bool = True, # noqa: FBT001, FBT002 validation_sample_size: int = 10, + validation_dataset: TrainingDataset | None = None, early_stopping_patience: int | None = 3, early_stopping_min_delta: float = 0.001, checkpoint_directory: str | None = None, - ) -> list: - """Train the TiDE model using quantile loss. - - Args: - train_batches: List of training batch dictionaries - epochs: Maximum number of epochs to train - learning_rate: Learning rate for optimizer - log_interval: Log progress every N steps - validate_data: Whether to validate data before training. - Note: Data validation samples batches to check for NaN/Inf values. - For large datasets (>1000 batches), validation may take several seconds. - Set to False to skip validation if data quality is already guaranteed. - validation_sample_size: Number of batches to sample during validation. - Larger values provide more thorough validation but take longer. - The first and last batches are always checked. Default is 10. - early_stopping_patience: Stop if no improvement for N epochs - (None to disable) - early_stopping_min_delta: Minimum improvement to reset patience counter - - Performance Notes: - - Data validation runs once before training starts - - Validation time scales with validation_sample_size, not total batches - - For datasets with 10,000 batches, validation with - validation_sample_size=10 typically completes in under a second - (hardware dependent) - - Increase validation_sample_size for more thorough checking or decrease - for faster training startup - """ - if not train_batches: + ) -> list[float]: + """Train the TiDE model using quantile loss.""" + if len(dataset) == 0: return [] + if dataset.targets is None: + message = "Targets are required for training" + raise ValueError(message) + if validation_sample_size <= 0: message = "validation_sample_size must be positive" raise ValueError(message) if validate_data: is_valid = self.validate_training_data( - train_batches, + dataset, sample_size=validation_sample_size, ) if not is_valid: @@ -334,7 +317,8 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 parameters = get_parameters(self) optimizer = Adam(params=parameters, lr=learning_rate) losses = [] - total_batches = len(train_batches) + num_samples = len(dataset) + total_batches = (num_samples + batch_size - 1) // batch_size best_loss = float("inf") best_saved_loss = float("inf") @@ -348,6 +332,15 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 checkpoint_directory, "best_checkpoint.safetensor" ) + logger.info("Training device", device=Device.DEFAULT) + + # Pre-load dataset onto compute device once to eliminate per-step transfers + gpu_past_continuous = Tensor(dataset.past_continuous) + gpu_past_categorical = Tensor(dataset.past_categorical) + gpu_future_categorical = Tensor(dataset.future_categorical) + gpu_static_categorical = Tensor(dataset.static_categorical) + gpu_targets = Tensor(dataset.targets) + try: for epoch in range(epochs): logger.info( @@ -357,17 +350,25 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 total_batches=total_batches, ) epoch_losses = [] - - for step, batch in enumerate(train_batches): - if batch["past_continuous_features"].shape[0] == 0: - logger.warning( - "Skipping empty batch", - step=step + 1, - epoch=epoch + 1, - ) - continue - - combined_input_features, targets, batch_size = ( + indices = _rng.permutation(num_samples) + + for step in range(total_batches): + batch_idx = indices[ + step * batch_size : (step + 1) * batch_size + ].tolist() + batch = { + "past_continuous_features": gpu_past_continuous[batch_idx], + "past_categorical_features": gpu_past_categorical[batch_idx], + "future_categorical_features": gpu_future_categorical[ + batch_idx + ], + "static_categorical_features": gpu_static_categorical[ + batch_idx + ], + "targets": gpu_targets[batch_idx], + } + + combined_input_features, targets, batch_size_actual = ( self._combine_input_features(batch) ) @@ -379,7 +380,9 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 predictions = self.forward(combined_input_features) # reshape targets to (batch_size, output_length) - targets_reshaped = targets.reshape(batch_size, self.output_length) + targets_reshaped = targets.reshape( + batch_size_actual, self.output_length + ) loss = quantile_loss( predictions, @@ -391,6 +394,7 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 optimizer.zero_grad() loss.backward() optimizer.step() + Tensor.realize(*get_parameters(self)) step_loss = loss.numpy().item() epoch_losses.append(step_loss) @@ -414,6 +418,14 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 epoch_loss = sum(epoch_losses) / len(epoch_losses) + # Use validation loss for early stopping if validation_dataset provided + if validation_dataset is not None: + stopping_loss = self.validate_model(validation_dataset, batch_size) + if math.isnan(stopping_loss): + stopping_loss = epoch_loss + else: + stopping_loss = epoch_loss + logger.info( "Completed training epoch", epoch=epoch + 1, @@ -424,19 +436,22 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 losses.append(epoch_loss) - if checkpoint_path and epoch_loss < best_saved_loss: - best_saved_loss = epoch_loss + checkpoint_metric = ( + stopping_loss if validation_dataset is not None else epoch_loss + ) + if checkpoint_path and checkpoint_metric < best_saved_loss: + best_saved_loss = checkpoint_metric safe_save(get_state_dict(self), checkpoint_path) checkpoint_saved = True logger.info( "Saved best checkpoint", checkpoint_path=checkpoint_path, - loss=f"{epoch_loss:.4f}", + loss=f"{checkpoint_metric:.4f}", ) if early_stopping_patience is not None: - if epoch_loss < best_loss - early_stopping_min_delta: - best_loss = epoch_loss + if stopping_loss < best_loss - early_stopping_min_delta: + best_loss = stopping_loss epochs_without_improvement = 0 logger.info( "New best loss", @@ -473,16 +488,45 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901 return losses - def validate(self, validation_batches: list) -> float: - """Validate the model using quantile loss""" + def validate_model(self, dataset: TrainingDataset, batch_size: int = 256) -> float: + """Validate the model using quantile loss.""" prev_training = Tensor.training Tensor.training = False try: + if len(dataset) == 0: + logger.warning("No validation samples provided; returning NaN loss") + return float("nan") + + if dataset.targets is None: + message = "Targets are required for validation" + raise ValueError(message) + validation_losses = [] + num_samples = len(dataset) + total_batches = (num_samples + batch_size - 1) // batch_size + + for step in range(total_batches): + batch_idx = np.arange( + step * batch_size, min((step + 1) * batch_size, num_samples) + ) + batch = { + "past_continuous_features": Tensor( + dataset.past_continuous[batch_idx] + ), + "past_categorical_features": Tensor( + dataset.past_categorical[batch_idx] + ), + "future_categorical_features": Tensor( + dataset.future_categorical[batch_idx] + ), + "static_categorical_features": Tensor( + dataset.static_categorical[batch_idx] + ), + "targets": Tensor(dataset.targets[batch_idx]), + } - for batch in validation_batches: - combined_input, targets, batch_size = self._combine_input_features( - batch + combined_input, targets, batch_size_actual = ( + self._combine_input_features(batch) ) if targets is None: @@ -490,14 +534,19 @@ def validate(self, validation_batches: list) -> float: raise ValueError(message) predictions = self.forward(combined_input) - - targets_reshaped = targets.reshape(batch_size, self.output_length) - - loss = quantile_loss(predictions, targets_reshaped, self.quantiles) + targets_reshaped = targets.reshape( + batch_size_actual, self.output_length + ) + loss = quantile_loss( + predictions, + targets_reshaped, + self.quantiles, + huber_delta=self.huber_delta, + ) validation_losses.append(loss.numpy().item()) if not validation_losses: - logger.warning("No validation batches provided; returning NaN loss") + logger.warning("No validation batches processed; returning NaN loss") return float("nan") return sum(validation_losses) / len(validation_losses) diff --git a/models/tide/src/tide/run.py b/models/tide/src/tide/run.py deleted file mode 100644 index 620149c79..000000000 --- a/models/tide/src/tide/run.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -import sys - -import structlog - -from tide.workflow import training_pipeline - -logger = structlog.get_logger() - - -def run_training_job( - base_url: str, - data_bucket: str, - artifacts_bucket: str, - lookback_days: int = 365, -) -> str: - """Run the TiDE training pipeline via Prefect.""" - if lookback_days <= 0: - message = "lookback_days must be positive" - raise ValueError(message) - - logger.info( - "Starting training pipeline", - base_url=base_url, - data_bucket=data_bucket, - artifacts_bucket=artifacts_bucket, - lookback_days=lookback_days, - ) - - artifact_path = training_pipeline( - base_url=base_url, - data_bucket=data_bucket, - artifacts_bucket=artifacts_bucket, - lookback_days=lookback_days, - ) - - logger.info("Training pipeline complete", artifact_path=artifact_path) - - return artifact_path - - -if __name__ == "__main__": - base_url = os.getenv("FUND_DATA_MANAGER_BASE_URL", "") - data_bucket = os.getenv("AWS_S3_DATA_BUCKET_NAME", "") - artifacts_bucket = os.getenv("AWS_S3_MODEL_ARTIFACTS_BUCKET_NAME", "") - - required_vars = { - "FUND_DATA_MANAGER_BASE_URL": base_url, - "AWS_S3_DATA_BUCKET_NAME": data_bucket, - "AWS_S3_MODEL_ARTIFACTS_BUCKET_NAME": artifacts_bucket, - } - - missing = [key for key, value in required_vars.items() if not value] - if missing: - logger.error("Missing required environment variables", missing=missing) - sys.exit(1) - - try: - lookback_days = int(os.getenv("FUND_LOOKBACK_DAYS", "365")) - except ValueError: - logger.exception("FUND_LOOKBACK_DAYS must be a valid integer") - sys.exit(1) - - if lookback_days <= 0: - logger.error("FUND_LOOKBACK_DAYS must be positive", lookback_days=lookback_days) - sys.exit(1) - - try: - run_training_job( - base_url=base_url, - data_bucket=data_bucket, - artifacts_bucket=artifacts_bucket, - lookback_days=lookback_days, - ) - except Exception as e: - logger.exception("Training pipeline failed", error=str(e)) - sys.exit(1) diff --git a/models/tide/src/tide/tasks.py b/models/tide/src/tide/tasks.py index 75398c976..3630ff692 100644 --- a/models/tide/src/tide/tasks.py +++ b/models/tide/src/tide/tasks.py @@ -6,6 +6,7 @@ import polars as pl import structlog from botocore.exceptions import ClientError +from internal.equity_bars_schema import equity_bars_schema if TYPE_CHECKING: from mypy_boto3_s3 import S3Client @@ -15,6 +16,16 @@ MINIMUM_CLOSE_PRICE = 1.0 MINIMUM_VOLUME = 100_000 +_COLUMN_TYPES: dict[str, type[pl.DataType]] = { + "open_price": pl.Float64, + "high_price": pl.Float64, + "low_price": pl.Float64, + "close_price": pl.Float64, + "volume_weighted_average_price": pl.Float64, + "volume": pl.Int64, + "transactions": pl.Int64, +} + def read_equity_bars_from_s3( s3_client: "S3Client", @@ -47,6 +58,13 @@ def read_equity_bars_from_s3( response = s3_client.get_object(Bucket=bucket_name, Key=key) parquet_bytes = response["Body"].read() dataframe = pl.read_parquet(parquet_bytes) + dataframe = dataframe.with_columns( + [ + pl.col(col).cast(dtype) + for col, dtype in _COLUMN_TYPES.items() + if col in dataframe.columns + ] + ) batch_dataframes.append(dataframe) logger.debug("Read parquet file", key=key, rows=dataframe.height) except s3_client.exceptions.NoSuchKey: @@ -110,6 +128,7 @@ def filter_equity_bars( filtered = data.filter( (pl.col("close_price") >= minimum_close_price) & (pl.col("volume") >= minimum_volume) + & ~pl.col("ticker").str.contains("[a-z]") ) logger.info("Filtered equity bars", output_rows=filtered.height) @@ -151,7 +170,12 @@ def consolidate_data( if missing_columns: logger.warning("Missing columns in consolidated data", missing=missing_columns) - result = consolidated.select(available_columns) + if "sector" in available_columns and "industry" in available_columns: + result = consolidated.select(available_columns).filter( + pl.col("sector").is_not_null() & pl.col("industry").is_not_null() + ) + else: + result = consolidated.select(available_columns) logger.info( "Consolidated data", output_rows=result.height, columns=available_columns @@ -225,6 +249,8 @@ def prepare_training_data( # noqa: PLR0913 filtered_bars = filter_equity_bars(equity_bars) + equity_bars_schema.validate(filtered_bars) + consolidated = consolidate_data( equity_bars=filtered_bars, categories=categories, diff --git a/models/tide/src/tide/trainer.py b/models/tide/src/tide/trainer.py index 71c273f59..d91f64144 100644 --- a/models/tide/src/tide/trainer.py +++ b/models/tide/src/tide/trainer.py @@ -3,25 +3,25 @@ import polars as pl import structlog -from tide.tide_data import Data -from tide.tide_model import Model +from tide.data import Data +from tide.model import Model from tide.tracking import log_epoch_loss, log_training_result logger = structlog.get_logger() DEFAULT_CONFIGURATION: dict[str, Any] = { "architecture": "TiDE", - "learning_rate": 0.0005, - "epoch_count": 200, - "early_stopping_patience": 25, + "learning_rate": 0.001, + "epoch_count": 20, + "early_stopping_patience": 3, "validation_split": 0.8, "input_length": 35, - "output_length": 7, + "output_length": 5, "hidden_size": 64, "num_encoder_layers": 3, "num_decoder_layers": 2, "dropout_rate": 0.1, - "batch_size": 32, + "batch_size": 512, "huber_delta": 0.5, "quantiles": [0.1, 0.5, 0.9], } @@ -50,51 +50,63 @@ def train_model( dimensions = tide_data.get_dimensions() logger.info("Data dimensions", **dimensions) - logger.info("Creating training batches") - train_batches = tide_data.get_batches( + logger.info("Creating training dataset") + train_dataset = tide_data.get_dataset( data_type="train", validation_split=float(configuration["validation_split"]), input_length=int(configuration["input_length"]), output_length=int(configuration["output_length"]), - batch_size=int(configuration["batch_size"]), ) - logger.info("Training batches created", batch_count=len(train_batches)) + logger.info("Creating validation dataset") + try: + validation_dataset = tide_data.get_dataset( + data_type="validate", + validation_split=float(configuration["validation_split"]), + input_length=int(configuration["input_length"]), + output_length=int(configuration["output_length"]), + ) + except ValueError as e: + if "Total days available" not in str(e): + raise + logger.warning( + "Validation set too small for windowing; disabling validation early stopping" # noqa: E501 + ) + validation_dataset = None + + logger.info("Training dataset created", sample_count=len(train_dataset)) - if not train_batches: + if len(train_dataset) == 0: logger.error( - "No training batches created", + "No training samples created", validation_split=configuration["validation_split"], input_length=configuration["input_length"], output_length=configuration["output_length"], - batch_size=configuration["batch_size"], training_data_rows=training_data.height, ) message = ( - "No training batches created - check input data and configuration. " + "No training samples created - check input data and configuration. " f"Training data has {training_data.height} rows, " f"input_length={configuration['input_length']}, " - f"output_length={configuration['output_length']}, " - f"batch_size={configuration['batch_size']}" + f"output_length={configuration['output_length']}" ) raise ValueError(message) - sample_batch = train_batches[0] - - batch_size = sample_batch["past_continuous_features"].shape[0] - logger.info("Batch size determined", batch_size=batch_size) - + # Compute input_size from numpy array shapes: [N, time_steps, features] past_continuous_size = ( - sample_batch["past_continuous_features"].reshape(batch_size, -1).shape[1] + train_dataset.past_continuous.shape[1] * train_dataset.past_continuous.shape[2] ) past_categorical_size = ( - sample_batch["past_categorical_features"].reshape(batch_size, -1).shape[1] + train_dataset.past_categorical.shape[1] + * train_dataset.past_categorical.shape[2] ) future_categorical_size = ( - sample_batch["future_categorical_features"].reshape(batch_size, -1).shape[1] + train_dataset.future_categorical.shape[1] + * train_dataset.future_categorical.shape[2] ) static_categorical_size = ( - sample_batch["static_categorical_features"].reshape(batch_size, -1).shape[1] + train_dataset.static_categorical.shape[1] + * train_dataset.static_categorical.shape[2] ) input_size = cast( @@ -123,9 +135,11 @@ def train_model( early_stopping_patience = configuration.get("early_stopping_patience", 25) losses = tide_model.train( - train_batches=train_batches, + dataset=train_dataset, + batch_size=int(configuration["batch_size"]), epochs=int(configuration["epoch_count"]), learning_rate=float(configuration["learning_rate"]), + validation_dataset=validation_dataset, checkpoint_directory=checkpoint_directory, early_stopping_patience=( int(early_stopping_patience) diff --git a/models/tide/src/tide/workflow.py b/models/tide/src/tide/workflow.py index 7f50d9e7d..e26d8ee38 100644 --- a/models/tide/src/tide/workflow.py +++ b/models/tide/src/tide/workflow.py @@ -83,7 +83,7 @@ def prepare_data( return output_key -@task(name="train-tide-model", timeout_seconds=3600) +@task(name="train-tide-model", timeout_seconds=14400) def train_tide_model( training_data_key: str = "training/filtered_tide_training_data.parquet", ) -> str: diff --git a/models/tide/tests/test_tide_data.py b/models/tide/tests/test_data.py similarity index 72% rename from models/tide/tests/test_tide_data.py rename to models/tide/tests/test_data.py index 9411fbee5..34c717544 100644 --- a/models/tide/tests/test_tide_data.py +++ b/models/tide/tests/test_data.py @@ -2,27 +2,23 @@ import tempfile from collections.abc import Callable from pathlib import Path -from typing import TYPE_CHECKING import polars as pl import pytest -from tide.tide_data import ( +from tide.data import ( CleanData, Data, EngineerFeatures, - ExpandDateRange, - FillNulls, Pipeline, ScaleAndEncode, + TrainingDataset, ValidateColumns, ) -if TYPE_CHECKING: - from datetime import date - -FRIDAY_WEEKDAY = 4 EXPECTED_PAST_CONTINUOUS_FEATURES = 7 -EXPECTED_PAST_CATEGORICAL_FEATURES = 6 +EXPECTED_INPUT_LENGTH = 35 +EXPECTED_OUTPUT_LENGTH = 5 +EXPECTED_PAST_CATEGORICAL_FEATURES = 5 EXPECTED_STATIC_CATEGORICAL_FEATURES = 3 @@ -50,47 +46,11 @@ def test_validate_columns_rejects_extra_column( ValidateColumns().run(data) -def test_expand_date_range_fills_gaps( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - unique_dates = expanded.select("date").unique().height - min_date: date = expanded.select(pl.col("date").min()).item() - max_date: date = expanded.select(pl.col("date").max()).item() - expected_dates = (max_date - min_date).days + 1 - assert unique_dates == expected_dates - - -def test_fill_nulls_replaces_null_prices( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - filled = FillNulls().run(expanded) - null_count = filled.select(pl.col("open_price").is_null().sum()).item() - assert null_count == 0 - - -def test_fill_nulls_sets_holidays_for_missing_weekdays( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - filled = FillNulls().run(expanded) - weekday_nulls = filled.filter( - (pl.col("date").dt.weekday() <= FRIDAY_WEEKDAY) & pl.col("is_holiday").is_null() - ) - assert weekday_nulls.height == 0 - - def test_engineer_features_adds_calendar_columns( make_raw_data: Callable[..., pl.DataFrame], ) -> None: data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - filled = FillNulls().run(expanded) - featured = EngineerFeatures().run(filled) + featured = EngineerFeatures().run(data) expected_columns = { "day_of_week", "day_of_month", @@ -99,6 +59,7 @@ def test_engineer_features_adds_calendar_columns( "year", "time_idx", "daily_return", + "date", } assert expected_columns.issubset(set(featured.columns)) @@ -107,9 +68,7 @@ def test_engineer_features_time_idx_is_dense_rank( make_raw_data: Callable[..., pl.DataFrame], ) -> None: data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - filled = FillNulls().run(expanded) - featured = EngineerFeatures().run(filled) + featured = EngineerFeatures().run(data) time_indices = featured.sort("timestamp").select("time_idx").to_series().to_list() assert time_indices == sorted(time_indices) assert time_indices[0] == 1 @@ -119,10 +78,7 @@ def test_clean_data_removes_unknown_tickers( make_raw_data: Callable[..., pl.DataFrame], ) -> None: data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - filled = FillNulls().run(expanded) - featured = EngineerFeatures().run(filled) - # Add an UNKNOWN ticker row + featured = EngineerFeatures().run(data) unknown_row = featured.head(1).with_columns(pl.lit("UNKNOWN").alias("ticker")) featured_with_unknown = pl.concat([featured, unknown_row]) cleaned = CleanData().run(featured_with_unknown) @@ -133,9 +89,7 @@ def test_clean_data_removes_nan_daily_return( make_raw_data: Callable[..., pl.DataFrame], ) -> None: data = make_raw_data(tickers=["AAPL"], days=10) - expanded = ExpandDateRange().run(data) - filled = FillNulls().run(expanded) - featured = EngineerFeatures().run(filled) + featured = EngineerFeatures().run(data) cleaned = CleanData().run(featured) nan_count = cleaned.filter(pl.col("daily_return").is_nan()).height null_count = cleaned.filter(pl.col("daily_return").is_null()).height @@ -143,6 +97,30 @@ def test_clean_data_removes_nan_daily_return( assert null_count == 0 +def test_clean_data_removes_rows_with_invalid_values_in_multiple_columns( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + data = make_raw_data(tickers=["AAPL"], days=10) + featured = EngineerFeatures().run(data) + # Inject bad values into two different continuous columns on the same row + bad_row = featured.head(1).with_columns( + pl.lit(float("nan")).alias("open_price"), + pl.lit(float("inf")).alias("close_price"), + ) + # Inject a bad value into only a later continuous column on a different row + bad_row_later = featured.slice(1, 1).with_columns( + pl.lit(float("nan")).alias("volume_weighted_average_price"), + ) + combined = pl.concat([featured.slice(2), bad_row, bad_row_later]) + + cleaned = CleanData().run(combined) + + assert cleaned.filter(pl.col("open_price").is_nan()).height == 0 + assert cleaned.filter(pl.col("close_price").is_infinite()).height == 0 + assert cleaned.filter(pl.col("volume_weighted_average_price").is_nan()).height == 0 + assert cleaned.height == featured.height - 2 + + def test_scale_and_encode_produces_scaler_and_mappings( make_raw_data: Callable[..., pl.DataFrame], ) -> None: @@ -176,7 +154,7 @@ def test_pipeline_run_to_stops_at_stage( ) -> None: data = make_raw_data() pipeline = Pipeline() - result = pipeline.run_to("fill_nulls", data) + result = pipeline.run_to("validate_columns", data) assert "day_of_week" not in result.columns @@ -191,7 +169,7 @@ def test_pipeline_snapshot_roundtrip( ) -> None: data = make_raw_data(tickers=["AAPL"], days=10) pipeline = Pipeline() - result = pipeline.run_to("fill_nulls", data) + result = pipeline.run_to("engineer_features", data) snapshot_path = str(tmp_path / "snapshot.parquet") Pipeline.snapshot(result, snapshot_path) loaded = Pipeline.load_snapshot(snapshot_path) @@ -228,7 +206,6 @@ def test_data_preprocess_output_columns_match_schema( "sector", "industry", "date", - "is_holiday", "day_of_week", "day_of_month", "day_of_year", @@ -273,26 +250,67 @@ def test_data_get_dimensions( ) -def test_data_get_batches_train( +def test_data_get_dataset_train( make_raw_data: Callable[..., pl.DataFrame], ) -> None: raw = make_raw_data(days=90) data = Data() data.preprocess_and_set_data(raw) - batches = data.get_batches( + dataset = data.get_dataset( data_type="train", validation_split=0.8, input_length=35, - output_length=7, - batch_size=32, + output_length=5, + ) + assert isinstance(dataset, TrainingDataset) + assert len(dataset) > 0 + assert dataset.past_continuous.shape[1] == EXPECTED_INPUT_LENGTH + assert dataset.past_categorical.shape[1] == EXPECTED_INPUT_LENGTH + assert dataset.future_categorical.shape[1] == EXPECTED_OUTPUT_LENGTH + assert dataset.static_categorical.shape[2] == EXPECTED_STATIC_CATEGORICAL_FEATURES + assert dataset.targets is not None + + +def test_data_get_dataset_validate( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + # Need enough days so the 20% validation split has > input_length+output_length days + raw = make_raw_data(days=300) + data = Data() + data.preprocess_and_set_data(raw) + dataset = data.get_dataset( + data_type="validate", + validation_split=0.8, + input_length=35, + output_length=5, + ) + assert isinstance(dataset, TrainingDataset) + assert dataset.targets is not None + + +def test_data_get_dataset_predict_has_no_targets( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + raw = make_raw_data(days=90) + data = Data() + data.preprocess_and_set_data(raw) + dataset = data.get_dataset( + data_type="predict", + input_length=35, + output_length=5, ) - assert len(batches) > 0 - batch = batches[0] - assert "past_continuous_features" in batch - assert "past_categorical_features" in batch - assert "future_categorical_features" in batch - assert "static_categorical_features" in batch - assert "targets" in batch + assert isinstance(dataset, TrainingDataset) + assert dataset.targets is None + + +def test_data_get_dataset_invalid_type( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + raw = make_raw_data(days=90) + data = Data() + data.preprocess_and_set_data(raw) + with pytest.raises(ValueError, match="Invalid data type"): + data.get_dataset(data_type="invalid") def test_scale_and_encode_raises_on_nan_scaler( diff --git a/models/tide/tests/test_deploy.py b/models/tide/tests/test_deploy.py index 2cedb7513..06155ff48 100644 --- a/models/tide/tests/test_deploy.py +++ b/models/tide/tests/test_deploy.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, patch -from prefect.flows import EntrypointType +from prefect.schedules import Schedule from tide.deploy import deploy_training_flow LOOKBACK_DAYS = 30 @@ -17,21 +17,24 @@ def test_deploy_training_flow_calls_deploy(mock_pipeline: MagicMock) -> None: call_kwargs = mock_deploy.call_args.kwargs assert call_kwargs["name"] == "tide-trainer-remote" assert call_kwargs["work_pool_name"] == "fund-models-remote" - assert call_kwargs["cron"] == "0 22 * * 1-5" - assert call_kwargs["timezone"] == "America/New_York" + schedule = call_kwargs["schedule"] + assert isinstance(schedule, Schedule) + assert schedule.cron == "0 22 * * 1-5" + assert schedule.timezone == "America/New_York" assert call_kwargs["parameters"]["lookback_days"] == LOOKBACK_DAYS @patch("tide.deploy.training_pipeline") -def test_deploy_training_flow_sets_module_path_entrypoint( +def test_deploy_training_flow_sets_build_options( mock_pipeline: MagicMock, ) -> None: mock_deploy = MagicMock() mock_pipeline.deploy = mock_deploy - deploy_training_flow() + image = "123456789.dkr.ecr.us-east-1.amazonaws.com/fund/tide-model-runner:latest" + deploy_training_flow(image=image) call_kwargs = mock_deploy.call_args.kwargs - assert call_kwargs["entrypoint_type"] == EntrypointType.MODULE_PATH + assert call_kwargs["image"] == image assert call_kwargs["build"] is False assert call_kwargs["push"] is False diff --git a/models/tide/tests/test_tide_model.py b/models/tide/tests/test_model.py similarity index 70% rename from models/tide/tests/test_tide_model.py rename to models/tide/tests/test_model.py index fcc1429e1..e22445b92 100644 --- a/models/tide/tests/test_tide_model.py +++ b/models/tide/tests/test_model.py @@ -1,16 +1,16 @@ import tempfile -from dataclasses import dataclass import numpy as np import pytest -from tide.tide_model import Model, quantile_loss +from tide.data import TrainingDataset +from tide.model import Model, quantile_loss from tinygrad.tensor import Tensor BATCH_SIZE = 4 INPUT_LENGTH = 35 -OUTPUT_LENGTH = 7 +OUTPUT_LENGTH = 5 CONTINUOUS_FEATURES = 7 -CATEGORICAL_FEATURES = 6 +CATEGORICAL_FEATURES = 5 STATIC_FEATURES = 3 HIDDEN_SIZE = 32 NUM_QUANTILES = 3 @@ -24,56 +24,71 @@ rng = np.random.default_rng(42) -@dataclass -class BatchConfig: - batch_size: int = BATCH_SIZE - input_length: int = INPUT_LENGTH - output_length: int = OUTPUT_LENGTH - continuous_features: int = CONTINUOUS_FEATURES - categorical_features: int = CATEGORICAL_FEATURES - static_features: int = STATIC_FEATURES - - -def _make_batch( - config: BatchConfig | None = None, +def _make_dataset( + num_samples: int = BATCH_SIZE, *, include_targets: bool = True, -) -> dict[str, Tensor]: - if config is None: - config = BatchConfig() +) -> TrainingDataset: + return TrainingDataset( + past_continuous=rng.standard_normal( + (num_samples, INPUT_LENGTH, CONTINUOUS_FEATURES) + ).astype(np.float32), + past_categorical=rng.integers( + 0, + CATEGORICAL_UPPER_BOUND, + (num_samples, INPUT_LENGTH, CATEGORICAL_FEATURES), + ).astype(np.int32), + future_categorical=rng.integers( + 0, + CATEGORICAL_UPPER_BOUND, + (num_samples, OUTPUT_LENGTH, CATEGORICAL_FEATURES), + ).astype(np.int32), + static_categorical=rng.integers( + 0, + CATEGORICAL_UPPER_BOUND, + (num_samples, 1, STATIC_FEATURES), + ).astype(np.int32), + targets=( + rng.standard_normal((num_samples, OUTPUT_LENGTH, 1)).astype(np.float32) + if include_targets + else None + ), + ) + + +def _make_batch(*, include_targets: bool = True) -> dict[str, Tensor]: + """Create a single dict[str, Tensor] batch for forward/predict tests.""" batch: dict[str, Tensor] = { "past_continuous_features": Tensor( - rng.standard_normal( - (config.batch_size, config.input_length, config.continuous_features) - ).astype(np.float32) + rng.standard_normal((BATCH_SIZE, INPUT_LENGTH, CONTINUOUS_FEATURES)).astype( + np.float32 + ) ), "past_categorical_features": Tensor( rng.integers( 0, CATEGORICAL_UPPER_BOUND, - (config.batch_size, config.input_length, config.categorical_features), + (BATCH_SIZE, INPUT_LENGTH, CATEGORICAL_FEATURES), ).astype(np.int32) ), "future_categorical_features": Tensor( rng.integers( 0, CATEGORICAL_UPPER_BOUND, - (config.batch_size, config.output_length, config.categorical_features), + (BATCH_SIZE, OUTPUT_LENGTH, CATEGORICAL_FEATURES), ).astype(np.int32) ), "static_categorical_features": Tensor( rng.integers( 0, CATEGORICAL_UPPER_BOUND, - (config.batch_size, 1, config.static_features), + (BATCH_SIZE, 1, STATIC_FEATURES), ).astype(np.int32) ), } if include_targets: batch["targets"] = Tensor( - rng.standard_normal((config.batch_size, config.output_length, 1)).astype( - np.float32 - ) + rng.standard_normal((BATCH_SIZE, OUTPUT_LENGTH, 1)).astype(np.float32) ) return batch @@ -167,9 +182,9 @@ def test_model_train_with_huber_delta() -> None: output_length=OUTPUT_LENGTH, huber_delta=0.5, ) - batches = [_make_batch()] + dataset = _make_dataset() losses = model.train( - train_batches=batches, + dataset=dataset, epochs=EPOCHS_SHORT, learning_rate=LEARNING_RATE, validate_data=False, @@ -194,9 +209,9 @@ def test_model_train_returns_losses() -> None: model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - batches = [_make_batch()] + dataset = _make_dataset() losses = model.train( - train_batches=batches, + dataset=dataset, epochs=EPOCHS_SHORT, learning_rate=LEARNING_RATE, validate_data=False, @@ -205,48 +220,31 @@ def test_model_train_returns_losses() -> None: assert all(isinstance(loss, float) for loss in losses) -def test_model_train_empty_batch_list() -> None: +def test_model_train_empty_dataset() -> None: input_size = _compute_input_size() model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - losses = model.train( - train_batches=[], - epochs=EPOCHS_SHORT, - learning_rate=LEARNING_RATE, - validate_data=False, - ) - assert losses == [] - - -def test_model_train_skips_zero_size_batch() -> None: - input_size = _compute_input_size() - model = Model( - input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH - ) - empty_batch = { - "past_continuous_features": Tensor( - np.zeros((0, INPUT_LENGTH, CONTINUOUS_FEATURES), dtype=np.float32) + empty_dataset = TrainingDataset( + past_continuous=np.zeros( + (0, INPUT_LENGTH, CONTINUOUS_FEATURES), dtype=np.float32 ), - "past_categorical_features": Tensor( - np.zeros((0, INPUT_LENGTH, CATEGORICAL_FEATURES), dtype=np.int32) + past_categorical=np.zeros( + (0, INPUT_LENGTH, CATEGORICAL_FEATURES), dtype=np.int32 ), - "future_categorical_features": Tensor( - np.zeros((0, OUTPUT_LENGTH, CATEGORICAL_FEATURES), dtype=np.int32) - ), - "static_categorical_features": Tensor( - np.zeros((0, 1, STATIC_FEATURES), dtype=np.int32) + future_categorical=np.zeros( + (0, OUTPUT_LENGTH, CATEGORICAL_FEATURES), dtype=np.int32 ), - "targets": Tensor(np.zeros((0, OUTPUT_LENGTH, 1), dtype=np.float32)), - } - normal_batch = _make_batch(BatchConfig(batch_size=BATCH_SIZE)) + static_categorical=np.zeros((0, 1, STATIC_FEATURES), dtype=np.int32), + targets=np.zeros((0, OUTPUT_LENGTH, 1), dtype=np.float32), + ) losses = model.train( - train_batches=[empty_batch, normal_batch], - epochs=EPOCH_SINGLE, + dataset=empty_dataset, + epochs=EPOCHS_SHORT, learning_rate=LEARNING_RATE, validate_data=False, ) - assert len(losses) == EPOCH_SINGLE + assert losses == [] def test_model_train_missing_targets_raises() -> None: @@ -254,44 +252,57 @@ def test_model_train_missing_targets_raises() -> None: model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - batch = _make_batch(include_targets=False) + dataset = _make_dataset(include_targets=False) with pytest.raises(ValueError, match="Targets are required"): model.train( - train_batches=[batch], + dataset=dataset, epochs=EPOCH_SINGLE, learning_rate=LEARNING_RATE, validate_data=False, ) -def test_model_validate_returns_loss() -> None: +def test_model_validate_model_returns_loss() -> None: input_size = _compute_input_size() model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - batches = [_make_batch()] - loss = model.validate(batches) + dataset = _make_dataset() + loss = model.validate_model(dataset) assert isinstance(loss, float) assert loss >= 0 -def test_model_validate_empty_batches_returns_nan() -> None: +def test_model_validate_model_empty_dataset_returns_nan() -> None: input_size = _compute_input_size() model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - loss = model.validate([]) + empty_dataset = TrainingDataset( + past_continuous=np.zeros( + (0, INPUT_LENGTH, CONTINUOUS_FEATURES), dtype=np.float32 + ), + past_categorical=np.zeros( + (0, INPUT_LENGTH, CATEGORICAL_FEATURES), dtype=np.int32 + ), + future_categorical=np.zeros( + (0, OUTPUT_LENGTH, CATEGORICAL_FEATURES), dtype=np.int32 + ), + static_categorical=np.zeros((0, 1, STATIC_FEATURES), dtype=np.int32), + targets=np.zeros((0, OUTPUT_LENGTH, 1), dtype=np.float32), + ) + loss = model.validate_model(empty_dataset) assert np.isnan(loss) -def test_model_validate_missing_targets_raises() -> None: +def test_model_validate_model_missing_targets_raises() -> None: input_size = _compute_input_size() model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - batch = _make_batch(include_targets=False) + dataset = _make_dataset(include_targets=False) with pytest.raises(ValueError, match="Targets are required"): - model.validate([batch]) + model.validate_model(dataset) def test_model_predict_output_shape() -> None: @@ -323,9 +334,9 @@ def test_model_early_stopping() -> None: model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - batches = [_make_batch()] + dataset = _make_dataset() losses = model.train( - train_batches=batches, + dataset=dataset, epochs=EPOCHS_LONG, learning_rate=LEARNING_RATE, validate_data=False, @@ -341,23 +352,23 @@ def test_model_validation_sample_size_must_be_positive() -> None: ) with pytest.raises(ValueError, match="positive"): model.train( - train_batches=[_make_batch()], + dataset=_make_dataset(), epochs=EPOCH_SINGLE, validate_data=True, validation_sample_size=0, ) -def test_model_validate_restores_training_state() -> None: +def test_model_validate_model_restores_training_state() -> None: input_size = _compute_input_size() model = Model( input_size=input_size, hidden_size=HIDDEN_SIZE, output_length=OUTPUT_LENGTH ) - batches = [_make_batch()] + dataset = _make_dataset() Tensor.training = True - model.validate(batches) + model.validate_model(dataset) assert Tensor.training is True Tensor.training = False - model.validate(batches) + model.validate_model(dataset) assert Tensor.training is False diff --git a/models/tide/tests/test_run.py b/models/tide/tests/test_run.py deleted file mode 100644 index 2a285dccb..000000000 --- a/models/tide/tests/test_run.py +++ /dev/null @@ -1,67 +0,0 @@ -from unittest.mock import patch - -import pytest -from tide.run import run_training_job - - -def test_run_training_job_calls_training_pipeline() -> None: - with patch( - "tide.run.training_pipeline", - return_value="s3://bucket/artifacts/model.tar.gz", - ) as mock_pipeline: - result = run_training_job( - base_url="http://datamanager:8080", - data_bucket="fund-data-bucket", - artifacts_bucket="fund-artifacts-bucket", - lookback_days=365, - ) - - mock_pipeline.assert_called_once_with( - base_url="http://datamanager:8080", - data_bucket="fund-data-bucket", - artifacts_bucket="fund-artifacts-bucket", - lookback_days=365, - ) - assert result == "s3://bucket/artifacts/model.tar.gz" - - -def test_run_training_job_returns_artifact_path() -> None: - expected_path = ( - "s3://my-bucket/artifacts/tide-trainer-2024-01-01/output/model.tar.gz" - ) - with patch( - "tide.run.training_pipeline", - return_value=expected_path, - ): - result = run_training_job( - base_url="http://datamanager:8080", - data_bucket="fund-data-bucket", - artifacts_bucket="fund-artifacts-bucket", - ) - - assert result == expected_path - - -def test_run_training_job_propagates_errors() -> None: - with ( - patch( - "tide.run.training_pipeline", - side_effect=RuntimeError("Training failed"), - ), - pytest.raises(RuntimeError, match="Training failed"), - ): - run_training_job( - base_url="http://datamanager:8080", - data_bucket="fund-data-bucket", - artifacts_bucket="fund-artifacts-bucket", - ) - - -def test_run_training_job_rejects_non_positive_lookback() -> None: - with pytest.raises(ValueError, match="lookback_days must be positive"): - run_training_job( - base_url="http://datamanager:8080", - data_bucket="fund-data-bucket", - artifacts_bucket="fund-artifacts-bucket", - lookback_days=0, - ) diff --git a/models/tide/tests/test_tasks.py b/models/tide/tests/test_tasks.py index 9a9f7b828..b86b25e90 100644 --- a/models/tide/tests/test_tasks.py +++ b/models/tide/tests/test_tasks.py @@ -1,5 +1,5 @@ import io -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, patch import polars as pl @@ -19,21 +19,22 @@ _SAMPLE_EQUITY_BARS = pl.DataFrame( { "ticker": ["AAPL"], - "timestamp": [_TARGET_DATE], + "timestamp": [int(_TARGET_DATE.timestamp()) * 1000], "open_price": [148.0], "high_price": [152.0], "low_price": [147.0], "close_price": [150.0], "volume": [1_000_000], "volume_weighted_average_price": [151.0], + "transactions": [5_000], } ) _SAMPLE_CATEGORIES = pl.DataFrame( { - "ticker": ["AAPL"], - "sector": ["Technology"], - "industry": ["Consumer Electronics"], + "ticker": ["AAPL", "DLNGpB"], + "sector": ["Technology", "Energy"], + "industry": ["Consumer Electronics", "Oil & Gas"], } ) @@ -51,6 +52,7 @@ def _to_csv_bytes(data: pl.DataFrame) -> bytes: def test_filter_equity_bars_keeps_rows_above_thresholds() -> None: data = pl.DataFrame( { + "ticker": ["AAPL", "LOW"], "close_price": [MINIMUM_CLOSE_PRICE + 1.0, 0.5], "volume": [MINIMUM_VOLUME + 1, 50_000], } @@ -62,9 +64,53 @@ def test_filter_equity_bars_keeps_rows_above_thresholds() -> None: assert result["close_price"][0] == MINIMUM_CLOSE_PRICE + 1.0 +def test_filter_equity_bars_excludes_preferred_stocks() -> None: + data = pl.DataFrame( + { + "ticker": ["AAPL", "JPMpC", "NEEpR"], + "close_price": [ + MINIMUM_CLOSE_PRICE + 1.0, + MINIMUM_CLOSE_PRICE + 1.0, + MINIMUM_CLOSE_PRICE + 1.0, + ], + "volume": [MINIMUM_VOLUME + 1, MINIMUM_VOLUME + 1, MINIMUM_VOLUME + 1], + } + ) + + result = filter_equity_bars(data) + + assert len(result) == 1 + assert result["ticker"][0] == "AAPL" + + +def test_filter_equity_bars_excludes_warrants() -> None: + data = pl.DataFrame( + { + "ticker": ["AAPL", "RALw", "FTVw", "DDw"], + "close_price": [ + MINIMUM_CLOSE_PRICE + 1.0, + MINIMUM_CLOSE_PRICE + 1.0, + MINIMUM_CLOSE_PRICE + 1.0, + MINIMUM_CLOSE_PRICE + 1.0, + ], + "volume": [ + MINIMUM_VOLUME + 1, + MINIMUM_VOLUME + 1, + MINIMUM_VOLUME + 1, + MINIMUM_VOLUME + 1, + ], + } + ) + + result = filter_equity_bars(data) + + assert len(result) == 1 + assert result["ticker"][0] == "AAPL" + + def test_filter_equity_bars_empty_input_returns_empty() -> None: - data = pl.DataFrame({"close_price": [], "volume": []}).cast( - {"close_price": pl.Float64, "volume": pl.Int64} + data = pl.DataFrame({"ticker": [], "close_price": [], "volume": []}).cast( + {"ticker": pl.String, "close_price": pl.Float64, "volume": pl.Int64} ) result = filter_equity_bars(data) @@ -81,6 +127,38 @@ def test_consolidate_data_joins_on_ticker_and_retains_columns() -> None: assert "industry" in result.columns +def test_consolidate_data_excludes_tickers_with_null_sector_or_industry() -> None: + categories = pl.DataFrame( + { + "ticker": ["AAPL", "MSFT", "GOOG"], + "sector": ["Technology", None, "Technology"], + "industry": ["Consumer Electronics", "Software", None], + } + ) + equity_bars = pl.DataFrame( + { + "ticker": ["AAPL", "MSFT", "GOOG"], + "timestamp": [ + int(_TARGET_DATE.timestamp()) * 1000, + int(_TARGET_DATE.timestamp()) * 1000, + int(_TARGET_DATE.timestamp()) * 1000, + ], + "open_price": [148.0, 200.0, 100.0], + "high_price": [152.0, 205.0, 105.0], + "low_price": [147.0, 198.0, 99.0], + "close_price": [150.0, 202.0, 102.0], + "volume": [1_000_000, 500_000, 750_000], + "volume_weighted_average_price": [151.0, 201.0, 101.0], + "transactions": [5_000, 2_000, 3_000], + } + ) + + result = consolidate_data(equity_bars, categories) + + assert len(result) == 1 + assert result["ticker"][0] == "AAPL" + + def test_consolidate_data_excludes_unmatched_tickers() -> None: categories = pl.DataFrame( { @@ -95,6 +173,33 @@ def test_consolidate_data_excludes_unmatched_tickers() -> None: assert len(result) == 0 +def test_read_equity_bars_from_s3_normalizes_column_types_across_files() -> None: + day1 = _SAMPLE_EQUITY_BARS.with_columns(pl.col("volume").cast(pl.Float64)) + day2 = _SAMPLE_EQUITY_BARS.with_columns(pl.col("volume").cast(pl.Int64)) + + mock_body_1 = MagicMock() + mock_body_1.read.return_value = _to_parquet_bytes(day1) + mock_body_2 = MagicMock() + mock_body_2.read.return_value = _to_parquet_bytes(day2) + + mock_s3_client = MagicMock() + mock_s3_client.get_object.side_effect = [ + {"Body": mock_body_1}, + {"Body": mock_body_2}, + ] + + result = read_equity_bars_from_s3( + s3_client=mock_s3_client, + bucket_name="test-bucket", + start_date=_TARGET_DATE, + end_date=_TARGET_DATE + timedelta(days=1), + ) + + expected_rows = len(day1) + len(day2) + assert len(result) == expected_rows + assert result["volume"].dtype == pl.Int64 + + def test_read_equity_bars_from_s3_returns_dataframe() -> None: parquet_bytes = _to_parquet_bytes(_SAMPLE_EQUITY_BARS) @@ -128,8 +233,8 @@ def test_read_categories_from_s3_returns_dataframe() -> None: bucket_name="test-bucket", ) - assert len(result) == 1 - assert result["ticker"][0] == "AAPL" + assert len(result) == len(_SAMPLE_CATEGORIES) + assert "AAPL" in result["ticker"].to_list() mock_s3_client.get_object.assert_called_once_with( Bucket="test-bucket", Key="equity/details/details.csv", @@ -153,6 +258,53 @@ def test_write_training_data_to_s3_returns_s3_uri() -> None: assert call_kwargs["Key"] == "training/data.parquet" +def test_prepare_training_data_succeeds_when_raw_data_contains_preferred_tickers() -> ( + None +): + preferred_ticker_bars = pl.DataFrame( + { + "ticker": ["AAPL", "DLNGpB"], + "timestamp": [ + int(_TARGET_DATE.timestamp()) * 1000, + int(_TARGET_DATE.timestamp()) * 1000, + ], + "open_price": [148.0, 20.0], + "high_price": [152.0, 21.0], + "low_price": [147.0, 19.0], + "close_price": [150.0, 20.5], + "volume": [1_000_000, 500_000], + "volume_weighted_average_price": [151.0, 20.3], + "transactions": [5_000, 1_000], + } + ) + parquet_bytes = _to_parquet_bytes(preferred_ticker_bars) + csv_bytes = _to_csv_bytes(_SAMPLE_CATEGORIES) + + mock_body_bars = MagicMock() + mock_body_bars.read.return_value = parquet_bytes + mock_body_categories = MagicMock() + mock_body_categories.read.return_value = csv_bytes + + mock_s3_client = MagicMock() + mock_s3_client.get_object.side_effect = [ + {"Body": mock_body_bars}, + {"Body": mock_body_categories}, + ] + + with patch("tide.tasks.boto3.client", return_value=mock_s3_client): + result = prepare_training_data( + data_bucket_name="test-data-bucket", + model_artifacts_bucket_name="test-artifacts-bucket", + start_date=_TARGET_DATE, + end_date=_TARGET_DATE, + ) + + assert result.startswith("s3://test-artifacts-bucket/") + uploaded_bytes = mock_s3_client.put_object.call_args.kwargs["Body"] + uploaded_df = pl.read_parquet(io.BytesIO(uploaded_bytes)) + assert list(uploaded_df["ticker"].unique().sort()) == ["AAPL"] + + def test_prepare_training_data_returns_s3_uri() -> None: parquet_bytes = _to_parquet_bytes(_SAMPLE_EQUITY_BARS) csv_bytes = _to_csv_bytes(_SAMPLE_CATEGORIES) diff --git a/prefect.yaml b/prefect.yaml index 5f9362406..d4ca69e69 100644 --- a/prefect.yaml +++ b/prefect.yaml @@ -5,14 +5,6 @@ pull: branch: master deployments: - - name: tide-trainer-remote - entrypoint: models/tide/src/tide/workflow.py:training_pipeline - work_pool: - name: fund-models-remote - parameters: {} - build: null - push: null - - name: tide-trainer-local entrypoint: models/tide/src/tide/workflow.py:training_pipeline work_pool: diff --git a/tools/src/tools/build_work_pool_template.py b/tools/src/tools/build_work_pool_template.py new file mode 100644 index 000000000..2308c5bce --- /dev/null +++ b/tools/src/tools/build_work_pool_template.py @@ -0,0 +1,109 @@ +import json +import os +import sys +from dataclasses import dataclass + +import structlog + +logger = structlog.get_logger() + + +@dataclass +class NetworkConfig: + vpc_id: str + private_subnet_1_id: str + private_subnet_2_id: str + ecs_security_group_id: str + + +def build_work_pool_template( + template: dict, + cluster: str, + aws_credentials_block_id: str, + task_definition_arn: str, + network: NetworkConfig, +) -> dict: + """Apply ECS GPU work pool configuration to a Prefect base job template.""" + aws_region = os.environ.get("AWS_REGION", "us-east-1") + + variables = template["variables"]["properties"] + variables["cluster"]["default"] = cluster + variables["aws_credentials"]["default"] = { + "$ref": {"block_document_id": aws_credentials_block_id} + } + variables["capacity_provider_strategy"]["default"] = [ + {"capacityProvider": "fund-models-gpu", "weight": 1} + ] + variables["launch_type"]["default"] = None + variables["task_definition_arn"]["default"] = task_definition_arn + variables["vpc_id"]["default"] = network.vpc_id + variables["network_configuration"]["default"] = { + "subnets": [network.private_subnet_1_id, network.private_subnet_2_id], + "securityGroups": [network.ecs_security_group_id], + "assignPublicIp": "DISABLED", + } + + task_def = template.setdefault("job_configuration", {}).setdefault( + "task_definition", {} + ) + containers = task_def.setdefault("containerDefinitions", [{}]) + if not containers: + containers.append({}) + + for container in containers: + container["resourceRequirements"] = [{"type": "GPU", "value": "1"}] + container["logConfiguration"] = { + "logDriver": "awslogs", + "options": { + "awslogs-group": "/ecs/fund/models", + "awslogs-region": aws_region, + "awslogs-stream-prefix": "tide", + }, + } + + return template + + +if __name__ == "__main__": + expected_arg_count = 8 + if len(sys.argv) != expected_arg_count: + usage = ( + "Usage: python -m tools.build_work_pool_template" + " " + " " + " " + ) + logger.error(usage, args_received=len(sys.argv) - 1) + sys.exit(1) + + ( + _, + cluster, + aws_credentials_block_id, + task_definition_arn, + vpc_id, + private_subnet_1_id, + private_subnet_2_id, + ecs_security_group_id, + ) = sys.argv + + try: + template = json.load(sys.stdin) + except json.JSONDecodeError as e: + logger.exception("Failed to parse work pool template JSON", error=f"{e}") + sys.exit(1) + + result = build_work_pool_template( + template=template, + cluster=cluster, + aws_credentials_block_id=aws_credentials_block_id, + task_definition_arn=task_definition_arn, + network=NetworkConfig( + vpc_id=vpc_id, + private_subnet_1_id=private_subnet_1_id, + private_subnet_2_id=private_subnet_2_id, + ecs_security_group_id=ecs_security_group_id, + ), + ) + + sys.stdout.write(json.dumps(result) + "\n") diff --git a/tools/tests/test_build_work_pool_template.py b/tools/tests/test_build_work_pool_template.py new file mode 100644 index 000000000..9b640d7e8 --- /dev/null +++ b/tools/tests/test_build_work_pool_template.py @@ -0,0 +1,125 @@ +import os + +from tools.build_work_pool_template import NetworkConfig, build_work_pool_template + + +def _minimal_template() -> dict: + return { + "variables": { + "properties": { + "cluster": {}, + "aws_credentials": {}, + "capacity_provider_strategy": {}, + "launch_type": {}, + "task_definition_arn": {}, + "vpc_id": {}, + "network_configuration": {}, + } + } + } + + +def _default_network() -> NetworkConfig: + return NetworkConfig( + vpc_id="vpc-123", + private_subnet_1_id="subnet-1", + private_subnet_2_id="subnet-2", + ecs_security_group_id="sg-123", + ) + + +def test_build_work_pool_template_sets_cluster() -> None: + result = build_work_pool_template( + template=_minimal_template(), + cluster="test-cluster", + aws_credentials_block_id="block-id", + task_definition_arn="arn:aws:ecs:us-east-1:123:task-definition/tide:1", + network=_default_network(), + ) + + assert result["variables"]["properties"]["cluster"]["default"] == "test-cluster" + + +def test_build_work_pool_template_sets_aws_credentials_ref() -> None: + result = build_work_pool_template( + template=_minimal_template(), + cluster="test-cluster", + aws_credentials_block_id="block-abc", + task_definition_arn="arn:aws:ecs:us-east-1:123:task-definition/tide:1", + network=_default_network(), + ) + + credentials = result["variables"]["properties"]["aws_credentials"]["default"] + assert credentials == {"$ref": {"block_document_id": "block-abc"}} + + +def test_build_work_pool_template_sets_capacity_provider_and_clears_launch_type() -> ( + None +): + result = build_work_pool_template( + template=_minimal_template(), + cluster="test-cluster", + aws_credentials_block_id="block-id", + task_definition_arn="arn:aws:ecs:us-east-1:123:task-definition/tide:1", + network=_default_network(), + ) + + assert result["variables"]["properties"]["capacity_provider_strategy"][ + "default" + ] == [{"capacityProvider": "fund-models-gpu", "weight": 1}] + assert result["variables"]["properties"]["launch_type"]["default"] is None + + +def test_build_work_pool_template_sets_network_configuration() -> None: + result = build_work_pool_template( + template=_minimal_template(), + cluster="test-cluster", + aws_credentials_block_id="block-id", + task_definition_arn="arn:aws:ecs:us-east-1:123:task-definition/tide:1", + network=NetworkConfig( + vpc_id="vpc-456", + private_subnet_1_id="subnet-a", + private_subnet_2_id="subnet-b", + ecs_security_group_id="sg-789", + ), + ) + + network = result["variables"]["properties"]["network_configuration"]["default"] + assert network["subnets"] == ["subnet-a", "subnet-b"] + assert network["securityGroups"] == ["sg-789"] + assert network["assignPublicIp"] == "DISABLED" + + +def test_build_work_pool_template_configures_gpu_and_logging() -> None: + result = build_work_pool_template( + template=_minimal_template(), + cluster="test-cluster", + aws_credentials_block_id="block-id", + task_definition_arn="arn:aws:ecs:us-east-1:123:task-definition/tide:1", + network=_default_network(), + ) + + containers = result["job_configuration"]["task_definition"]["containerDefinitions"] + assert len(containers) == 1 + assert containers[0]["resourceRequirements"] == [{"type": "GPU", "value": "1"}] + log_opts = containers[0]["logConfiguration"]["options"] + assert log_opts["awslogs-group"] == "/ecs/fund/models" + assert log_opts["awslogs-stream-prefix"] == "tide" + assert log_opts["awslogs-region"] == os.environ.get("AWS_REGION", "us-east-1") + + +def test_build_work_pool_template_populates_empty_containers_list() -> None: + template = _minimal_template() + template["job_configuration"] = {"task_definition": {"containerDefinitions": []}} + + result = build_work_pool_template( + template=template, + cluster="test-cluster", + aws_credentials_block_id="block-id", + task_definition_arn="arn:aws:ecs:us-east-1:123:task-definition/tide:1", + network=_default_network(), + ) + + containers = result["job_configuration"]["task_definition"]["containerDefinitions"] + assert len(containers) == 1 + assert "resourceRequirements" in containers[0] diff --git a/uv.lock b/uv.lock index f4ca669cb..252efd08c 100644 --- a/uv.lock +++ b/uv.lock @@ -685,6 +685,7 @@ dependencies = [ { name = "internal" }, { name = "pandera", extra = ["polars"] }, { name = "polars" }, + { name = "prometheus-client" }, { name = "requests" }, { name = "sentry-sdk", extra = ["fastapi"] }, { name = "structlog" }, @@ -704,6 +705,7 @@ requires-dist = [ { name = "internal", editable = "libraries/python" }, { name = "pandera", extras = ["polars"], specifier = ">=0.26.0" }, { name = "polars", specifier = ">=1.29.0" }, + { name = "prometheus-client", specifier = ">=0.21.0" }, { name = "requests", specifier = ">=2.32.5" }, { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=2.0.0" }, { name = "structlog", specifier = ">=25.5.0" }, @@ -1851,6 +1853,7 @@ dependencies = [ { name = "internal" }, { name = "pandera", extra = ["polars"] }, { name = "polars" }, + { name = "prometheus-client" }, { name = "pytz" }, { name = "requests" }, { name = "scipy" }, @@ -1867,6 +1870,7 @@ requires-dist = [ { name = "internal", editable = "libraries/python" }, { name = "pandera", extras = ["polars"], specifier = ">=0.26.0" }, { name = "polars", specifier = ">=1.29.0" }, + { name = "prometheus-client", specifier = ">=0.21.0" }, { name = "pytz", specifier = ">=2025.1" }, { name = "requests", specifier = ">=2.32.5" }, { name = "scipy", specifier = ">=1.17.1" },