diff --git a/.claude/commands/update-pull-request.md b/.claude/commands/update-pull-request.md index f2d70253..35fc7a5b 100644 --- a/.claude/commands/update-pull-request.md +++ b/.claude/commands/update-pull-request.md @@ -13,6 +13,10 @@ If you need to accept edits during execution: - Choose "accept edits and continue" (NOT "clear context") - Or wait until the "Commit and Push Changes" step to accept all edits at once +## Important: API Usage + +**Do NOT use MCP GitHub tools during this command.** They have shown session reliability issues. All GitHub API interactions must use `gh api` (REST) or Python `urllib.request` (GraphQL mutations only). See step 8 for details on which operations require Python. + ## Instructions Analyze and address all feedback and failing checks on a GitHub pull request, then respond to and resolve all comments. @@ -49,172 +53,104 @@ Follow these steps: echo "Repository: ${OWNER}/${REPO}" - # Fetch PR data - gh api graphql -f query=' - query($owner: String!, $repo: String!, $number: Int!) { - repository(owner: $owner, name: $repo) { - pullRequest(number: $number) { - id - title - headRefName - headRefOid - baseRefName - - reviewThreads(first: 100) { - nodes { - id - isResolved - isOutdated - comments(first: 100) { - nodes { - id - databaseId - body - author { login } - path - position - createdAt - } - } - } - } - - comments(first: 100) { - nodes { - id - databaseId - body - author { login } - createdAt - } - } - - reviews(first: 50) { - nodes { - id - state - body - author { login } - submittedAt - } - } - - commits(last: 1) { - nodes { - commit { - checkSuites(first: 50) { - nodes { - workflowRun { - id - databaseId - } - checkRuns(first: 50) { - nodes { - name - status - conclusion - detailsUrl - } - } - } - } - } - } - } - } - } - } - ' -f owner="${OWNER}" -f repo="${REPO}" -F number=${ARGUMENTS} > "${SCRATCHPAD}/pr_data.json" - - echo "PR data fetched to ${SCRATCHPAD}/pr_data.json" + # Fetch PR data via REST — raw responses saved to scratchpad, never read directly into context + gh api "repos/${OWNER}/${REPO}/pulls/${ARGUMENTS}" > "${SCRATCHPAD}/pr_meta_raw.json" + gh api --paginate "repos/${OWNER}/${REPO}/issues/${ARGUMENTS}/comments" | jq -s 'add' > "${SCRATCHPAD}/pr_comments_raw.json" + gh api --paginate "repos/${OWNER}/${REPO}/pulls/${ARGUMENTS}/reviews" | jq -s 'add' > "${SCRATCHPAD}/pr_reviews_raw.json" + gh api --paginate "repos/${OWNER}/${REPO}/pulls/${ARGUMENTS}/comments" | jq -s 'add' > "${SCRATCHPAD}/pr_review_comments_raw.json" + + HEAD_SHA=$(jq -r '.head.sha' "${SCRATCHPAD}/pr_meta_raw.json") + gh api --paginate "repos/${OWNER}/${REPO}/commits/${HEAD_SHA}/check-runs" | jq -s 'add' > "${SCRATCHPAD}/check_runs_raw.json" + + # Fetch review thread node IDs (PRRT_* format) needed for resolving threads later. + # Uses an inline GraphQL query with values embedded as literals (no $variable syntax), + # which avoids the bash ! history-expansion issue that breaks parameterized GraphQL queries. + # Capped at 100 threads — cursor pagination requires parameterized GraphQL ($variables with !), + # which bash history expansion breaks, so this limit is intentional for practical use. + INLINE_QUERY=$(printf '{repository(owner:"%s",name:"%s"){pullRequest(number:%s){reviewThreads(first:100){nodes{id isResolved comments(first:1){nodes{databaseId}}}}}}}' \ + "${OWNER}" "${REPO}" "${ARGUMENTS}") + gh api graphql -f query="${INLINE_QUERY}" > "${SCRATCHPAD}/thread_ids_raw.json" + + echo "PR data fetched to ${SCRATCHPAD}/" + echo "SCRATCHPAD=${SCRATCHPAD}" ``` - Validate the fetched data: ```bash - # Check jq is installed + SCRATCHPAD="" + if ! command -v jq >/dev/null 2>&1; then echo "Error: jq is required but not installed. Install with: brew install jq (macOS) or apt-get install jq (Linux)" exit 1 fi - # Check file exists and contains valid JSON - if [ ! -f "${SCRATCHPAD}/pr_data.json" ] || ! jq empty "${SCRATCHPAD}/pr_data.json" 2>/dev/null; then - echo "Error: Failed to fetch valid pull request data. Check that pull request #${ARGUMENTS} exists, run 'gh auth status' to verify authentication, then retry." - exit 1 - fi + for f in pr_meta_raw pr_comments_raw pr_reviews_raw pr_review_comments_raw check_runs_raw thread_ids_raw; do + if [ ! -f "${SCRATCHPAD}/${f}.json" ] || ! jq empty "${SCRATCHPAD}/${f}.json" 2>/dev/null; then + echo "Error: Failed to fetch ${f}.json. Check that pull request #${ARGUMENTS} exists, run 'gh auth status' to verify authentication, then retry." + exit 1 + fi + done - # Validate GraphQL response structure - if ! jq -e '.data.repository.pullRequest.id and (.errors | not)' "${SCRATCHPAD}/pr_data.json" >/dev/null 2>&1; then - echo "Error: Pull request data missing expected fields or contains GraphQL errors. Check pull request #${ARGUMENTS} exists and you have access." - jq -r '.errors[]?.message // "No specific error message available"' "${SCRATCHPAD}/pr_data.json" + if ! jq -e '.number' "${SCRATCHPAD}/pr_meta_raw.json" >/dev/null 2>&1; then + echo "Error: Pull request data missing expected fields. Check pull request #${ARGUMENTS} exists and you have access." exit 1 fi + + echo "Validation passed" ``` -- **Critical**: The pull request data file will be too large to read directly with the Read tool. Extract structured data once into smaller files: +- **Critical**: The raw files will be too large to read directly with the Read tool. Extract structured data once into smaller files: ```bash - echo "Extracting structured data from PR response..." + SCRATCHPAD="" - # Extract PR metadata - jq '.data.repository.pullRequest | { - id: .id, - title: .title, - headRefName: .headRefName, - baseRefName: .baseRefName - }' "${SCRATCHPAD}/pr_data.json" > "${SCRATCHPAD}/metadata.json" + echo "Extracting structured data from raw responses..." - # Extract unresolved review threads - jq '[.data.repository.pullRequest.reviewThreads.nodes[] | - select(.isResolved == false and .isOutdated == false) | { - threadId: .id, - comments: [.comments.nodes[] | { - id: .id, - databaseId: .databaseId, - body: .body, - author: .author.login, - path: .path, - position: .position - }] - }]' "${SCRATCHPAD}/pr_data.json" > "${SCRATCHPAD}/review_threads.json" - - # Extract outdated threads - jq '[.data.repository.pullRequest.reviewThreads.nodes[] | - select(.isResolved == false and .isOutdated == true) | { - threadId: .id, - comments: [.comments.nodes[] | { - id: .id, - body: .body, - author: .author.login, - path: .path - }] - }]' "${SCRATCHPAD}/pr_data.json" > "${SCRATCHPAD}/outdated_threads.json" - - # Extract PR-level comments - jq '[.data.repository.pullRequest.comments.nodes[] | { - id: .id, - databaseId: .databaseId, + # PR metadata + jq '{ + number: .number, + title: .title, + headRefName: .head.ref, + baseRefName: .base.ref, + headSha: .head.sha + }' "${SCRATCHPAD}/pr_meta_raw.json" > "${SCRATCHPAD}/metadata.json" + + # PR-level (issue) comments — REST uses .user.login, not .author.login + jq '[.[] | {id: .id, databaseId: .id, body: .body, author: .user.login}]' \ + "${SCRATCHPAD}/pr_comments_raw.json" > "${SCRATCHPAD}/pr_comments.json" + + # Inline review comments — root threads only (in_reply_to_id == null) + # rootCommentId is the integer ID used to post replies via REST + jq '[.[] | select(.in_reply_to_id == null) | { + rootCommentId: .id, + path: .path, + line: (.line // .original_line), body: .body, - author: .author.login - }]' "${SCRATCHPAD}/pr_data.json" > "${SCRATCHPAD}/pr_comments.json" + author: .user.login + }]' "${SCRATCHPAD}/pr_review_comments_raw.json" > "${SCRATCHPAD}/review_comments.json" - # Extract check failures - jq '[.data.repository.pullRequest.commits.nodes[].commit.checkSuites.nodes[] as $suite | - $suite.checkRuns.nodes[] | - select(.conclusion == "FAILURE" or .conclusion == "TIMED_OUT") | { + # Check failures — REST uses lowercase conclusion values ("failure", "timed_out") + jq '[.check_runs[] | + select(.conclusion == "failure" or .conclusion == "timed_out") | { name: .name, conclusion: .conclusion, - detailsUrl: .detailsUrl, - workflowRunId: ($suite.workflowRun.databaseId // null) - }] | unique_by(.name)' "${SCRATCHPAD}/pr_data.json" > "${SCRATCHPAD}/check_failures.json" + detailsUrl: .details_url + }] | unique_by(.name)' "${SCRATCHPAD}/check_runs_raw.json" > "${SCRATCHPAD}/check_failures.json" + + # Thread node IDs — links PRRT_* IDs to root comment database IDs for resolution + jq '[.data.repository.pullRequest.reviewThreads.nodes[] | { + threadId: .id, + isResolved: .isResolved, + rootCommentId: .comments.nodes[0].databaseId + }]' "${SCRATCHPAD}/thread_ids_raw.json" > "${SCRATCHPAD}/thread_ids.json" echo "Data extraction complete:" - echo " - metadata.json (PR info)" - echo " - review_threads.json ($(jq 'length' "${SCRATCHPAD}/review_threads.json") unresolved threads)" - echo " - outdated_threads.json ($(jq 'length' "${SCRATCHPAD}/outdated_threads.json") outdated threads)" - echo " - pr_comments.json ($(jq 'length' "${SCRATCHPAD}/pr_comments.json") PR comments)" - echo " - check_failures.json ($(jq 'length' "${SCRATCHPAD}/check_failures.json") failed checks)" + echo " - metadata.json" + echo " - pr_comments.json ($(jq 'length' "${SCRATCHPAD}/pr_comments.json") comments)" + echo " - review_comments.json ($(jq 'length' "${SCRATCHPAD}/review_comments.json") root review threads)" + echo " - check_failures.json ($(jq 'length' "${SCRATCHPAD}/check_failures.json") failed checks)" + echo " - thread_ids.json ($(jq 'length' "${SCRATCHPAD}/thread_ids.json") threads)" ``` - These smaller structured files can be read with the Read tool if needed, and eliminate redundant jq parsing throughout the command. @@ -224,35 +160,42 @@ Follow these steps: - Review all extracted data to build a complete picture of what needs to be addressed: ```bash + SCRATCHPAD="" + echo "=== Check Failures ===" jq -r '.[] | "[\(.conclusion)] \(.name) - \(.detailsUrl)"' "${SCRATCHPAD}/check_failures.json" echo "" - echo "=== Unresolved Review Threads ===" - jq -r '.[] | "\(.threadId) | \(.comments[0].path // "N/A") | \(.comments[0].author)"' "${SCRATCHPAD}/review_threads.json" - - echo "" - echo "=== Outdated Threads (require manual review) ===" - jq -r '.[] | "\(.threadId) | \(.comments[0].path // "N/A") | \(.comments[0].author) | \(.comments[0].body[:80])"' "${SCRATCHPAD}/outdated_threads.json" + echo "=== Review Comments (root threads only) ===" + jq -r '.[] | "[\(.rootCommentId)] @\(.author) on \(.path):\(.line // "?"): \(.body[:80])"' \ + "${SCRATCHPAD}/review_comments.json" echo "" echo "=== PR-level Comments ===" jq -r '.[] | "\(.databaseId) | \(.author) | \(.body[:80])"' "${SCRATCHPAD}/pr_comments.json" + + echo "" + echo "=== Thread Resolution Status ===" + jq -r '.[] | "\(.threadId) | rootComment=\(.rootCommentId) | resolved=\(.isResolved)"' \ + "${SCRATCHPAD}/thread_ids.json" ``` -- For detailed review of specific threads, use: +- For full comment bodies: ```bash - jq '.[] | select(.threadId == "PRRT_xxx")' "${SCRATCHPAD}/review_threads.json" + SCRATCHPAD="" + jq -r '.[] | "=== [\(.rootCommentId)] @\(.author) on \(.path) ===\n\(.body)\n"' \ + "${SCRATCHPAD}/review_comments.json" ``` - The structured files contain all necessary metadata: - - `review_threads.json`: Thread ID, comment IDs (both node and database), body, author, file path/position - - `outdated_threads.json`: Thread ID and comment metadata (body, author, file path); "outdated" means the code was modified, not that the feedback is irrelevant - review each to determine if it still applies + - `review_comments.json`: Root comment ID (integer), path, line, body, author — one entry per thread + - `thread_ids.json`: Thread node IDs (`PRRT_*`), resolution status, root comment database ID — join with `review_comments.json` on `rootCommentId` - `pr_comments.json`: Comment IDs, body, author - `check_failures.json`: Check name, conclusion, details URL + - `pr_review_comments_raw.json`: Full flat array of all review comments including replies (use for full context if needed) -- Note that check-runs and workflow runs are distinct; to fetch logs, first obtain the workflow run ID, then use `gh api repos/${OWNER}/${REPO}/actions/runs/{run_id}/logs`. If logs are inaccessible via API, run `mask development python all` or `mask development rust all` locally to replicate the errors. +- Note that check-runs and workflow runs are distinct; to fetch logs, obtain the workflow run ID from `check_runs_raw.json` and use `gh api repos/${OWNER}/${REPO}/actions/runs/{run_id}/logs`. If logs are inaccessible via API, run `mask development python all` or `mask development rust all` locally to replicate the errors. - Group all feedback (check failures, review threads, outdated threads, PR-level comments) using judgement: by file, by theme, by type of change, or whatever makes most sense for the specific pull request; ensure each group maintains the full metadata for all items it contains. - Analyze dependencies between feedback groups to determine which are independent (can be worked in parallel) and which are interdependent (must be handled sequentially). - For each piece of feedback, evaluate whether to address it (make code changes) or reject it (explain why the feedback doesn't apply); provide clear reasoning for each decision. @@ -322,77 +265,84 @@ Follow these steps: ### 8. Respond to and Resolve Comments - For each piece of feedback (both addressed and rejected), draft a response comment explaining what was done or why it was rejected, using the commenter name for personalization. -- Post all response comments to their respective threads: - - For review comments (code-level), use GraphQL `addPullRequestReviewThreadReply` mutation to post comments directly to threads (NOT as pending review): - - ```bash - # IMPORTANT: Keep response text simple - avoid newlines, code blocks, and special characters - # GraphQL string literals cannot contain raw newlines; use spaces or simple sentences - # If complex formatting is needed, save response to a variable first and ensure proper escaping - - gh api graphql -f query=' - mutation($pullRequestReviewThreadId: ID!, $body: String!) { - addPullRequestReviewThreadReply(input: { - pullRequestReviewThreadId: $pullRequestReviewThreadId, - body: $body - }) { - comment { id } - } - } - ' -f pullRequestReviewThreadId="" -f body="" - ``` - - Use the thread ID (format: `PRRT_*`) from `review_threads.json` for `pullRequestReviewThreadId` parameter. - - Example to get thread ID: - - ```bash - THREAD_ID=$(jq -r '.[] | select(.comments[0].body | contains("some text")) | .threadId' "${SCRATCHPAD}/review_threads.json") - ``` - - **Response formatting guidelines**: - - Keep responses concise and single-line when possible - - Avoid embedding code blocks or complex markdown in mutation strings - - Use simple sentences: "Fixed in the latest commit" or "Updated to use GraphQL approach" - - For longer responses, reference line numbers or file paths instead of quoting code - - - For issue comments (pull request-level), use REST API: - - ```bash - gh api repos/${OWNER}/${REPO}/issues/"${ARGUMENTS}"/comments -f body="" - ``` - -- For each response posted, capture the returned comment ID for verification. -- Auto-resolve all comment threads after posting responses: - - For review comment threads: - - Use the thread ID (format: `PRRT_*`) from `review_threads.json`. - - Resolve thread using GraphQL mutation: - - ```bash - gh api graphql -f query=' - mutation($threadId: ID!) { - resolveReviewThread(input: {threadId: $threadId}) { - thread { - id - isResolved - } - } - } - ' -f threadId="" - ``` - - - Map each comment back to its parent thread using the structured files from step 1 (particularly review_threads.json). - - Resolve both addressed and rejected feedback threads (explanation provided in response). - - - For issue comments (pull request-level): - - No resolution mechanism (issue comments don't have thread states). - - Only post response; no resolution step needed. + +#### Posting replies to review threads + +Use the REST reply endpoint with the root comment's integer ID from `review_comments.json`: + +```bash +SCRATCHPAD="" +OWNER="..."; REPO="..."; PR="${ARGUMENTS}" + +post_reply() { + local comment_id="$1" + local body="$2" + gh api "repos/${OWNER}/${REPO}/pulls/${PR}/comments/${comment_id}/replies" \ + -f body="${body}" --jq '.id' +} + +# Example — call once per thread: +post_reply 1234567 "Fixed in abc1234. Updated the Dockerfile path to translate hyphens to underscores." +post_reply 1234568 "Intentional — kept as-is by design." +``` + +**Response formatting guidelines**: +- Keep responses concise and single-line when possible +- Avoid newlines, code blocks, or special characters in the body string +- Reference commit hashes or file paths rather than quoting code inline + +#### Posting PR-level (issue) comments + +```bash +gh api "repos/${OWNER}/${REPO}/issues/${ARGUMENTS}/comments" -f body="" +``` + +Issue comments have no thread state and do not need a resolution step. + +#### Resolving review threads + +**Use Python `urllib.request`** — `gh api graphql` is broken for parameterized GraphQL mutations because bash history expansion mangles the `!` in type annotations like `ID!`. Python string literals are not subject to this, making it the reliable path for mutations. + +```bash +SCRATCHPAD="" + +python3 - << 'PYEOF' +import json, os, urllib.request, subprocess +from pathlib import Path + +token = subprocess.check_output(["gh", "auth", "token"]).decode().strip() +mutation = "mutation($threadId:ID!){resolveReviewThread(input:{threadId:$threadId}){thread{id isResolved}}}" + +scratchpad = Path(os.environ["SCRATCHPAD"]) +threads_data = json.loads((scratchpad / "thread_ids.json").read_text()) +threads = [t["threadId"] for t in threads_data if not t.get("isResolved")] + +for thread_id in threads: + body = json.dumps({"query": mutation, "variables": {"threadId": thread_id}}).encode() + req = urllib.request.Request( + "https://api.github.com/graphql", + data=body, + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=30) as resp: + data = json.loads(resp.read()) + if "errors" in data: + print(f"{thread_id}: ERROR — {data['errors']}") + continue + thread = data.get("data", {}).get("resolveReviewThread", {}).get("thread", {}) + print(f"{thread_id}: isResolved={thread.get('isResolved')}") +PYEOF +``` + +The script reads `thread_ids.json` from the scratchpad automatically to determine which threads to resolve. + +Resolve both addressed and rejected threads — the response comment explains the outcome either way. ### 9. Final Summary - Provide a comprehensive summary showing: - - Total feedback items processed (with count of addressed vs rejected), including any outdated threads that were reviewed. + - Total feedback items processed (with count of addressed vs rejected). - Which checks were fixed. - Confirmation that all comments have been responded to and resolved. - Final verification status (all local checks passing; remote continuous integration is now running against the pushed changes). -- For check failures that were fixed, note that no comments were posted - the fixes will be reflected in re-run checks which are now in progress. +- For check failures that were fixed, note that no comments were posted — the fixes will be reflected in re-run checks which are now in progress. diff --git a/.github/workflows/launch_infrastructure.yaml b/.github/workflows/launch_infrastructure.yaml index 9dcf7b15..0b1ff265 100644 --- a/.github/workflows/launch_infrastructure.yaml +++ b/.github/workflows/launch_infrastructure.yaml @@ -2,20 +2,15 @@ name: Launch infrastructure run-name: Launch infrastructure on: - schedule: - # Runs at 12:00 UTC weekdays (7:00 AM EST / 8:00 AM EDT) - # Launches infrastructure 2.5 hours before market open (EST) or 1.5 hours before (EDT) - - cron: 0 12 * * 1-5 push: branches: - master - workflow_dispatch: concurrency: group: infrastructure-deployment cancel-in-progress: false jobs: build_and_push_images: - name: Build and push ${{ matrix.service }} + name: Build and push ${{ matrix.application }}-${{ matrix.stage }} runs-on: ubuntu-latest environment: pulumi permissions: @@ -24,12 +19,18 @@ jobs: strategy: matrix: include: - - service: data_manager + - application: data-manager + stage: server paths: applications/data_manager/** - - service: portfolio_manager + - application: portfolio-manager + stage: server paths: applications/portfolio_manager/** - - service: ensemble_manager + - application: ensemble-manager + stage: server paths: applications/ensemble_manager/** + - application: model-trainer + stage: server-worker + paths: models/** steps: - name: Checkout code uses: actions/checkout@v4 @@ -46,27 +47,27 @@ jobs: - 'pyproject.toml' - 'uv.lock' - name: Configure AWS credentials - if: steps.changes.outputs.service == 'true' || github.event_name == 'schedule' + if: steps.changes.outputs.service == 'true' uses: aws-actions/configure-aws-credentials@v5 with: role-to-assume: ${{ secrets.AWS_IAM_INFRASTRUCTURE_ROLE_ARN }} aws-region: ${{ secrets.AWS_REGION }} - name: Set up Docker Buildx - if: steps.changes.outputs.service == 'true' || github.event_name == 'schedule' + if: steps.changes.outputs.service == 'true' uses: docker/setup-buildx-action@v3 - name: Install Flox - if: steps.changes.outputs.service == 'true' || github.event_name == 'schedule' + if: steps.changes.outputs.service == 'true' uses: flox/install-flox-action@v2 - - name: Build ${{ matrix.service }} image - if: steps.changes.outputs.service == 'true' || github.event_name == 'schedule' + - name: Build ${{ matrix.application }}-${{ matrix.stage }} image + if: steps.changes.outputs.service == 'true' uses: flox/activate-action@v1 with: - command: mask infrastructure images build ${{ matrix.service }} server - - name: Push ${{ matrix.service }} image - if: steps.changes.outputs.service == 'true' || github.event_name == 'schedule' + command: mask infrastructure image build ${{ matrix.application }} ${{ matrix.stage }} + - name: Push ${{ matrix.application }}-${{ matrix.stage }} image + if: steps.changes.outputs.service == 'true' uses: flox/activate-action@v1 with: - command: mask infrastructure images push ${{ matrix.service }} server + command: mask infrastructure image push ${{ matrix.application }} ${{ matrix.stage }} launch_infrastructure: name: Deploy with Pulumi needs: build_and_push_images @@ -91,15 +92,60 @@ jobs: PULUMI_ACCESS_TOKEN: ${{ secrets.PULUMI_ACCESS_TOKEN }} with: command: mask infrastructure stack up - trigger_data_sync: - name: Trigger sync data workflow + deploy_images: + name: Deploy ${{ matrix.application }}-${{ matrix.stage }} needs: launch_infrastructure - if: github.event_name == 'schedule' runs-on: ubuntu-latest + environment: pulumi permissions: - contents: write + id-token: write + contents: read + strategy: + matrix: + include: + - application: data-manager + stage: server + paths: applications/data_manager/** + - application: portfolio-manager + stage: server + paths: applications/portfolio_manager/** + - application: ensemble-manager + stage: server + paths: applications/ensemble_manager/** + - application: model-trainer + stage: server + paths: models/** + - application: model-trainer + stage: worker + paths: models/** steps: - - name: Trigger sync data workflow - uses: peter-evans/repository-dispatch@v4 + - name: Checkout code + uses: actions/checkout@v4 + - name: Check for service changes + uses: dorny/paths-filter@v3 + id: changes + with: + filters: | + service: + - '${{ matrix.paths }}' + - 'libraries/python/**' + - 'Cargo.toml' + - 'Cargo.lock' + - 'pyproject.toml' + - 'uv.lock' + - name: Configure AWS credentials + if: steps.changes.outputs.service == 'true' + uses: aws-actions/configure-aws-credentials@v5 + with: + role-to-assume: ${{ secrets.AWS_IAM_INFRASTRUCTURE_ROLE_ARN }} + aws-region: ${{ secrets.AWS_REGION }} + - name: Install Flox + if: steps.changes.outputs.service == 'true' + uses: flox/install-flox-action@v2 + - name: Deploy ${{ matrix.application }}-${{ matrix.stage }} + if: steps.changes.outputs.service == 'true' + uses: flox/activate-action@v1 + env: + PULUMI_ACCESS_TOKEN: ${{ secrets.PULUMI_ACCESS_TOKEN }} with: - event-type: sync-data-after-scheduled-launch-infrastructure + command: mask infrastructure image deploy ${{ matrix.application }} ${{ matrix.stage }} diff --git a/.github/workflows/sync_data.yaml b/.github/workflows/sync_data.yaml deleted file mode 100644 index fe3177db..00000000 --- a/.github/workflows/sync_data.yaml +++ /dev/null @@ -1,23 +0,0 @@ ---- -name: Sync data -run-name: Sync data -on: - repository_dispatch: - types: - - sync-data-after-scheduled-launch-infrastructure -jobs: - fetch_data: - name: Sync data on weekday schedule - runs-on: ubuntu-latest - environment: pulumi - steps: - - name: Checkout code - uses: actions/checkout@v4 - - name: Install Flox - uses: flox/install-flox-action@v2 - - name: Fetch with Pulumi - uses: flox/activate-action@v1 - env: - PULUMI_ACCESS_TOKEN: ${{ secrets.PULUMI_ACCESS_TOKEN }} - with: - command: mask infrastructure services invoke datamanager --data-type equity-bars diff --git a/.github/workflows/teardown_infrastructure.yaml b/.github/workflows/teardown_infrastructure.yaml index 13661b7b..7b2684aa 100644 --- a/.github/workflows/teardown_infrastructure.yaml +++ b/.github/workflows/teardown_infrastructure.yaml @@ -2,10 +2,6 @@ name: Teardown infrastructure run-name: Teardown infrastructure on: - schedule: - # Runs at 23:00 UTC weekdays (6:00 PM EST / 7:00 PM EDT) - # Tears down infrastructure 2 hours after market close (EST) or 3 hours after (EDT) - - cron: 0 23 * * 1-5 workflow_dispatch: jobs: teardown_infrastructure: diff --git a/applications/data_manager/Dockerfile b/applications/data_manager/Dockerfile index 8261ed0d..bfd5510f 100644 --- a/applications/data_manager/Dockerfile +++ b/applications/data_manager/Dockerfile @@ -37,6 +37,8 @@ RUN cargo chef prepare --recipe-path recipe.json FROM chef AS builder +ENV CARGO_BUILD_JOBS=4 + COPY --from=planner /app/recipe.json recipe.json RUN --mount=type=cache,target=/usr/local/cargo/registry \ diff --git a/applications/ensemble_manager/Dockerfile b/applications/ensemble_manager/Dockerfile index 6c4ba4aa..4f20d02f 100644 --- a/applications/ensemble_manager/Dockerfile +++ b/applications/ensemble_manager/Dockerfile @@ -11,6 +11,7 @@ COPY applications/ensemble_manager/ applications/ensemble_manager/ COPY libraries/python/ libraries/python/ COPY models/tide/ models/tide/ +COPY tools/ tools/ RUN uv sync --no-dev diff --git a/applications/portfolio_manager/pyproject.toml b/applications/portfolio_manager/pyproject.toml index f94dd87f..546f76f2 100644 --- a/applications/portfolio_manager/pyproject.toml +++ b/applications/portfolio_manager/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "polars>=1.29.0", "requests>=2.32.5", "alpaca-py>=0.42.1", + "pytz>=2025.1", "sentry-sdk[fastapi]>=2.0.0", "structlog>=25.5.0", "scipy>=1.17.1", diff --git a/infrastructure/Pulumi.production.yaml b/infrastructure/Pulumi.production.yaml index 36a40072..dab2470c 100644 --- a/infrastructure/Pulumi.production.yaml +++ b/infrastructure/Pulumi.production.yaml @@ -7,7 +7,7 @@ config: fund:sagemakerExecutionRoleName: fund-sagemaker-execution-role fund:githubRepository: oscmcompany/fund fund:githubBranch: master - fund:monthlyBudgetLimitUsd: "25" + fund:monthlyBudgetLimitUsd: "250" fund:githubWorkflowFiles: - launch_infrastructure.yaml - teardown_infrastructure.yaml diff --git a/infrastructure/__main__.py b/infrastructure/__main__.py index 6a78996c..9c6dc2a6 100644 --- a/infrastructure/__main__.py +++ b/infrastructure/__main__.py @@ -10,12 +10,12 @@ ensemble_manager_image_uri, ensemble_manager_repository, model_artifacts_bucket, + model_trainer_server_worker_image_uri, + model_trainer_server_worker_repository, portfolio_manager_image_uri, portfolio_manager_repository, - tide_trainer_image_uri, - tide_trainer_repository, - training_worker_image_uri, - training_worker_repository, + tide_runner_image_uri, + tide_runner_repository, ) protocol = "https://" if acm_certificate_arn else "http://" @@ -54,18 +54,21 @@ pulumi.Output.unsecret(model_artifacts_bucket.bucket), ) pulumi.export( - "aws_ecr_tide_trainer_repository", - tide_trainer_repository.repository_url, + "aws_ecr_tide_runner_repository", + tide_runner_repository.repository_url, ) -pulumi.export("aws_ecr_tide_trainer_image", tide_trainer_image_uri) +pulumi.export("aws_ecr_tide_runner_image", tide_runner_image_uri) pulumi.export( - "aws_ecr_training_worker_repository", training_worker_repository.repository_url + "aws_ecr_model_trainer_server_worker_repository", + model_trainer_server_worker_repository.repository_url, +) +pulumi.export( + "aws_ecr_model_trainer_server_worker_image", model_trainer_server_worker_image_uri ) -pulumi.export("aws_ecr_training_worker_image", training_worker_image_uri) pulumi.export( "training_api_url", pulumi.Output.concat( - "http://training-server.", service_discovery_namespace.name, ":4200/api" + "http://model-trainer-server.", service_discovery_namespace.name, ":4200/api" ), ) training_ui_url = ( diff --git a/infrastructure/compute.py b/infrastructure/compute.py index 7850006b..259e786b 100644 --- a/infrastructure/compute.py +++ b/infrastructure/compute.py @@ -25,9 +25,8 @@ data_manager_image_uri, ensemble_manager_image_uri, model_artifacts_bucket, + model_trainer_server_worker_image_uri, portfolio_manager_image_uri, - training_server_image_uri, - training_worker_image_uri, ) cluster = aws.ecs.Cluster( @@ -53,12 +52,14 @@ security_groups=[alb_security_group.id], internal=False, load_balancer_type="application", + ip_address_type="ipv4", + enable_cross_zone_load_balancing=True, tags=tags, ) data_manager_tg = aws.lb.TargetGroup( "data_manager_tg", - name="fund-data-manager", + name="fund-data-manager-server", port=8080, protocol="HTTP", vpc_id=vpc.id, @@ -75,7 +76,7 @@ portfolio_manager_tg = aws.lb.TargetGroup( "portfolio_manager_tg", - name="fund-portfolio-manager", + name="fund-portfolio-manager-server", port=8080, protocol="HTTP", vpc_id=vpc.id, @@ -92,7 +93,7 @@ ensemble_manager_tg = aws.lb.TargetGroup( "ensemble_manager_tg", - name="fund-ensemble-manager", + name="fund-ensemble-manager-server", port=8080, protocol="HTTP", vpc_id=vpc.id, @@ -107,9 +108,9 @@ tags=tags, ) -training_tg = aws.lb.TargetGroup( - "training_tg", - name="fund-training", +model_trainer_tg = aws.lb.TargetGroup( + "model_trainer_tg", + name="fund-model-trainer", port=4200, protocol="HTTP", vpc_id=vpc.id, @@ -139,7 +140,7 @@ default_actions=[ aws.lb.ListenerDefaultActionArgs( type="forward", - target_group_arn=training_tg.arn, + target_group_arn=model_trainer_tg.arn, ) ], tags=tags, @@ -153,7 +154,7 @@ default_actions=[ aws.lb.ListenerDefaultActionArgs( type="forward", - target_group_arn=training_tg.arn, + target_group_arn=model_trainer_tg.arn, ) ], tags=tags, @@ -293,7 +294,7 @@ # RDS Security Group - allows inbound Postgres from ECS tasks prefect_rds_security_group = aws.ec2.SecurityGroup( "prefect_rds_sg", - name="fund-prefect-rds", + name="fund-model-trainer-state", vpc_id=vpc.id, description="Security group for Prefect RDS database", tags=tags, @@ -324,7 +325,7 @@ # Redis Security Group - allows inbound Redis from ECS tasks prefect_redis_security_group = aws.ec2.SecurityGroup( "prefect_redis_sg", - name="fund-prefect-redis", + name="fund-model-trainer-broker", vpc_id=vpc.id, description="Security group for Prefect Redis cache", tags=tags, @@ -355,15 +356,15 @@ # RDS Subnet Group prefect_rds_subnet_group = aws.rds.SubnetGroup( "prefect_rds_subnet_group", - name="fund-prefect-rds", + name="fund-model-trainer-state", subnet_ids=[private_subnet_1.id, private_subnet_2.id], tags=tags, ) # RDS PostgreSQL for Prefect database -prefect_database = aws.rds.Instance( - "prefect_database", - identifier="fund-prefect", +model_trainer_state = aws.rds.Instance( + "model_trainer_state", + identifier="fund-model-trainer-state", engine="postgres", engine_version="14", instance_class="db.t3.micro", @@ -374,7 +375,7 @@ db_subnet_group_name=prefect_rds_subnet_group.name, vpc_security_group_ids=[prefect_rds_security_group.id], skip_final_snapshot=False, - final_snapshot_identifier=f"fund-prefect-final-{pulumi.get_stack()}", + final_snapshot_identifier=f"fund-model-trainer-state-final-{pulumi.get_stack()}", backup_retention_period=7, storage_encrypted=True, deletion_protection=True, @@ -386,7 +387,7 @@ "execution_role_prefect_db_secret_policy", name="fund-ecs-execution-role-prefect-db-secret", role=execution_role.id, - policy=prefect_database.master_user_secrets[0]["secret_arn"].apply( + policy=model_trainer_state.master_user_secrets[0]["secret_arn"].apply( lambda arn: json.dumps( { "Version": "2012-10-17", @@ -406,15 +407,15 @@ # ElastiCache Subnet Group prefect_elasticache_subnet_group = aws.elasticache.SubnetGroup( "prefect_elasticache_subnet_group", - name="fund-prefect-redis", + name="fund-model-trainer-broker", subnet_ids=[private_subnet_1.id, private_subnet_2.id], tags=tags, ) # ElastiCache Redis for Prefect messaging -prefect_redis = aws.elasticache.Cluster( - "prefect_redis", - cluster_id="fund-prefect-redis", +model_trainer_broker = aws.elasticache.Cluster( + "model_trainer_broker", + cluster_id="fund-model-trainer-broker", engine="redis", engine_version="7.0", node_type="cache.t3.micro", @@ -451,7 +452,7 @@ # Prefect Server Log Group training_server_log_group = aws.cloudwatch.LogGroup( "training_server_logs", - name="/ecs/fund/training-server", + name="/ecs/fund/model-trainer-server", retention_in_days=7, tags=tags, ) @@ -459,7 +460,7 @@ # Prefect Worker Log Group training_worker_log_group = aws.cloudwatch.LogGroup( "training_worker_logs", - name="/ecs/fund/training-worker", + name="/ecs/fund/model-trainer-worker", retention_in_days=7, tags=tags, ) @@ -467,7 +468,7 @@ # Prefect Server Task Definition training_server_task_definition = aws.ecs.TaskDefinition( "training_server_task", - family="training-server", + family="model-trainer-server", cpu="512", memory="1024", network_mode="awsvpc", @@ -476,15 +477,15 @@ task_role_arn=task_role.arn, container_definitions=pulumi.Output.all( training_server_log_group.name, - prefect_database.endpoint, - prefect_database.master_user_secrets[0]["secret_arn"], - training_server_image_uri, + model_trainer_state.endpoint, + model_trainer_state.master_user_secrets[0]["secret_arn"], + model_trainer_server_worker_image_uri, alb.dns_name, ).apply( lambda args: json.dumps( [ { - "name": "training-server", + "name": "model-trainer-server", "image": args[3], # Inline bash/python constructs the database URL at runtime # because the password comes from Secrets Manager and must be @@ -525,7 +526,7 @@ "options": { "awslogs-group": args[0], "awslogs-region": region, - "awslogs-stream-prefix": "training-server", + "awslogs-stream-prefix": "model-trainer-server", }, }, "essential": True, @@ -540,7 +541,7 @@ # Prefect Server Service Discovery training_server_sd_service = aws.servicediscovery.Service( "training_server_sd", - name="training-server", + name="model-trainer-server", dns_config=aws.servicediscovery.ServiceDnsConfigArgs( namespace_id=service_discovery_namespace.id, dns_records=[ @@ -553,7 +554,7 @@ # Prefect Server ECS Service training_server_service = aws.ecs.Service( "training_server_service", - name="fund-training-server", + name="fund-model-trainer-server", cluster=cluster.arn, task_definition=training_server_task_definition.arn, desired_count=1, @@ -565,8 +566,8 @@ ), load_balancers=[ aws.ecs.ServiceLoadBalancerArgs( - target_group_arn=training_tg.arn, - container_name="training-server", + target_group_arn=model_trainer_tg.arn, + container_name="model-trainer-server", container_port=4200, ) ], @@ -574,7 +575,7 @@ registry_arn=training_server_sd_service.arn ), opts=pulumi.ResourceOptions( - depends_on=[prefect_database, prefect_redis, prefect_listener], + depends_on=[model_trainer_state, model_trainer_broker, prefect_listener], ), tags=tags, ) @@ -582,7 +583,7 @@ # Prefect Worker Task Definition training_worker_task_definition = aws.ecs.TaskDefinition( "training_worker_task", - family="training-worker", + family="model-trainer-worker", cpu="4096", memory="8192", network_mode="awsvpc", @@ -594,19 +595,19 @@ service_discovery_namespace.name, data_bucket.bucket, model_artifacts_bucket.bucket, - training_worker_image_uri, + model_trainer_server_worker_image_uri, training_notification_sender_email_parameter.arn, training_notification_recipients_parameter.arn, ).apply( lambda args: json.dumps( [ { - "name": "training-worker", + "name": "model-trainer-worker", "image": args[4], "environment": [ { "name": "PREFECT_API_URL", - "value": f"http://training-server.{args[1]}:4200/api", + "value": f"http://model-trainer-server.{args[1]}:4200/api", }, { "name": "AWS_S3_DATA_BUCKET_NAME", @@ -618,7 +619,7 @@ }, { "name": "FUND_DATAMANAGER_BASE_URL", - "value": f"http://data-manager.{args[1]}:8080", + "value": f"http://data-manager-server.{args[1]}:8080", }, { "name": "FUND_LOOKBACK_DAYS", @@ -640,7 +641,7 @@ "options": { "awslogs-group": args[0], "awslogs-region": region, - "awslogs-stream-prefix": "training-worker", + "awslogs-stream-prefix": "model-trainer-worker", }, }, "essential": True, @@ -655,7 +656,7 @@ # Prefect Worker ECS Service training_worker_service = aws.ecs.Service( "training_worker_service", - name="fund-training-worker", + name="fund-model-trainer-worker", cluster=cluster.arn, task_definition=training_worker_task_definition.arn, desired_count=1, @@ -673,28 +674,28 @@ data_manager_log_group = aws.cloudwatch.LogGroup( "data_manager_logs", - name="/ecs/fund/data-manager", + name="/ecs/fund/data-manager-server", retention_in_days=7, tags=tags, ) portfolio_manager_log_group = aws.cloudwatch.LogGroup( "portfolio_manager_logs", - name="/ecs/fund/portfolio-manager", + name="/ecs/fund/portfolio-manager-server", retention_in_days=7, tags=tags, ) ensemble_manager_log_group = aws.cloudwatch.LogGroup( "ensemble_manager_logs", - name="/ecs/fund/ensemble-manager", + name="/ecs/fund/ensemble-manager-server", retention_in_days=7, tags=tags, ) data_manager_task_definition = aws.ecs.TaskDefinition( "data_manager_task", - family="data-manager", + family="data-manager-server", cpu="256", memory="512", network_mode="awsvpc", @@ -711,7 +712,7 @@ lambda args: json.dumps( [ { - "name": "data-manager", + "name": "data-manager-server", "image": args[1], "portMappings": [{"containerPort": 8080, "protocol": "tcp"}], "environment": [ @@ -747,7 +748,7 @@ "options": { "awslogs-group": args[0], "awslogs-region": region, - "awslogs-stream-prefix": "data-manager", + "awslogs-stream-prefix": "data-manager-server", }, }, "essential": True, @@ -761,7 +762,7 @@ portfolio_manager_task_definition = aws.ecs.TaskDefinition( "portfolio_manager_task", - family="portfolio-manager", + family="portfolio-manager-server", cpu="256", memory="512", network_mode="awsvpc", @@ -779,17 +780,17 @@ lambda args: json.dumps( [ { - "name": "portfolio-manager", + "name": "portfolio-manager-server", "image": args[2], "portMappings": [{"containerPort": 8080, "protocol": "tcp"}], "environment": [ { "name": "FUND_DATAMANAGER_BASE_URL", - "value": f"http://data-manager.{args[1]}:8080", + "value": f"http://data-manager-server.{args[1]}:8080", }, { "name": "FUND_ENSEMBLE_MANAGER_BASE_URL", - "value": f"http://ensemble-manager.{args[1]}:8080", + "value": f"http://ensemble-manager-server.{args[1]}:8080", }, { "name": "FUND_ENVIRONMENT", @@ -823,7 +824,7 @@ "options": { "awslogs-group": args[0], "awslogs-region": region, - "awslogs-stream-prefix": "portfolio-manager", + "awslogs-stream-prefix": "portfolio-manager-server", }, }, "essential": True, @@ -837,7 +838,7 @@ ensemble_manager_task_definition = aws.ecs.TaskDefinition( "ensemble_manager_task", - family="ensemble-manager", + family="ensemble-manager-server", cpu="256", memory="512", network_mode="awsvpc", @@ -854,13 +855,13 @@ lambda args: json.dumps( [ { - "name": "ensemble-manager", + "name": "ensemble-manager-server", "image": args[2], "portMappings": [{"containerPort": 8080, "protocol": "tcp"}], "environment": [ { "name": "FUND_DATAMANAGER_BASE_URL", - "value": f"http://data-manager.{args[1]}:8080", + "value": f"http://data-manager-server.{args[1]}:8080", }, { "name": "AWS_S3_MODEL_ARTIFACTS_BUCKET_NAME", @@ -886,7 +887,7 @@ "options": { "awslogs-group": args[0], "awslogs-region": region, - "awslogs-stream-prefix": "ensemble-manager", + "awslogs-stream-prefix": "ensemble-manager-server", }, }, "essential": True, @@ -900,7 +901,7 @@ data_manager_sd_service = aws.servicediscovery.Service( "data_manager_sd", - name="data-manager", + name="data-manager-server", dns_config=aws.servicediscovery.ServiceDnsConfigArgs( namespace_id=service_discovery_namespace.id, dns_records=[ @@ -912,7 +913,7 @@ portfolio_manager_sd_service = aws.servicediscovery.Service( "portfolio_manager_sd", - name="portfolio-manager", + name="portfolio-manager-server", dns_config=aws.servicediscovery.ServiceDnsConfigArgs( namespace_id=service_discovery_namespace.id, dns_records=[ @@ -924,7 +925,7 @@ ensemble_manager_sd_service = aws.servicediscovery.Service( "ensemble_manager_sd", - name="ensemble-manager", + name="ensemble-manager-server", dns_config=aws.servicediscovery.ServiceDnsConfigArgs( namespace_id=service_discovery_namespace.id, dns_records=[ @@ -936,7 +937,7 @@ data_manager_service = aws.ecs.Service( "data_manager_service", - name="fund-data-manager", + name="fund-data-manager-server", cluster=cluster.arn, task_definition=data_manager_task_definition.arn, desired_count=1, @@ -949,7 +950,7 @@ load_balancers=[ aws.ecs.ServiceLoadBalancerArgs( target_group_arn=data_manager_tg.arn, - container_name="data-manager", + container_name="data-manager-server", container_port=8080, ) ], @@ -962,7 +963,7 @@ portfolio_manager_service = aws.ecs.Service( "portfolio_manager_service", - name="fund-portfolio-manager", + name="fund-portfolio-manager-server", cluster=cluster.arn, task_definition=portfolio_manager_task_definition.arn, desired_count=1, @@ -975,7 +976,7 @@ load_balancers=[ aws.ecs.ServiceLoadBalancerArgs( target_group_arn=portfolio_manager_tg.arn, - container_name="portfolio-manager", + container_name="portfolio-manager-server", container_port=8080, ) ], @@ -988,7 +989,7 @@ ensemble_manager_service = aws.ecs.Service( "ensemble_manager_service", - name="fund-ensemble-manager", + name="fund-ensemble-manager-server", cluster=cluster.arn, task_definition=ensemble_manager_task_definition.arn, desired_count=1, @@ -1001,7 +1002,7 @@ load_balancers=[ aws.ecs.ServiceLoadBalancerArgs( target_group_arn=ensemble_manager_tg.arn, - container_name="ensemble-manager", + container_name="ensemble-manager-server", container_port=8080, ) ], diff --git a/infrastructure/github_environment_runbook.md b/infrastructure/github_environment_runbook.md index 11e14d88..55c90f17 100644 --- a/infrastructure/github_environment_runbook.md +++ b/infrastructure/github_environment_runbook.md @@ -8,29 +8,8 @@ Required environment secrets for operations with Pulumi: - `AWS_REGION` - `PULUMI_ACCESS_TOKEN` -## Update `AWS_IAM_INFRASTRUCTURE_ROLE_ARN` from Pulumi output - -Run from repository root: - -```bash -cd infrastructure -pulumi stack select "$(pulumi org get-default)/fund/production" -role_arn="$(pulumi stack output aws_iam_github_actions_infrastructure_role_arn --stack production)" -cd .. -gh secret set AWS_IAM_INFRASTRUCTURE_ROLE_ARN --env pulumi --body "$role_arn" -``` - -## Update `AWS_REGION` from Pulumi stack config - -Run from repository root: - -```bash -cd infrastructure -pulumi stack select "$(pulumi org get-default)/fund/production" -region="$(pulumi config get aws:region --stack production --show-secrets)" -cd .. -gh secret set AWS_REGION --env pulumi --body "$region" -``` +`AWS_IAM_INFRASTRUCTURE_ROLE_ARN` and `AWS_REGION` are set automatically when running +`mask infrastructure stack up --bootstrap` from a local machine with `gh` authenticated. ## Update `PULUMI_ACCESS_TOKEN` from Pulumi account diff --git a/infrastructure/notifications.py b/infrastructure/notifications.py index e94143c2..f6f3e1f8 100644 --- a/infrastructure/notifications.py +++ b/infrastructure/notifications.py @@ -23,6 +23,8 @@ endpoint=notification_email_address, ) +# This can be updated by setting the monthlyBudgetLimitUsd Pulumi configuration +# variable. aws.budgets.Budget( "production_cost_budget", account_id=account_id, diff --git a/infrastructure/storage.py b/infrastructure/storage.py index d56f5701..da012832 100644 --- a/infrastructure/storage.py +++ b/infrastructure/storage.py @@ -14,7 +14,7 @@ "tagStatus": "untagged", "countType": "sinceImagePushed", "countUnit": "days", - "countNumber": 0, + "countNumber": 1, }, "action": {"type": "expire"}, } @@ -23,14 +23,10 @@ ) # S3 Data Bucket for storing equity bars, predictions, portfolios -# alias: migrated from aws:s3/bucket:Bucket to aws:s3/bucketV2:BucketV2 -data_bucket = aws.s3.BucketV2( +data_bucket = aws.s3.Bucket( "data_bucket", bucket=pulumi.Output.concat("fund-data-", random_suffix), - opts=pulumi.ResourceOptions( - retain_on_delete=True, - aliases=[pulumi.Alias(type_="aws:s3/bucket:Bucket")], - ), + opts=pulumi.ResourceOptions(retain_on_delete=True), tags=tags, ) @@ -67,14 +63,10 @@ ) # S3 Model Artifacts Bucket for storing trained model weights and checkpoints -# alias: migrated from aws:s3/bucket:Bucket to aws:s3/bucketV2:BucketV2 -model_artifacts_bucket = aws.s3.BucketV2( +model_artifacts_bucket = aws.s3.Bucket( "model_artifacts_bucket", bucket=pulumi.Output.concat("fund-model-artifacts-", random_suffix), - opts=pulumi.ResourceOptions( - retain_on_delete=True, - aliases=[pulumi.Alias(type_="aws:s3/bucket:Bucket")], - ), + opts=pulumi.ResourceOptions(retain_on_delete=True), tags=tags, ) @@ -116,7 +108,7 @@ # retain_on_delete=True and add pulumi import statements to the maskfile up command. data_manager_repository = aws.ecr.Repository( "data_manager_repository", - name="fund/data_manager-server", + name="fund/data-manager-server", image_tag_mutability="MUTABLE", force_delete=True, image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( @@ -133,7 +125,7 @@ portfolio_manager_repository = aws.ecr.Repository( "portfolio_manager_repository", - name="fund/portfolio_manager-server", + name="fund/portfolio-manager-server", image_tag_mutability="MUTABLE", force_delete=True, image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( @@ -150,7 +142,7 @@ ensemble_manager_repository = aws.ecr.Repository( "ensemble_manager_repository", - name="fund/ensemble_manager-server", + name="fund/ensemble-manager-server", image_tag_mutability="MUTABLE", force_delete=True, image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( @@ -165,9 +157,9 @@ policy=_ecr_lifecycle_policy, ) -tide_trainer_repository = aws.ecr.Repository( - "tide_trainer_repository", - name="fund/tide-trainer", +tide_runner_repository = aws.ecr.Repository( + "tide_runner_repository", + name="fund/tide-runner", image_tag_mutability="MUTABLE", force_delete=True, image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( @@ -177,14 +169,14 @@ ) aws.ecr.LifecyclePolicy( - "tide_trainer_repository_lifecycle", - repository=tide_trainer_repository.name, + "tide_runner_repository_lifecycle", + repository=tide_runner_repository.name, policy=_ecr_lifecycle_policy, ) -training_server_repository = aws.ecr.Repository( - "training_server_repository", - name="fund/training-server", +model_trainer_server_worker_repository = aws.ecr.Repository( + "model_trainer_server_worker_repository", + name="fund/model-trainer-server-worker", image_tag_mutability="MUTABLE", force_delete=True, image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( @@ -194,25 +186,8 @@ ) aws.ecr.LifecyclePolicy( - "training_server_repository_lifecycle", - repository=training_server_repository.name, - policy=_ecr_lifecycle_policy, -) - -training_worker_repository = aws.ecr.Repository( - "training_worker_repository", - name="fund/training-worker", - image_tag_mutability="MUTABLE", - force_delete=True, - image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( - scan_on_push=True, - ), - tags=tags, -) - -aws.ecr.LifecyclePolicy( - "training_worker_repository_lifecycle", - repository=training_worker_repository.name, + "model_trainer_server_worker_repository_lifecycle", + repository=model_trainer_server_worker_repository.name, policy=_ecr_lifecycle_policy, ) @@ -227,12 +202,11 @@ ensemble_manager_image_uri = ensemble_manager_repository.repository_url.apply( lambda url: f"{url}:latest" ) -tide_trainer_image_uri = tide_trainer_repository.repository_url.apply( - lambda url: f"{url}:latest" -) -training_server_image_uri = training_server_repository.repository_url.apply( +tide_runner_image_uri = tide_runner_repository.repository_url.apply( lambda url: f"{url}:latest" ) -training_worker_image_uri = training_worker_repository.repository_url.apply( - lambda url: f"{url}:latest" +model_trainer_server_worker_image_uri = ( + model_trainer_server_worker_repository.repository_url.apply( + lambda url: f"{url}:latest" + ) ) diff --git a/libraries/python/tests/test_infrastructure_storage.py b/libraries/python/tests/test_infrastructure_storage.py index 82de9d24..8c2f3b7e 100644 --- a/libraries/python/tests/test_infrastructure_storage.py +++ b/libraries/python/tests/test_infrastructure_storage.py @@ -28,6 +28,7 @@ def test_storage_contains_ecr_lifecycle_policy_resources() -> None: assert '"data_manager_repository_lifecycle"' in infrastructure_storage assert '"portfolio_manager_repository_lifecycle"' in infrastructure_storage assert '"ensemble_manager_repository_lifecycle"' in infrastructure_storage - assert '"tide_trainer_repository_lifecycle"' in infrastructure_storage - assert '"training_server_repository_lifecycle"' in infrastructure_storage - assert '"training_worker_repository_lifecycle"' in infrastructure_storage + assert '"tide_runner_repository_lifecycle"' in infrastructure_storage + assert ( + '"model_trainer_server_worker_repository_lifecycle"' in infrastructure_storage + ) diff --git a/maskfile.md b/maskfile.md index 175bd84c..243cff02 100644 --- a/maskfile.md +++ b/maskfile.md @@ -48,13 +48,13 @@ echo "Development environment setup completed successfully" > Manage infrastructure resources -### images +### image > Manage Docker images for applications -#### build (application_name) (stage_name) +#### build (package_name) (stage_name) -> Build Docker images with optional cache pull +> Build Docker images with optional cache pull (e.g. `portfolio-manager server`, `tide runner`) ```bash set -euo pipefail @@ -62,23 +62,51 @@ set -euo pipefail echo "Building image" aws_account_id=$(aws sts get-caller-identity --query Account --output text) -aws_region=${AWS_REGION} +aws_region="${AWS_REGION:-}" if [ -z "$aws_region" ]; then echo "AWS_REGION environment variable is not set" exit 1 fi -if [ "${application_name}" = "training" ]; then - dockerfile="tools/Dockerfile" +commit_hash=$(git rev-parse --short HEAD) +repository_name="fund/${package_name}-${stage_name}" +if [ "${package_name}" = "model-trainer" ]; then + repository_name="fund/model-trainer-server-worker" +fi +image_reference="${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com/${repository_name}" + +echo "Logging into ECR" +aws ecr get-login-password --region ${aws_region} | docker login \ + --username AWS \ + --password-stdin ${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com 2>/dev/null || echo "Could not authenticate to ECR (will build without cache)" + +echo "Checking if image for commit ${commit_hash} already exists in ECR" +existing_image=$(aws ecr describe-images \ + --repository-name "${repository_name}" \ + --image-ids "imageTag=git-${commit_hash}" \ + --query 'imageDetails[0].imageDigest' \ + --output text 2>/dev/null || echo "NONE") +if [ "$existing_image" != "NONE" ] && [ "$existing_image" != "None" ] && [ -n "$existing_image" ]; then + echo "Image for commit ${commit_hash} already exists in ECR, skipping build" + exit 0 +fi + +if [ "${package_name}" = "model-trainer" ]; then + dockerfile="models/Dockerfile" + build_target="server-worker" +elif [ -f "models/${package_name}/Dockerfile" ]; then + dockerfile="models/${package_name}/Dockerfile" + build_target="${stage_name}" else - dockerfile="applications/${application_name}/Dockerfile" + resolved_name=$(echo "${package_name}" | tr '-' '_') + dockerfile="applications/${resolved_name}/Dockerfile" + build_target="${stage_name}" fi -image_reference="${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com/fund/${application_name}-${stage_name}" cache_reference="${image_reference}:buildcache" # Use GHA backend for caching when running in GitHub Actions if [ -n "${GITHUB_ACTIONS:-}" ]; then - scope="${application_name}-${stage_name}" + scope="${package_name}-${stage_name}" echo "Running in GitHub Actions - using hybrid cache (gha + registry) with scope: ${scope}" cache_from_arguments="--cache-from type=gha,scope=${scope} --cache-from type=registry,ref=${cache_reference}" cache_to_arguments="--cache-to type=gha,scope=${scope},mode=max --cache-to type=registry,ref=${cache_reference},mode=max" @@ -95,15 +123,10 @@ else docker buildx create --use --name fund-builder 2>/dev/null || docker buildx use fund-builder || (echo "Using default buildx builder" && docker buildx use default) fi -echo "Logging into ECR (to pull cache if available)" -aws ecr get-login-password --region ${aws_region} | docker login \ - --username AWS \ - --password-stdin ${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com 2>/dev/null || echo "Could not authenticate to ECR for cache (will build without cache)" - echo "Building with caching (will continue if cache doesn't exist)" docker buildx build \ --platform linux/amd64 \ - --target ${stage_name} \ + --target ${build_target} \ --file ${dockerfile} \ --tag ${image_reference}:latest \ ${cache_from_arguments} \ @@ -111,12 +134,12 @@ docker buildx build \ --load \ . -echo "Image built: ${application_name} ${stage_name}" +echo "Image built: ${package_name} ${stage_name}" ``` -#### push (application_name) (stage_name) +#### push (package_name) (stage_name) -> Push Docker image to ECR +> Push Docker image to ECR (e.g. `portfolio-manager server`, `tide runner`) ```bash set -euo pipefail @@ -124,13 +147,16 @@ set -euo pipefail echo "Pushing image to ECR" aws_account_id=$(aws sts get-caller-identity --query Account --output text) -aws_region=${AWS_REGION} +aws_region="${AWS_REGION:-}" if [ -z "$aws_region" ]; then echo "AWS_REGION environment variable is not set" exit 1 fi -repository_name="fund/${application_name}-${stage_name}" +repository_name="fund/${package_name}-${stage_name}" +if [ "${package_name}" = "model-trainer" ]; then + repository_name="fund/model-trainer-server-worker" +fi image_reference="${aws_account_id}.dkr.ecr.${aws_region}.amazonaws.com/${repository_name}" commit_hash=$(git rev-parse --short HEAD) @@ -151,7 +177,7 @@ fi if [ "$existing_tag" != "NONE" ] && [ "$existing_tag" != "None" ] && [ -n "$existing_tag" ]; then echo "Image for commit ${commit_hash} already exists in ECR, skipping push" - echo "Image pushed: ${application_name} ${stage_name} (cached)" + echo "Image pushed: ${package_name} ${stage_name} (cached)" exit 0 fi @@ -160,7 +186,47 @@ docker tag "${image_reference}:latest" "${image_reference}:git-${commit_hash}" docker push "${image_reference}:latest" docker push "${image_reference}:git-${commit_hash}" -echo "Image pushed: ${application_name} ${stage_name} (commit: ${commit_hash})" +echo "Image pushed: ${package_name} ${stage_name} (commit: ${commit_hash})" +``` + +#### deploy (package_name) (stage_name) + +> Deploy ECS service with latest image (e.g. `portfolio-manager server`, `data-manager server`) + +```bash +set -euo pipefail + +echo "Deploying ${package_name} ${stage_name}" + +case "${package_name}-${stage_name}" in + data-manager-server) service="fund-data-manager-server" ;; + portfolio-manager-server) service="fund-portfolio-manager-server" ;; + ensemble-manager-server) service="fund-ensemble-manager-server" ;; + model-trainer-server) service="fund-model-trainer-server" ;; + model-trainer-worker) service="fund-model-trainer-worker" ;; + tide-runner) echo "No ECS service for tide runner" && exit 0 ;; + *) echo "Unknown service: ${package_name}-${stage_name}" && exit 1 ;; +esac + +cd infrastructure/ + +if ! organization_name=$(pulumi org get-default 2>/dev/null) || [ -z "${organization_name}" ]; then + echo "Error: Pulumi default organization not set. Run: pulumi org set-default " + exit 1 +fi +pulumi stack select "${organization_name}/fund/production" +cluster=$(pulumi stack output aws_ecs_cluster_name) + +cd "${MASKFILE_DIR}" + +aws ecs update-service --cluster "$cluster" --service "$service" --force-new-deployment --no-cli-pager > /dev/null +echo "Deployment started: ${service}" + +echo "Waiting for ${service} to stabilize" + +aws ecs wait services-stable --cluster "$cluster" --services "$service" + +echo "Deployment complete: ${service} (${package_name} ${stage_name})" ``` ### stack @@ -213,18 +279,16 @@ if [ -n "$GITHUB_POLICY_ARN" ]; then pulumi import --yes --generate-code=false aws:iam/policy:Policy github_actions_infrastructure_policy "$GITHUB_POLICY_ARN" 2>/dev/null || true fi -pulumi import --yes --generate-code=false aws:s3/bucketV2:BucketV2 data_bucket "fund-data-${RANDOM_SUFFIX}" 2>/dev/null || true +pulumi import --yes --generate-code=false aws:s3/bucket:Bucket data_bucket "fund-data-${RANDOM_SUFFIX}" 2>/dev/null || true pulumi import --yes --generate-code=false aws:s3/bucketServerSideEncryptionConfiguration:BucketServerSideEncryptionConfiguration data_bucket_encryption "fund-data-${RANDOM_SUFFIX}" 2>/dev/null || true pulumi import --yes --generate-code=false aws:s3/bucketPublicAccessBlock:BucketPublicAccessBlock data_bucket_public_access_block "fund-data-${RANDOM_SUFFIX}" 2>/dev/null || true pulumi import --yes --generate-code=false aws:s3/bucketVersioning:BucketVersioning data_bucket_versioning "fund-data-${RANDOM_SUFFIX}" 2>/dev/null || true -pulumi import --yes --generate-code=false aws:s3/bucketV2:BucketV2 model_artifacts_bucket "fund-model-artifacts-${RANDOM_SUFFIX}" 2>/dev/null || true +pulumi import --yes --generate-code=false aws:s3/bucket:Bucket model_artifacts_bucket "fund-model-artifacts-${RANDOM_SUFFIX}" 2>/dev/null || true pulumi import --yes --generate-code=false aws:s3/bucketServerSideEncryptionConfiguration:BucketServerSideEncryptionConfiguration model_artifacts_bucket_encryption "fund-model-artifacts-${RANDOM_SUFFIX}" 2>/dev/null || true pulumi import --yes --generate-code=false aws:s3/bucketPublicAccessBlock:BucketPublicAccessBlock model_artifacts_bucket_public_access_block "fund-model-artifacts-${RANDOM_SUFFIX}" 2>/dev/null || true pulumi import --yes --generate-code=false aws:s3/bucketVersioning:BucketVersioning model_artifacts_bucket_versioning "fund-model-artifacts-${RANDOM_SUFFIX}" 2>/dev/null || true -pulumi import --yes --generate-code=false aws:ssm/parameter:Parameter ssm_ensemble_manager_model_version "/fund/production/ensemble-manager/model-version" 2>/dev/null || true - echo "Importing resources complete" pulumi up --diff --yes @@ -260,63 +324,6 @@ if [[ "$BOOTSTRAP" == "true" ]]; then fi fi -echo "Forcing ECS service deployments to pull latest images" - -cluster=$(pulumi stack output aws_ecs_cluster_name --stack production 2>/dev/null || echo "") - -if [ -z "$cluster" ]; then - echo "Cluster not found - skipping service deployments (initial setup)" -else - # Note: Service names use 'fund' prefix matching the Pulumi project name. - # These must exactly match the ECS service names created by the infrastructure code. - # The AWS account provides environment context (one account = one environment). - for service in fund-data-manager fund-portfolio-manager fund-ensemble-manager fund-training-server fund-training-worker; do - echo "Checking if $service exists and is ready" - - # Wait up to 60 seconds for service to be active - retry_count=0 - maximum_retries=12 - retry_wait_seconds=5 - service_is_ready=false - - while [ $retry_count -lt $maximum_retries ]; do - service_status=$(aws ecs describe-services \ - --cluster "$cluster" \ - --services "$service" \ - --query 'services[0].status' \ - --output text 2>/dev/null || echo "NONE") - - if [ "$service_status" = "ACTIVE" ]; then - service_is_ready=true - echo "Service $service is ACTIVE" - break - elif [ "$service_status" = "NONE" ]; then - echo "Service not found, waiting ($((retry_count + 1))/$maximum_retries)" - else - echo "Service status: $service_status, waiting ($((retry_count + 1))/$maximum_retries)" - fi - - sleep $retry_wait_seconds - retry_count=$((retry_count + 1)) - done - - if [ "$service_is_ready" = true ]; then - echo "Forcing new deployment for $service" - aws ecs update-service \ - --cluster "$cluster" \ - --service "$service" \ - --force-new-deployment \ - --no-cli-pager \ - --output text > /dev/null 2>&1 && echo "Deployment initiated" || echo "Failed to force deployment" - else - echo "Skipping $service (not ready after 60s - may be initial deployment)" - fi - done - - echo "Stack update complete - ECS is performing rolling deployments" - echo "Monitor progress: aws ecs describe-services --cluster $cluster --services fund-portfolio-manager" -fi - echo "Infrastructure launched successfully" ``` @@ -336,7 +343,7 @@ pulumi down --yes --stack production echo "Infrastructure torn down successfully" ``` -### services +### service > Manage infrastructure services @@ -770,7 +777,7 @@ echo "YAML development checks completed successfully" > Model management commands -### train +### train (model_name) > Train model via Prefect training pipeline @@ -794,10 +801,18 @@ export FUND_LOOKBACK_DAYS="${FUND_LOOKBACK_DAYS:-365}" cd ../ -uv run python -m tide.run +case "${model_name}" in + tide) + uv run python -m tide.run + ;; + *) + echo "Unknown model: ${model_name}" + exit 1 + ;; +esac ``` -### deploy +### deploy (model_name) > Register flow deployment with Prefect server @@ -813,7 +828,7 @@ fi pulumi stack select ${organization_name}/fund/production -export FUND_DATAMANAGER_BASE_URL="http://data-manager.$(pulumi stack output aws_service_discovery_namespace):8080" +export FUND_DATAMANAGER_BASE_URL="http://data-manager-server.$(pulumi stack output aws_service_discovery_namespace):8080" export AWS_S3_DATA_BUCKET_NAME="$(pulumi stack output aws_s3_data_bucket_name)" export AWS_S3_MODEL_ARTIFACTS_BUCKET_NAME="$(pulumi stack output aws_s3_model_artifacts_bucket_name)" export PREFECT_API_URL="$(pulumi stack output training_api_url)" @@ -821,19 +836,34 @@ export FUND_LOOKBACK_DAYS="${FUND_LOOKBACK_DAYS:-365}" cd ../ -uv run python -m tide.deploy +case "${model_name}" in + tide) + uv run python -m tide.deploy + ;; + *) + echo "Unknown model: ${model_name}" + exit 1 + ;; +esac ``` -### download (application_name) +### download (model_name) > Download model artifacts ```bash set -euo pipefail -export APPLICATION_NAME="${application_name}" - -uv run python -m tools.download_model_artifacts +case "${model_name}" in + tide) + export APPLICATION_NAME="${model_name}" + uv run python -m tools.download_model_artifacts + ;; + *) + echo "Unknown model: ${model_name}" + exit 1 + ;; +esac ``` ## mcp diff --git a/tools/Dockerfile b/models/Dockerfile similarity index 96% rename from tools/Dockerfile rename to models/Dockerfile index 873d2e97..d74356b3 100644 --- a/tools/Dockerfile +++ b/models/Dockerfile @@ -14,11 +14,13 @@ COPY tools/ tools/ COPY applications/ensemble_manager/ applications/ensemble_manager/ +COPY models/tide/ models/tide/ + COPY libraries/python/ libraries/python/ RUN uv sync --no-dev -FROM python:3.12.10-slim AS worker +FROM python:3.12.10-slim AS server-worker WORKDIR /app diff --git a/models/tide/Dockerfile b/models/tide/Dockerfile index 582150f5..28d46086 100644 --- a/models/tide/Dockerfile +++ b/models/tide/Dockerfile @@ -14,7 +14,7 @@ COPY tools/ tools/ RUN uv sync --no-dev --package tide -FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS trainer +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS runner ENV DEBIAN_FRONTEND=noninteractive ENV TZ=UTC diff --git a/models/tide/src/tide/deploy.py b/models/tide/src/tide/deploy.py index f147b1c1..c8a233ce 100644 --- a/models/tide/src/tide/deploy.py +++ b/models/tide/src/tide/deploy.py @@ -27,7 +27,8 @@ def deploy_training_flow( training_pipeline.deploy( name="daily-training", work_pool_name="training-pool", - cron="0 22 * * *", + cron="0 22 * * 1-5", + timezone="America/New_York", parameters={ "base_url": base_url, "data_bucket": data_bucket, diff --git a/models/tide/src/tide/trainer.py b/models/tide/src/tide/trainer.py deleted file mode 100644 index 44e22b51..00000000 --- a/models/tide/src/tide/trainer.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import cast - -import polars as pl -import structlog - -from tide.tide_data import Data -from tide.tide_model import Model - -logger = structlog.get_logger() - -DEFAULT_CONFIGURATION = { - "architecture": "TiDE", - "learning_rate": 0.003, - "epoch_count": 20, - "validation_split": 0.8, - "input_length": 35, - "output_length": 7, - "hidden_size": 64, - "num_encoder_layers": 2, - "num_decoder_layers": 2, - "dropout_rate": 0.1, - "batch_size": 256, -} - - -def train_model( - training_data: pl.DataFrame, - configuration: dict | None = None, - checkpoint_directory: str | None = None, -) -> tuple[Model, Data]: - """Train TiDE model and return model + data processor.""" - merged_configuration = dict(DEFAULT_CONFIGURATION) - if configuration is not None: - merged_configuration.update(configuration) - configuration = merged_configuration - - logger.info("Configuration loaded", **configuration) - - logger.info("Initializing data processor") - tide_data = Data() - - logger.info("Preprocessing training data") - tide_data.preprocess_and_set_data(data=training_data) - - logger.info("Getting data dimensions") - dimensions = tide_data.get_dimensions() - logger.info("Data dimensions", **dimensions) - - logger.info("Creating training batches") - train_batches = tide_data.get_batches( - data_type="train", - validation_split=float(configuration["validation_split"]), - input_length=int(configuration["input_length"]), - output_length=int(configuration["output_length"]), - batch_size=int(configuration["batch_size"]), - ) - - logger.info("Training batches created", batch_count=len(train_batches)) - - if not train_batches: - logger.error( - "No training batches created", - validation_split=configuration["validation_split"], - input_length=configuration["input_length"], - output_length=configuration["output_length"], - batch_size=configuration["batch_size"], - training_data_rows=training_data.height, - ) - message = ( - "No training batches created - check input data and configuration. " - f"Training data has {training_data.height} rows, " - f"input_length={configuration['input_length']}, " - f"output_length={configuration['output_length']}, " - f"batch_size={configuration['batch_size']}" - ) - raise ValueError(message) - - sample_batch = train_batches[0] - - batch_size = sample_batch["past_continuous_features"].shape[0] - logger.info("Batch size determined", batch_size=batch_size) - - past_continuous_size = ( - sample_batch["past_continuous_features"].reshape(batch_size, -1).shape[1] - ) - past_categorical_size = ( - sample_batch["past_categorical_features"].reshape(batch_size, -1).shape[1] - ) - future_categorical_size = ( - sample_batch["future_categorical_features"].reshape(batch_size, -1).shape[1] - ) - static_categorical_size = ( - sample_batch["static_categorical_features"].reshape(batch_size, -1).shape[1] - ) - - input_size = cast( - "int", - past_continuous_size - + past_categorical_size - + future_categorical_size - + static_categorical_size, - ) - - logger.info("Input size calculated", input_size=input_size) - - logger.info("Creating model") - tide_model = Model( - input_size=input_size, - hidden_size=int(configuration["hidden_size"]), - num_encoder_layers=int(configuration["num_encoder_layers"]), - num_decoder_layers=int(configuration["num_decoder_layers"]), - output_length=int(configuration["output_length"]), - dropout_rate=float(configuration["dropout_rate"]), - quantiles=[0.1, 0.5, 0.9], - ) - - logger.info("Training started", epochs=configuration["epoch_count"]) - - losses = tide_model.train( - train_batches=train_batches, - epochs=int(configuration["epoch_count"]), - learning_rate=float(configuration["learning_rate"]), - checkpoint_directory=checkpoint_directory, - ) - - logger.info( - "Training complete", - final_loss=losses[-1] if losses else None, - all_losses=losses, - ) - - return tide_model, tide_data diff --git a/models/tide/src/tide/workflow.py b/models/tide/src/tide/workflow.py index 716baeb9..bef055e4 100644 --- a/models/tide/src/tide/workflow.py +++ b/models/tide/src/tide/workflow.py @@ -5,6 +5,7 @@ import tempfile from datetime import UTC, datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING, Any, cast import boto3 import polars as pl @@ -18,6 +19,137 @@ logger = structlog.get_logger() +if TYPE_CHECKING: + from tide.tide_data import Data + from tide.tide_model import Model + +DEFAULT_CONFIGURATION = { + "architecture": "TiDE", + "learning_rate": 0.003, + "epoch_count": 20, + "validation_split": 0.8, + "input_length": 35, + "output_length": 7, + "hidden_size": 64, + "num_encoder_layers": 2, + "num_decoder_layers": 2, + "dropout_rate": 0.1, + "batch_size": 256, +} + + +def train_model( + training_data: pl.DataFrame, + configuration: dict[str, Any] | None = None, + checkpoint_directory: str | None = None, +) -> "tuple[Model, Data]": + """Train TiDE model and return model + data processor.""" + # Defer imports to avoid loading tinygrad at module level (heavy GPU dependency) + from tide.tide_data import Data # noqa: PLC0415 + from tide.tide_model import Model # noqa: PLC0415 + + merged_configuration = dict(DEFAULT_CONFIGURATION) + if configuration is not None: + merged_configuration.update(configuration) + configuration = merged_configuration + + logger.info("Configuration loaded", **configuration) + + logger.info("Initializing data processor") + tide_data = Data() + + logger.info("Preprocessing training data") + tide_data.preprocess_and_set_data(data=training_data) + + logger.info("Getting data dimensions") + dimensions = tide_data.get_dimensions() + logger.info("Data dimensions", **dimensions) + + logger.info("Creating training batches") + train_batches = tide_data.get_batches( + data_type="train", + validation_split=float(configuration["validation_split"]), + input_length=int(configuration["input_length"]), + output_length=int(configuration["output_length"]), + batch_size=int(configuration["batch_size"]), + ) + + logger.info("Training batches created", batch_count=len(train_batches)) + + if not train_batches: + logger.error( + "No training batches created", + validation_split=configuration["validation_split"], + input_length=configuration["input_length"], + output_length=configuration["output_length"], + batch_size=configuration["batch_size"], + training_data_rows=training_data.height, + ) + message = ( + "No training batches created - check input data and configuration. " + f"Training data has {training_data.height} rows, " + f"input_length={configuration['input_length']}, " + f"output_length={configuration['output_length']}, " + f"batch_size={configuration['batch_size']}" + ) + raise ValueError(message) + + sample_batch = train_batches[0] + + batch_size = sample_batch["past_continuous_features"].shape[0] + logger.info("Batch size determined", batch_size=batch_size) + + past_continuous_size = ( + sample_batch["past_continuous_features"].reshape(batch_size, -1).shape[1] + ) + past_categorical_size = ( + sample_batch["past_categorical_features"].reshape(batch_size, -1).shape[1] + ) + future_categorical_size = ( + sample_batch["future_categorical_features"].reshape(batch_size, -1).shape[1] + ) + static_categorical_size = ( + sample_batch["static_categorical_features"].reshape(batch_size, -1).shape[1] + ) + + input_size = cast( + "int", + past_continuous_size + + past_categorical_size + + future_categorical_size + + static_categorical_size, + ) + + logger.info("Input size calculated", input_size=input_size) + + logger.info("Creating model") + tide_model = Model( + input_size=input_size, + hidden_size=int(configuration["hidden_size"]), + num_encoder_layers=int(configuration["num_encoder_layers"]), + num_decoder_layers=int(configuration["num_decoder_layers"]), + output_length=int(configuration["output_length"]), + dropout_rate=float(configuration["dropout_rate"]), + quantiles=[0.1, 0.5, 0.9], + ) + + logger.info("Training started", epochs=configuration["epoch_count"]) + + losses = tide_model.train( + train_batches=train_batches, + epochs=int(configuration["epoch_count"]), + learning_rate=float(configuration["learning_rate"]), + checkpoint_directory=checkpoint_directory, + ) + + logger.info( + "Training complete", + final_loss=losses[-1] if losses else None, + all_losses=losses, + ) + + return tide_model, tide_data + def get_training_date_range(lookback_days: int) -> tuple[datetime, datetime]: """Build a UTC date range used by sync + prepare steps.""" @@ -102,9 +234,6 @@ def train_tide_model( training_data_key: str = "training/filtered_tide_training_data.parquet", ) -> str: """Download training data from S3, train model, upload artifact to S3.""" - # Defer import to avoid importing tinygrad at module level (heavy GPU dependency) - from tide.trainer import train_model # noqa: PLC0415 - resolved_training_data_key = training_data_key bucket_prefix = f"s3://{artifacts_bucket}/" if training_data_key.startswith(bucket_prefix): diff --git a/models/tide/tests/test_deploy.py b/models/tide/tests/test_deploy.py index 7da5c987..10eeb11d 100644 --- a/models/tide/tests/test_deploy.py +++ b/models/tide/tests/test_deploy.py @@ -22,7 +22,8 @@ def test_deploy_training_flow_calls_deploy(mock_pipeline: MagicMock) -> None: call_kwargs = mock_deploy.call_args.kwargs assert call_kwargs["name"] == "daily-training" assert call_kwargs["work_pool_name"] == "training-pool" - assert call_kwargs["cron"] == "0 22 * * *" + assert call_kwargs["cron"] == "0 22 * * 1-5" + assert call_kwargs["timezone"] == "America/New_York" assert call_kwargs["parameters"]["base_url"] == "http://example.com" assert call_kwargs["parameters"]["lookback_days"] == LOOKBACK_DAYS diff --git a/models/tide/tests/test_trainer.py b/models/tide/tests/test_trainer.py deleted file mode 100644 index 64244e98..00000000 --- a/models/tide/tests/test_trainer.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Callable - -import polars as pl -import pytest -from tide.trainer import DEFAULT_CONFIGURATION, train_model - -PARTIAL_HIDDEN_SIZE = 16 - - -def test_train_model_returns_model_and_data( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - training_data = make_raw_data(days=90) - model, data = train_model(training_data) - assert model is not None - assert data is not None - assert hasattr(data, "scaler") - assert hasattr(data, "mappings") - - -def test_train_model_uses_custom_configuration( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - training_data = make_raw_data(days=90) - custom_config = dict(DEFAULT_CONFIGURATION) - custom_hidden_size = 32 - custom_config["epoch_count"] = 1 - custom_config["hidden_size"] = custom_hidden_size - model, _data = train_model(training_data, configuration=custom_config) - assert model.hidden_size == custom_hidden_size - - -def test_train_model_raises_on_insufficient_data( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - short_data = make_raw_data(tickers=["AAPL"], days=5) - with pytest.raises(ValueError, match="Total days available"): - train_model(short_data) - - -def test_train_model_uses_default_configuration( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - training_data = make_raw_data(days=90) - model, _ = train_model(training_data) - assert model.hidden_size == DEFAULT_CONFIGURATION["hidden_size"] - assert model.output_length == DEFAULT_CONFIGURATION["output_length"] - - -def test_train_model_merges_partial_configuration( - make_raw_data: Callable[..., pl.DataFrame], -) -> None: - training_data = make_raw_data(days=90) - model, _ = train_model( - training_data, - configuration={ - "epoch_count": 1, - "hidden_size": PARTIAL_HIDDEN_SIZE, - }, - ) - assert model.hidden_size == PARTIAL_HIDDEN_SIZE - assert model.output_length == DEFAULT_CONFIGURATION["output_length"] diff --git a/models/tide/tests/test_workflow.py b/models/tide/tests/test_workflow.py index cc766105..282182ea 100644 --- a/models/tide/tests/test_workflow.py +++ b/models/tide/tests/test_workflow.py @@ -1,18 +1,22 @@ import io +from collections.abc import Callable from datetime import UTC, datetime from unittest.mock import MagicMock, patch import polars as pl import pytest from tide.workflow import ( + DEFAULT_CONFIGURATION, prepare_data, sync_equity_bars, sync_equity_details, + train_model, train_tide_model, training_pipeline, ) LOOKBACK_DAYS = 30 +PARTIAL_HIDDEN_SIZE = 16 @patch("tide.workflow.sync_equity_bars_data") @@ -110,7 +114,7 @@ def test_train_tide_model_downloads_trains_uploads(mock_boto3: MagicMock) -> Non mock_model = MagicMock() mock_data = MagicMock() - with patch("tide.trainer.train_model") as mock_train: + with patch("tide.workflow.train_model") as mock_train: mock_train.return_value = (mock_model, mock_data) result = train_tide_model.fn( artifacts_bucket="artifacts-bucket", @@ -162,3 +166,57 @@ def test_training_pipeline_threads_data_key( "training/data.parquet", ) assert result == "s3://bucket/model" + + +def test_train_model_returns_model_and_data( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + training_data = make_raw_data(days=90) + model, data = train_model(training_data, configuration={"epoch_count": 1}) + assert model is not None + assert data is not None + assert hasattr(data, "scaler") + assert hasattr(data, "mappings") + + +def test_train_model_uses_custom_configuration( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + training_data = make_raw_data(days=90) + custom_config = dict(DEFAULT_CONFIGURATION) + custom_config["epoch_count"] = 1 + custom_config["hidden_size"] = PARTIAL_HIDDEN_SIZE * 2 + model, _data = train_model(training_data, configuration=custom_config) + assert model.hidden_size == PARTIAL_HIDDEN_SIZE * 2 + + +def test_train_model_raises_on_insufficient_data( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + short_data = make_raw_data(tickers=["AAPL"], days=5) + with pytest.raises(ValueError, match="Total days available"): + train_model(short_data) + + +def test_train_model_uses_default_configuration( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + training_data = make_raw_data(days=90) + model, _ = train_model(training_data, configuration={"epoch_count": 1}) + assert model.hidden_size == DEFAULT_CONFIGURATION["hidden_size"] + assert model.output_length == DEFAULT_CONFIGURATION["output_length"] + + +def test_train_model_merges_partial_configuration( + make_raw_data: Callable[..., pl.DataFrame], +) -> None: + training_data = make_raw_data(days=90) + model, _ = train_model( + training_data, + configuration={ + "epoch_count": 1, + "hidden_size": PARTIAL_HIDDEN_SIZE, + }, + ) + assert model.hidden_size == PARTIAL_HIDDEN_SIZE + assert model.output_length == DEFAULT_CONFIGURATION["output_length"] diff --git a/uv.lock b/uv.lock index f5d3a6bd..512e56dd 100644 --- a/uv.lock +++ b/uv.lock @@ -1310,6 +1310,7 @@ dependencies = [ { name = "internal" }, { name = "pandera", extra = ["polars"] }, { name = "polars" }, + { name = "pytz" }, { name = "requests" }, { name = "scipy" }, { name = "sentry-sdk", extra = ["fastapi"] }, @@ -1325,6 +1326,7 @@ requires-dist = [ { name = "internal", editable = "libraries/python" }, { name = "pandera", extras = ["polars"], specifier = ">=0.26.0" }, { name = "polars", specifier = ">=1.29.0" }, + { name = "pytz", specifier = ">=2025.1" }, { name = "requests", specifier = ">=2.32.5" }, { name = "scipy", specifier = ">=1.17.1" }, { name = "sentry-sdk", extras = ["fastapi"], specifier = ">=2.0.0" },