Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .claude/skills/autotrain/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ For code changes (model architecture, loss function, layer structure), run
**one experiment at a time**:

1. **Hypothesize** a single architecture change
2. **Modify** the model code (tide_model.py or tide_data.py)
2. **Modify** the model code (model.py or data.py)
3. **Commit**: `git add -A && git commit -m "autoresearch: <description>"`
4. **Run training**:
```bash
Expand Down Expand Up @@ -153,16 +153,16 @@ Priority order: architecture > hyperparameters > epoch depth > loss function > d
## Files Modified During Loop

Primarily:
- `models/tide/src/tide/tide_model.py` (model architecture)
- `models/tide/src/tide/tide_data.py` (data processing)
- `models/tide/src/tide/model.py` (model architecture)
- `models/tide/src/tide/data.py` (data processing)
- `models/tide/src/tide/trainer.py` (DEFAULT_CONFIGURATION, training hyperparameters)

## Key Files

- **Trainer**: `models/tide/src/tide/trainer.py` (DEFAULT_CONFIGURATION lives here)
- **Workflow**: `models/tide/src/tide/workflow.py`
- **Model**: `models/tide/src/tide/tide_model.py`
- **Data**: `models/tide/src/tide/tide_data.py`
- **Model**: `models/tide/src/tide/model.py`
- **Data**: `models/tide/src/tide/data.py`
- **Experiment log**: `autotrain/<model-name>/experiments.jsonl`
- **Prefect config**: `prefect.yaml`
- **devenv tasks**: `devenv.nix`
11 changes: 3 additions & 8 deletions .github/workflows/launch_infrastructure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,11 @@ jobs:
- name: Install Flox
if: steps.changes.outputs.service == 'true'
uses: flox/install-flox-action@v2
- name: Build ${{ matrix.application }}-${{ matrix.stage }} image
- name: Build and push ${{ matrix.application }}-${{ matrix.stage }} image
if: steps.changes.outputs.service == 'true'
uses: flox/activate-action@v1
with:
command: mask infrastructure image build ${{ matrix.application }} ${{ matrix.stage }}
- name: Push ${{ matrix.application }}-${{ matrix.stage }} image
if: steps.changes.outputs.service == 'true'
uses: flox/activate-action@v1
with:
command: mask infrastructure image push ${{ matrix.application }} ${{ matrix.stage }}
command: mask infrastructure image build-and-push ${{ matrix.application }} ${{ matrix.stage }}
launch_infrastructure:
name: Deploy with Pulumi
needs: build_and_push_images
Expand Down Expand Up @@ -142,4 +137,4 @@ jobs:
env:
PULUMI_ACCESS_TOKEN: ${{ secrets.PULUMI_ACCESS_TOKEN }}
with:
command: mask infrastructure image deploy ${{ matrix.application }} ${{ matrix.stage }}
command: mask infrastructure service update ${{ matrix.application }}
5 changes: 3 additions & 2 deletions applications/data_manager/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::io::Cursor;
use tracing::{debug, error, info, warn};

const EQUITY_DETAILS_KEY: &str = "equity/details/details.csv";
pub const DUCKDB_CONFIG_VALUE_MAX_LENGTH: usize = 4096;

pub async fn write_equity_bars_dataframe_to_s3(
state: &State,
Expand Down Expand Up @@ -130,8 +131,8 @@ pub fn sanitize_duckdb_config_value(value: &str) -> Result<String, Error> {
return Err(Error::Other(message));
}

// Reasonable length limit
if value.len() > 512 {
// Reasonable length limit (4096 accommodates AWS session tokens which can exceed 1000 characters)
if value.len() > DUCKDB_CONFIG_VALUE_MAX_LENGTH {
let message = format!("Configuration value too long: {} characters", value.len());
error!("{}", message);
return Err(Error::Other(message));
Expand Down
6 changes: 4 additions & 2 deletions applications/data_manager/tests/test_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use data_manager::{
query_predictions_dataframe_from_s3, read_equity_details_dataframe_from_s3,
sanitize_duckdb_config_value, write_equity_bars_dataframe_to_s3,
write_equity_details_dataframe_to_s3, write_portfolio_dataframe_to_s3,
write_predictions_dataframe_to_s3, PredictionQuery,
write_predictions_dataframe_to_s3, PredictionQuery, DUCKDB_CONFIG_VALUE_MAX_LENGTH,
},
};
use polars::prelude::*;
Expand Down Expand Up @@ -466,6 +466,8 @@ fn test_sanitize_duckdb_config_value_valid() {
assert!(sanitize_duckdb_config_value("true").is_ok());
assert!(sanitize_duckdb_config_value("false").is_ok());
assert!(sanitize_duckdb_config_value("http://127.0.0.1:9000").is_ok());
// AWS ECS session tokens can exceed 1000 characters
assert!(sanitize_duckdb_config_value(&"a".repeat(1052)).is_ok());
}

#[test]
Expand All @@ -474,7 +476,7 @@ fn test_sanitize_duckdb_config_value_rejects_injection() {
assert!(sanitize_duckdb_config_value("localhost'; --").is_err());
assert!(sanitize_duckdb_config_value("\"malicious\"").is_err());
assert!(sanitize_duckdb_config_value("").is_err());
assert!(sanitize_duckdb_config_value(&"a".repeat(513)).is_err());
assert!(sanitize_duckdb_config_value(&"a".repeat(DUCKDB_CONFIG_VALUE_MAX_LENGTH + 1)).is_err());
assert!(sanitize_duckdb_config_value("value;another").is_err());
}

Expand Down
1 change: 1 addition & 0 deletions applications/ensemble_manager/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"requests>=2.32.5",
"sentry-sdk[fastapi]>=2.0.0",
"structlog>=25.5.0",
"prometheus-client>=0.21.0",
]

[dependency-groups]
Expand Down
41 changes: 21 additions & 20 deletions applications/ensemble_manager/src/ensemble_manager/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from mypy_boto3_s3 import S3Client

from internal.equity_details_schema import equity_details_schema
from tide.tide_data import Data
from tide.tide_model import Model
from tide.data import Data
from tide.model import Model

from .metrics import (
get_metrics,
Expand Down Expand Up @@ -342,30 +342,31 @@ def create_predictions(request: Request) -> Response: # noqa: PLR0915

tide_data.preprocess_and_set_data(data=data)

batches = tide_data.get_batches(data_type="predict")
from tinygrad.tensor import Tensor # noqa: PLC0415

if not batches:
dataset = tide_data.get_dataset(data_type="predict")

if len(dataset) == 0:
prediction_errors_total.labels(stage="batch_creation").inc()
logger.error("No data batches available for prediction")
logger.error("No data samples available for prediction")
observe_duration(timer_start)
return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)

logger.info(
"Processing prediction batches",
num_batches=len(batches),
)
logger.info("Processing prediction dataset", num_samples=len(dataset))

all_predictions = []
for batch in batches:
raw_predictions = request.app.state.tide_model.predict(inputs=batch)
batch_predictions = tide_data.postprocess_predictions(
input_batch=batch,
predictions=raw_predictions,
current_datetime=current_timestamp,
)
all_predictions.append(batch_predictions)
batch = {
"past_continuous_features": Tensor(dataset.past_continuous),
"past_categorical_features": Tensor(dataset.past_categorical),
"future_categorical_features": Tensor(dataset.future_categorical),
"static_categorical_features": Tensor(dataset.static_categorical),
}

predictions = pl.concat(all_predictions)
raw_predictions = request.app.state.tide_model.predict(inputs=batch)
predictions = tide_data.postprocess_predictions(
input_batch=batch,
predictions=raw_predictions,
current_datetime=current_timestamp,
)
logger.info(
"Combined predictions from all batches",
total_predictions=predictions.height,
Expand Down Expand Up @@ -417,7 +418,7 @@ def create_predictions(request: Request) -> Response: # noqa: PLR0915
observe_duration(timer_start)
Comment thread
greptile-apps[bot] marked this conversation as resolved.
raise

prediction_batch_count.set(len(batches))
prediction_batch_count.set(1)
prediction_row_count.set(validated_predictions.height)
observe_duration(timer_start)

Expand Down
1 change: 1 addition & 0 deletions applications/portfolio_manager/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"sentry-sdk[fastapi]>=2.0.0",
"structlog>=25.5.0",
"scipy>=1.17.1",
"prometheus-client>=0.21.0",
]

[tool.uv]
Expand Down
10 changes: 8 additions & 2 deletions infrastructure/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from compute import acm_certificate_arn, alb, cluster, service_discovery_namespace
from config import account_id
from iam import github_actions_infrastructure_role, github_actions_oidc_provider
from networking import vpc
from networking import ecs_security_group, private_subnet_1, private_subnet_2, vpc
from storage import (
data_bucket,
data_manager_image_uri,
Expand All @@ -15,7 +15,7 @@
tide_model_runner_image_uri,
tide_model_runner_repository,
)
from training import models_cluster
from training import models_cluster, tide_trainer_task_definition

protocol = "https://" if acm_certificate_arn else "http://"

Expand All @@ -33,8 +33,14 @@

pulumi.export("aws_account_id", account_id)
pulumi.export("aws_vpc_id", vpc.id)
pulumi.export("aws_ecs_private_subnet_1_id", private_subnet_1.id)
pulumi.export("aws_ecs_private_subnet_2_id", private_subnet_2.id)
pulumi.export("aws_ecs_security_group_id", ecs_security_group.id)
pulumi.export("aws_ecs_cluster_name", cluster.name)
pulumi.export("aws_ecs_models_cluster_name", models_cluster.name)
pulumi.export(
"aws_ecs_tide_trainer_task_definition_arn", tide_trainer_task_definition.arn
)
pulumi.export("aws_alb_dns_name", alb.dns_name)
pulumi.export("aws_alb_url", pulumi.Output.concat(protocol, alb.dns_name))
pulumi.export("aws_service_discovery_namespace", service_discovery_namespace.name)
Expand Down
34 changes: 34 additions & 0 deletions infrastructure/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from config import tags
from iam import execution_role, task_role
from networking import ecs_security_group, private_subnet_1, private_subnet_2
from storage import tide_model_runner_image_uri

models_cluster = aws.ecs.Cluster(
"models_cluster",
Expand Down Expand Up @@ -154,9 +155,42 @@
tags=tags,
)

tide_trainer_task_definition = aws.ecs.TaskDefinition(
"tide_trainer_task_definition",
family="fund-tide-trainer",
requires_compatibilities=["EC2"],
network_mode="awsvpc",
cpu="4096",
memory="14336",
execution_role_arn=execution_role.arn,
task_role_arn=task_role.arn,
container_definitions=tide_model_runner_image_uri.apply(
lambda image_uri: json.dumps(
[
{
"name": "prefect",
"image": image_uri,
"essential": True,
"resourceRequirements": [{"type": "GPU", "value": "1"}],
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/ecs/fund/models",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "tide",
},
},
}
]
)
),
tags=tags,
)

__all__ = [
"execution_role",
"models_cluster",
"models_log_group",
"task_role",
"tide_trainer_task_definition",
]
Loading
Loading