diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 912c6bbca..ce2b77cbc 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -1,24 +1,35 @@ { - "permissions": { - "allow": [ - "Bash(cat:*)", - "Bash(curl:*)", - "Bash(docker:*)", - "Bash(flox activate:*)", - "Bash(gh issue:*)", - "Bash(mask:*)", - "Bash(mkdir:*)", - "Bash(pre-commit:*)", - "Bash(pulumi:*)", - "Bash(rm:*)", - "Bash(timeout 5 curl -I http://localhost:*)", - "Bash(uv:*)", - "WebFetch(domain:docs.github.com)", - "WebSearch" - ], - "deny": [], - "defaultMode": "acceptEdits" - }, - "enableAllProjectMcpServers": false + "permissions": { + "allow": [ + "Bash(cat:*)", + "Bash(curl:*)", + "Bash(docker:*)", + "Bash(flox activate:*)", + "Bash(gh issue:*)", + "Bash(mask:*)", + "Bash(mkdir:*)", + "Bash(pre-commit:*)", + "Bash(pulumi:*)", + "Bash(rm:*)", + "Bash(timeout 5 curl -I http://localhost:*)", + "Bash(uv:*)", + "WebFetch(domain:docs.github.com)", + "WebSearch", + "Bash(aws ecr describe-repositories:*)", + "Bash(aws ecs describe-services:*)", + "Bash(aws logs tail:*)", + "Bash(aws logs:*)", + "Bash(xargs:*)", + "Bash(aws ecs describe-task-definition:*)", + "Bash(aws s3 cp:*)", + "Bash(python3:*)", + "Bash(aws s3 ls:*)", + "Bash(find:*)", + "Bash(gh api:*)", + "Bash(gh pr view:*)" + ], + "deny": [], + "defaultMode": "acceptEdits" + }, + "enableAllProjectMcpServers": false } - diff --git a/.flox/env/manifest.lock b/.flox/env/manifest.lock index e435a1744..faeff6cb7 100644 --- a/.flox/env/manifest.lock +++ b/.flox/env/manifest.lock @@ -3269,4 +3269,4 @@ "priority": 5 } ] -} \ No newline at end of file +} diff --git a/applications/datamanager/Dockerfile b/applications/datamanager/Dockerfile index 69eabd105..472f443ec 100644 --- a/applications/datamanager/Dockerfile +++ b/applications/datamanager/Dockerfile @@ -49,6 +49,9 @@ COPY applications/datamanager/Cargo.toml ./applications/datamanager/Cargo.toml COPY applications/datamanager/src/ applications/datamanager/src/ +ENV DUCKDB_LIB_DIR=/usr/local/lib +ENV DUCKDB_INCLUDE_DIR=/usr/local/include + RUN --mount=type=cache,target=/usr/local/cargo/registry \ --mount=type=cache,target=/app/target \ cargo build --release --bin datamanager && \ diff --git a/applications/datamanager/src/data.rs b/applications/datamanager/src/data.rs index b3f98c65a..38b801c5d 100644 --- a/applications/datamanager/src/data.rs +++ b/applications/datamanager/src/data.rs @@ -2,6 +2,7 @@ use crate::errors::Error; use polars::prelude::*; use serde::Deserialize; use std::io::Cursor; +use tracing::{debug, info, warn}; #[derive(Debug, Deserialize)] pub struct EquityBar { @@ -17,6 +18,11 @@ pub struct EquityBar { } pub fn create_equity_bar_dataframe(equity_bars_rows: Vec) -> Result { + debug!( + "Creating equity bar DataFrame from {} rows", + equity_bars_rows.len() + ); + let equity_bars_dataframe = df!( "ticker" => equity_bars_rows.iter().map(|b| b.ticker.as_str()).collect::>(), "timestamp" => equity_bars_rows.iter().map(|b| b.timestamp).collect::>(), @@ -28,13 +34,23 @@ pub fn create_equity_bar_dataframe(equity_bars_rows: Vec) -> Result equity_bars_rows.iter().map(|b| b.volume_weighted_average_price).collect::>(), "transactions" => equity_bars_rows.iter().map(|b| b.transactions).collect::>(), ) - .map_err(|e| Error::Other(format!("Failed to create DataFrame: {}", e)))?; + .map_err(|e| { + warn!("Failed to create equity bar DataFrame: {}", e); + Error::Other(format!("Failed to create DataFrame: {}", e)) + })?; + debug!("Normalizing ticker column to uppercase"); let equity_bars_dataframe = equity_bars_dataframe .lazy() .with_columns([col("ticker").str().to_uppercase().alias("ticker")]) .collect()?; + info!( + "Created equity bar DataFrame: {} rows x {} columns", + equity_bars_dataframe.height(), + equity_bars_dataframe.width() + ); + Ok(equity_bars_dataframe) } @@ -48,6 +64,11 @@ pub struct Prediction { } pub fn create_predictions_dataframe(prediction_rows: Vec) -> Result { + debug!( + "Creating predictions DataFrame from {} rows", + prediction_rows.len() + ); + let prediction_dataframe = df!( "ticker" => prediction_rows.iter().map(|p| p.ticker.as_str()).collect::>(), "timestamp" => prediction_rows.iter().map(|p| p.timestamp).collect::>(), @@ -55,14 +76,24 @@ pub fn create_predictions_dataframe(prediction_rows: Vec) -> Result< "quantile_50" => prediction_rows.iter().map(|p| p.quantile_50).collect::>(), "quantile_90" => prediction_rows.iter().map(|p| p.quantile_90).collect::>(), ) - .map_err(|e| Error::Other(format!("Failed to create DataFrame: {}", e)))?; + .map_err(|e| { + warn!("Failed to create predictions DataFrame: {}", e); + Error::Other(format!("Failed to create DataFrame: {}", e)) + })?; + debug!("Normalizing ticker column to uppercase"); let unfiltered_prediction_dataframe = prediction_dataframe .lazy() .with_columns([col("ticker").str().to_uppercase().alias("ticker")]) .collect()?; + debug!( + "Unfiltered predictions DataFrame has {} rows", + unfiltered_prediction_dataframe.height() + ); + // filtering necessary due to potentially overlapping tickers in predictions parquet files + debug!("Filtering to keep only most recent prediction per ticker"); let filtered_prediction_dataframe = unfiltered_prediction_dataframe .lazy() .with_columns([col("timestamp") @@ -79,6 +110,13 @@ pub fn create_predictions_dataframe(prediction_rows: Vec) -> Result< ]) .collect()?; + info!( + "Created predictions DataFrame: {} rows x {} columns (filtered from {} input rows)", + filtered_prediction_dataframe.height(), + filtered_prediction_dataframe.width(), + prediction_rows.len() + ); + Ok(filtered_prediction_dataframe) } @@ -92,6 +130,11 @@ pub struct Portfolio { } pub fn create_portfolio_dataframe(portfolio_rows: Vec) -> Result { + debug!( + "Creating portfolio DataFrame from {} rows", + portfolio_rows.len() + ); + let portfolio_dataframe = df!( "ticker" => portfolio_rows.iter().map(|p| p.ticker.as_str()).collect::>(), "timestamp" => portfolio_rows.iter().map(|p| p.timestamp).collect::>(), @@ -99,8 +142,12 @@ pub fn create_portfolio_dataframe(portfolio_rows: Vec) -> Result portfolio_rows.iter().map(|p| p.dollar_amount).collect::>(), "action" => portfolio_rows.iter().map(|p| p.action.as_str()).collect::>(), ) - .map_err(|e| Error::Other(format!("Failed to create DataFrame: {}", e)))?; + .map_err(|e| { + warn!("Failed to create portfolio DataFrame: {}", e); + Error::Other(format!("Failed to create DataFrame: {}", e)) + })?; + debug!("Normalizing ticker, side, and action columns to uppercase"); let portfolio_dataframe = portfolio_dataframe .lazy() .with_columns([col("ticker").str().to_uppercase().alias("ticker")]) @@ -108,30 +155,58 @@ pub fn create_portfolio_dataframe(portfolio_rows: Vec) -> Result Result { + debug!( + "Creating equity details DataFrame from CSV ({} bytes)", + csv_content.len() + ); + let cursor = Cursor::new(csv_content.as_bytes()); let mut dataframe = CsvReadOptions::default() .with_has_header(true) .into_reader_with_file_handle(cursor) .finish() - .map_err(|e| Error::Other(format!("Failed to parse CSV: {}", e)))?; + .map_err(|e| { + warn!("Failed to parse CSV: {}", e); + Error::Other(format!("Failed to parse CSV: {}", e)) + })?; + + debug!( + "Parsed CSV into DataFrame: {} rows x {} columns", + dataframe.height(), + dataframe.width() + ); let required_columns = vec!["sector", "industry"]; let column_names = dataframe.get_column_names(); + + debug!("Available columns: {:?}", column_names); + debug!("Required columns: {:?}", required_columns); + for column in &required_columns { if !column_names.iter().any(|c| c.as_str() == *column) { let message = format!("CSV missing required column: {}", column); + warn!("{}", message); return Err(Error::Other(message)); } } - dataframe = dataframe - .select(required_columns) - .map_err(|e| Error::Other(format!("Failed to select columns: {}", e)))?; + debug!("All required columns present, selecting subset"); + dataframe = dataframe.select(required_columns).map_err(|e| { + warn!("Failed to select columns: {}", e); + Error::Other(format!("Failed to select columns: {}", e)) + })?; + debug!("Normalizing sector and industry columns to uppercase and filling nulls"); let equity_details_dataframe = dataframe .lazy() .with_columns([ @@ -147,7 +222,16 @@ pub fn create_equity_details_dataframe(csv_content: String) -> Result, - h: Option, - l: Option, + c: Option, + h: Option, + l: Option, n: Option, - o: Option, + o: Option, t: u64, - v: Option, - vw: Option, + v: Option, + vw: Option, } #[derive(Deserialize, Debug)] @@ -54,7 +54,10 @@ pub async fn query( AxumState(state): AxumState, Query(parameters): Query, ) -> impl IntoResponse { - info!("Querying equity data from S3 partitioned files"); + info!( + "Querying equity data from S3 partitioned files, tickers: {:?}, start: {:?}, end: {:?}", + parameters.tickers, parameters.start_timestamp, parameters.end_timestamp + ); let tickers: Option> = match ¶meters.tickers { Some(tickers_str) if !tickers_str.is_empty() => { @@ -63,12 +66,17 @@ pub async fn query( .map(|s| s.trim().to_uppercase()) .collect(); if vec.is_empty() { + debug!("Ticker list was empty after parsing"); None } else { + debug!("Parsed {} tickers: {:?}", vec.len(), vec); Some(vec) } } - _ => None, + _ => { + debug!("No tickers specified, querying all"); + None + } }; match query_equity_bars_parquet_from_s3( @@ -80,6 +88,10 @@ pub async fn query( .await { Ok(parquet_data) => { + info!( + "Query successful, returning {} bytes of parquet data", + parquet_data.len() + ); let mut response = Response::new(Body::from(parquet_data)); response.headers_mut().insert( header::CONTENT_TYPE, @@ -95,7 +107,7 @@ pub async fn query( response } Err(err) => { - info!("Failed to query S3 data: {}", err); + warn!("Failed to query S3 data: {}", err); ( StatusCode::INTERNAL_SERVER_ERROR, format!("Query failed: {}", err), @@ -117,17 +129,21 @@ pub async fn sync( ); info!("url: {}", url); + info!("Sending request to Massive API"); let response = match state .http_client - .get(url) + .get(&url) .header("accept", "application/json") .query(&[("adjusted", "true"), ("apiKey", state.massive.key.as_str())]) .send() .await { - Ok(resp) => resp, + Ok(resp) => { + info!("Received response from Massive API, status: {}", resp.status()); + resp + } Err(err) => { - info!("Failed to send request: {}", err); + warn!("Failed to send request to Massive API: {}", err); return ( StatusCode::INTERNAL_SERVER_ERROR, "Failed to send API request", @@ -138,9 +154,12 @@ pub async fn sync( let text_content = match response.error_for_status() { Ok(response) => match response.text().await { - Ok(text) => text, + Ok(text) => { + info!("Received response body, length: {} bytes", text.len()); + text + } Err(err) => { - info!("Failed to read response text: {}", err); + warn!("Failed to read response text: {}", err); return ( StatusCode::INTERNAL_SERVER_ERROR, "Failed to read API response", @@ -149,15 +168,21 @@ pub async fn sync( } }, Err(err) => { - info!("API request failed: {}", err); + warn!("API request failed with error status: {}", err); return (StatusCode::INTERNAL_SERVER_ERROR, "API request failed").into_response(); } }; + info!("Parsing JSON response"); let json_content: serde_json::Value = match serde_json::from_str(&text_content) { - Ok(value) => value, + Ok(value) => { + debug!("JSON parsed successfully"); + value + } Err(err) => { - info!("Failed to parse JSON response: {}", err); + warn!("Failed to parse JSON response: {}", err); + let truncated: String = text_content.chars().take(500).collect(); + warn!("Raw response (first 500 chars): {}", truncated); return ( StatusCode::INTERNAL_SERVER_ERROR, "Invalid JSON response from API", @@ -167,9 +192,13 @@ pub async fn sync( }; let results = match json_content.get("results") { - Some(results) => results, + Some(results) => { + info!("Found results field in response"); + results + } None => { - info!("No results field found in response"); + warn!("No results field found in response"); + debug!("Response keys: {:?}", json_content.as_object().map(|o| o.keys().collect::>())); return ( StatusCode::NO_CONTENT, "No market data available for this date", @@ -178,22 +207,29 @@ pub async fn sync( } }; - let bars: Vec = match serde_json::from_value(results.clone()) { - Ok(bars) => bars, + info!("Parsing results into BarResult structs"); + let bars: Vec = match serde_json::from_value::>(results.clone()) { + Ok(bars) => { + info!("Successfully parsed {} bar results", bars.len()); + bars + } Err(err) => { - info!("Failed to parse results into BarResult structs: {}", err); + warn!("Failed to parse results into BarResult structs: {}", err); + warn!("Results type: {:?}", results.as_array().map(|a| a.len())); + if let Some(first_result) = results.as_array().and_then(|a| a.first()) { + warn!("First result sample: {}", first_result); + } return (StatusCode::BAD_GATEWAY, text_content).into_response(); } }; let tickers: Vec = bars.iter().map(|b| b.ticker.clone()).collect(); - let volumes: Vec> = bars.iter().map(|b| b.v).collect(); - let volume_weighted_average_prices: Vec> = - bars.iter().map(|b| b.vw.map(|vw| vw as f64)).collect(); - let open_prices: Vec> = bars.iter().map(|b| b.o.map(|o| o as f64)).collect(); - let close_prices: Vec> = bars.iter().map(|b| b.c.map(|c| c as f64)).collect(); - let high_prices: Vec> = bars.iter().map(|b| b.h.map(|h| h as f64)).collect(); - let low_prices: Vec> = bars.iter().map(|b| b.l.map(|l| l as f64)).collect(); + let volumes: Vec> = bars.iter().map(|b| b.v).collect(); + let volume_weighted_average_prices: Vec> = bars.iter().map(|b| b.vw).collect(); + let open_prices: Vec> = bars.iter().map(|b| b.o).collect(); + let close_prices: Vec> = bars.iter().map(|b| b.c).collect(); + let high_prices: Vec> = bars.iter().map(|b| b.h).collect(); + let low_prices: Vec> = bars.iter().map(|b| b.l).collect(); let timestamps: Vec = bars.iter().map(|b| b.t as i64).collect(); let transactions: Vec> = bars.iter().map(|b| b.n).collect(); @@ -209,6 +245,7 @@ pub async fn sync( "transactions" => transactions, }; + info!("Creating DataFrame from bar data"); match bars_data { Ok(data) => { info!( @@ -218,6 +255,7 @@ pub async fn sync( ); debug!("DataFrame schema: {:?}", data.schema()); + info!("Uploading DataFrame to S3"); match write_equity_bars_dataframe_to_s3(&state, &data, &payload.date).await { Ok(s3_key) => { info!("Successfully uploaded DataFrame to S3 at key: {}", s3_key); @@ -229,7 +267,7 @@ pub async fn sync( (StatusCode::OK, response_message).into_response() } Err(err) => { - info!("Failed to upload to S3: {}", err); + warn!("Failed to upload to S3: {}", err); let json_output = data.to_string(); ( StatusCode::BAD_GATEWAY, @@ -243,7 +281,7 @@ pub async fn sync( } } Err(err) => { - info!("Failed to create DataFrame: {}", err); + warn!("Failed to create DataFrame: {}", err); (StatusCode::INTERNAL_SERVER_ERROR, text_content).into_response() } } diff --git a/applications/datamanager/src/health.rs b/applications/datamanager/src/health.rs index c08692239..a69eef970 100644 --- a/applications/datamanager/src/health.rs +++ b/applications/datamanager/src/health.rs @@ -1,5 +1,7 @@ use axum::{http::StatusCode, response::IntoResponse}; +use tracing::debug; pub async fn get_health() -> impl IntoResponse { + debug!("Health check endpoint called"); (StatusCode::OK).into_response() } diff --git a/applications/datamanager/src/main.rs b/applications/datamanager/src/main.rs index 3dc79a42c..46897cbd4 100644 --- a/applications/datamanager/src/main.rs +++ b/applications/datamanager/src/main.rs @@ -17,11 +17,13 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example=debug,tower_http=debug,axum=debug".into()), + .unwrap_or_else(|_| "datamanager=debug,tower_http=debug,axum=debug".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); + tracing::info!("Starting datamanager service"); + let app = create_app().await; let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await.unwrap(); diff --git a/applications/datamanager/src/state.rs b/applications/datamanager/src/state.rs index 25da37e71..779a5f7f7 100644 --- a/applications/datamanager/src/state.rs +++ b/applications/datamanager/src/state.rs @@ -1,5 +1,6 @@ use aws_sdk_s3::Client as S3Client; use reqwest::Client as HTTPClient; +use tracing::{debug, info}; #[derive(Clone)] pub struct MassiveSecrets { @@ -17,23 +18,44 @@ pub struct State { impl State { pub async fn from_env() -> Self { + info!("Initializing application state from environment"); + + debug!("Creating HTTP client with 10s timeout"); let http_client = HTTPClient::builder() .timeout(std::time::Duration::from_secs(10)) .build() .unwrap(); + debug!("Loading AWS configuration"); let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await; + + let region = config + .region() + .map(|r| r.as_ref().to_string()) + .unwrap_or_else(|| "not configured".to_string()); + info!("AWS region: {}", region); + let s3_client = S3Client::new(&config); - let bucket_name = - std::env::var("AWS_S3_DATA_BUCKET_NAME").unwrap_or("pocketsizefund-data".to_string()); + + let bucket_name = std::env::var("AWS_S3_DATA_BUCKET_NAME") + .expect("AWS_S3_DATA_BUCKET_NAME must be set in environment"); + info!("Using S3 bucket from environment: {}", bucket_name); + + let massive_base = std::env::var("MASSIVE_BASE_URL") + .expect("MASSIVE_BASE_URL must be set in environment"); + info!("Using Massive API base URL from environment: {}", massive_base); + + let massive_key = std::env::var("MASSIVE_API_KEY") + .expect("MASSIVE_API_KEY must be set in environment"); + debug!("MASSIVE_API_KEY loaded (length: {} chars)", massive_key.len()); + + info!("Application state initialized successfully"); Self { http_client, massive: MassiveSecrets { - base: std::env::var("MASSIVE_BASE_URL") - .unwrap_or("https://api.massive.io".to_string()), - key: std::env::var("MASSIVE_API_KEY") - .expect("MASSIVE_API_KEY must be set in environment"), + base: massive_base, + key: massive_key, }, s3_client, bucket_name, diff --git a/applications/datamanager/src/storage.rs b/applications/datamanager/src/storage.rs index aa62aba9d..c2b873d64 100644 --- a/applications/datamanager/src/storage.rs +++ b/applications/datamanager/src/storage.rs @@ -11,7 +11,10 @@ use duckdb::Connection; use polars::prelude::*; use serde::Deserialize; use std::io::Cursor; -use tracing::{debug, info}; +use tracing::{debug, info, warn}; + +const MIN_DATE_INT: i32 = 0; +const MAX_DATE_INT: i32 = 99999999; pub async fn write_equity_bars_dataframe_to_s3( state: &State, @@ -95,19 +98,33 @@ async fn write_dataframe_to_s3( } async fn create_duckdb_connection() -> Result { + debug!("Opening in-memory DuckDB connection"); let connection = Connection::open_in_memory()?; + debug!("Installing and loading httpfs extension"); connection.execute_batch("INSTALL httpfs; LOAD httpfs;")?; + debug!("Loading AWS configuration for DuckDB S3 access"); let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await; - let provider = config - .credentials_provider() - .ok_or_else(|| Error::Other("No AWS credentials provider found".into()))?; + let provider = config.credentials_provider().ok_or_else(|| { + warn!("No AWS credentials provider found"); + Error::Other("No AWS credentials provider found".into()) + })?; + + debug!("Fetching AWS credentials"); let credentials = provider.provide_credentials().await?; + let region = config .region() .map(|r| r.as_ref().to_string()) - .unwrap_or_else(|| "us-east-1".to_string()); + .ok_or_else(|| Error::Other("AWS region must be configured".to_string()))?; + + let has_session_token = credentials.session_token().is_some(); + debug!( + "AWS credentials loaded: region={}, has_session_token={}", + region, has_session_token + ); + let session_token = credentials.session_token().unwrap_or_default(); let s3_config = format!( " @@ -123,8 +140,10 @@ async fn create_duckdb_connection() -> Result { session_token ); + debug!("Configuring DuckDB S3 settings"); connection.execute_batch(&s3_config)?; + info!("DuckDB connection established with S3 access"); Ok(connection) } @@ -141,65 +160,58 @@ pub async fn query_equity_bars_parquet_from_s3( _ => { let end_date = chrono::Utc::now(); let start_date = end_date - chrono::Duration::days(7); + info!( + "No date range specified, using default: {} to {}", + start_date, end_date + ); (start_date, end_date) } }; info!( - "Querying data from {} to {}", - start_timestamp, end_timestamp + "Querying equity bars from {} to {}, bucket: {}", + start_timestamp, end_timestamp, state.bucket_name ); - let mut s3_paths = Vec::new(); - let mut current_timestamp = start_timestamp; - - while current_timestamp <= end_timestamp { - let year = current_timestamp.format("%Y"); - let month = current_timestamp.format("%m"); - let day = current_timestamp.format("%d"); - - let s3_path = format!( - "s3://{}/equity/bars/daily/year={}/month={}/day={}/data.parquet", - state.bucket_name, year, month, day - ); - s3_paths.push(s3_path); - - current_timestamp += chrono::Duration::days(1); - } + let s3_glob = format!( + "s3://{}/equity/bars/daily/**/*.parquet", + state.bucket_name + ); - if s3_paths.is_empty() { - return Err(Error::Other( - "No files to query for the given date range".to_string(), - )); - } + info!("Using S3 glob pattern: {}", s3_glob); - info!("Querying {} S3 files", s3_paths.len()); + let start_date_int = start_timestamp.format("%Y%m%d").to_string().parse::().unwrap_or(MIN_DATE_INT); + let end_date_int = end_timestamp.format("%Y%m%d").to_string().parse::().unwrap_or(MAX_DATE_INT); - let s3_paths_str = s3_paths - .iter() - .map(|path| format!("SELECT * FROM '{}'", path)) - .collect::>() - .join(" UNION ALL "); + debug!( + "Date range filter: {} to {} (as integers)", + start_date_int, end_date_int + ); let ticker_filter = match &tickers { Some(ticker_list) if !ticker_list.is_empty() => { - // Validate ticker format to prevent SQL injection + debug!("Validating {} tickers for query filter", ticker_list.len()); for ticker in ticker_list { if !ticker .chars() .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-') { + warn!("Invalid ticker format rejected: {}", ticker); return Err(Error::Other(format!("Invalid ticker format: {}", ticker))); } } + debug!("Ticker validation passed: {:?}", ticker_list); let ticker_values = ticker_list .iter() .map(|t| format!("'{}'", t.replace('\'', "''"))) .collect::>() .join(", "); - format!("WHERE ticker IN ({})", ticker_values) + format!("AND ticker IN ({})", ticker_values) + } + _ => { + debug!("No ticker filter applied, querying all tickers"); + String::new() } - _ => String::new(), }; let query_sql = format!( @@ -214,16 +226,20 @@ pub async fn query_equity_bars_parquet_from_s3( volume, volume_weighted_average_price, transactions - FROM ({}) + FROM read_parquet('{}', hive_partitioning=1) + WHERE (year::int * 10000 + month::int * 100 + day::int) BETWEEN {} AND {} {} ORDER BY timestamp, ticker ", - s3_paths_str, ticker_filter + s3_glob, start_date_int, end_date_int, ticker_filter ); debug!("Executing query SQL: {}", query_sql); + info!("Preparing DuckDB statement"); let mut statement = connection.prepare(&query_sql)?; + + info!("Executing query and mapping results"); let equity_bars: Vec = statement .query_map([], |row| { Ok(EquityBar { @@ -239,8 +255,21 @@ pub async fn query_equity_bars_parquet_from_s3( }) })? .collect::, _>>() - .map_err(|e| Error::Other(format!("Failed to map query results: {}", e)))?; + .map_err(|e| { + warn!("Failed to map query results: {}", e); + Error::Other(format!("Failed to map query results: {}", e)) + })?; + info!("Query returned {} equity bar records", equity_bars.len()); + + if equity_bars.is_empty() { + warn!( + "No equity bar data found for date range {} to {}", + start_timestamp, end_timestamp + ); + } + + debug!("Creating DataFrame from equity bars"); let equity_bars_dataframe = create_equity_bar_dataframe(equity_bars); let mut buffer = Vec::new(); @@ -267,6 +296,10 @@ pub async fn query_predictions_dataframe_from_s3( state: &State, predictions_query: Vec, ) -> Result { + info!( + "Querying predictions for {} ticker/timestamp pairs", + predictions_query.len() + ); let connection = create_duckdb_connection().await?; let mut s3_paths = Vec::new(); @@ -282,16 +315,22 @@ pub async fn query_predictions_dataframe_from_s3( state.bucket_name, year, month, day ); + debug!( + "Adding S3 path for ticker {} at {}/{}/{}: {}", + prediction_query.ticker, year, month, day, s3_path + ); + s3_paths.push(s3_path); tickers.push(prediction_query.ticker.clone()); } if s3_paths.is_empty() { + warn!("No prediction query positions provided"); return Err(Error::Other("No positions provided".into())); } - info!("Querying {} S3 files", s3_paths.len()); + info!("Querying {} S3 files for tickers: {:?}", s3_paths.len(), tickers); let s3_paths_query = s3_paths .iter() @@ -322,8 +361,10 @@ pub async fn query_predictions_dataframe_from_s3( debug!("Executing export SQL: {}", query); + info!("Preparing predictions query statement"); let mut statement = connection.prepare(&query)?; + info!("Executing predictions query and mapping results"); let predictions: Vec = statement .query_map([], |row| { Ok(Prediction { @@ -335,10 +376,21 @@ pub async fn query_predictions_dataframe_from_s3( }) })? .collect::, _>>() - .map_err(|e| Error::Other(format!("Failed to map query results: {}", e)))?; + .map_err(|e| { + warn!("Failed to map predictions query results: {}", e); + Error::Other(format!("Failed to map query results: {}", e)) + })?; + + info!("Query returned {} prediction records", predictions.len()); + debug!("Creating predictions DataFrame"); let predictions_dataframe = create_predictions_dataframe(predictions)?; + info!( + "Predictions DataFrame created with {} rows", + predictions_dataframe.height() + ); + Ok(predictions_dataframe) } @@ -346,6 +398,10 @@ pub async fn query_portfolio_dataframe_from_s3( state: &State, timestamp: Option>, ) -> Result { + info!( + "Querying portfolio data, timestamp filter: {:?}", + timestamp.map(|ts| ts.to_string()) + ); let connection = create_duckdb_connection().await?; let query = match timestamp { @@ -381,7 +437,10 @@ pub async fn query_portfolio_dataframe_from_s3( "s3://{}/equity/portfolios/daily/**/*.parquet", state.bucket_name ); - info!("Querying most recent portfolio from all files"); + info!( + "Querying most recent portfolio using hive partitioning: {}", + s3_wildcard + ); format!( " @@ -417,8 +476,10 @@ pub async fn query_portfolio_dataframe_from_s3( debug!("Executing query SQL: {}", query); + info!("Preparing portfolio query statement"); let mut statement = connection.prepare(&query)?; + info!("Executing portfolio query and mapping results"); let portfolios: Vec = statement .query_map([], |row| { Ok(Portfolio { @@ -430,10 +491,21 @@ pub async fn query_portfolio_dataframe_from_s3( }) })? .collect::, _>>() - .map_err(|e| Error::Other(format!("Failed to map query results: {}", e)))?; + .map_err(|e| { + warn!("Failed to map portfolio query results: {}", e); + Error::Other(format!("Failed to map query results: {}", e)) + })?; + + info!("Query returned {} portfolio records", portfolios.len()); + debug!("Creating portfolio DataFrame"); let portfolio_dataframe = create_portfolio_dataframe(portfolios)?; + info!( + "Portfolio DataFrame created with {} rows", + portfolio_dataframe.height() + ); + Ok(portfolio_dataframe) } diff --git a/applications/equitypricemodel/src/equitypricemodel/preprocess.py b/applications/equitypricemodel/src/equitypricemodel/preprocess.py index 373b5aecb..01d5c1386 100644 --- a/applications/equitypricemodel/src/equitypricemodel/preprocess.py +++ b/applications/equitypricemodel/src/equitypricemodel/preprocess.py @@ -6,9 +6,7 @@ def filter_equity_bars( minimum_average_close_price: float = 10.0, minimum_average_volume: float = 1_000_000.0, ) -> pl.DataFrame: - data = data.clone() - - return ( + valid_tickers = ( data.group_by("ticker") .agg( average_close_price=pl.col("close_price").mean(), @@ -18,5 +16,7 @@ def filter_equity_bars( (pl.col("average_close_price") > minimum_average_close_price) & (pl.col("average_volume") > minimum_average_volume) ) - .drop(["average_close_price", "average_volume"]) + .select("ticker") ) + + return data.join(valid_tickers, on="ticker", how="semi") diff --git a/infrastructure/Pulumi.production.yaml b/infrastructure/Pulumi.production.yaml index b0df5cac6..d4a15c0e4 100644 --- a/infrastructure/Pulumi.production.yaml +++ b/infrastructure/Pulumi.production.yaml @@ -1,3 +1,3 @@ config: aws:region: - secure: AAABALPeEekY8m4V3bzTX5idnUTmjjjRWfJQit7uhk+w4mxTWyDK5r0= + secure: AAABACXJ+P/p3xcsiwJR6jbVwG3oilgmRIio62xIEJHyIzdid94Sf6M= diff --git a/infrastructure/__main__.py b/infrastructure/__main__.py index 8a70454a3..4ee8e9135 100644 --- a/infrastructure/__main__.py +++ b/infrastructure/__main__.py @@ -7,7 +7,7 @@ account_id = current_identity.account_id -region = aws.get_region().name +region = aws.get_region().region secret = aws.secretsmanager.get_secret( name="pocketsizefund/production/environment_variables", @@ -16,26 +16,95 @@ availability_zone_a = f"{region}a" availability_zone_b = f"{region}b" -datamanager_image = aws.ecr.get_image( - repository_name="pocketsizefund/datamanager-server", - image_tag="latest", +tags = { + "project": "pocketsizefund", + "stack": pulumi.get_stack(), + "manager": "pulumi", +} + +# S3 Data Bucket for storing equity bars, predictions, portfolios +data_bucket = aws.s3.BucketV2( + "data_bucket", + bucket_prefix="pocketsizefund-data-", + tags=tags, + opts=pulumi.ResourceOptions(protect=True), ) -portfoliomanager_image = aws.ecr.get_image( - repository_name="pocketsizefund/portfoliomanager-server", - image_tag="latest", +aws.s3.BucketVersioningV2( + "data_bucket_versioning", + bucket=data_bucket.id, + versioning_configuration=aws.s3.BucketVersioningV2VersioningConfigurationArgs( + status="Enabled", + ), ) -equitypricemodel_image = aws.ecr.get_image( - repository_name="pocketsizefund/equitypricemodel-server", - image_tag="latest", +# S3 Model Artifacts Bucket for storing trained model weights and checkpoints +model_artifacts_bucket = aws.s3.BucketV2( + "model_artifacts_bucket", + bucket_prefix="pocketsizefund-model-artifacts-", + tags=tags, + opts=pulumi.ResourceOptions(protect=True), ) -tags = { - "project": "pocketsizefund", - "stack": pulumi.get_stack(), - "manager": "pulumi", -} +aws.s3.BucketVersioningV2( + "model_artifacts_bucket_versioning", + bucket=model_artifacts_bucket.id, + versioning_configuration=aws.s3.BucketVersioningV2VersioningConfigurationArgs( + status="Enabled", + ), +) + +# ECR Repositories - these must exist before images can be pushed +datamanager_repository = aws.ecr.Repository( + "datamanager_repository", + name="pocketsizefund/datamanager-server", + image_tag_mutability="MUTABLE", + image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( + scan_on_push=True, + ), + tags=tags, + opts=pulumi.ResourceOptions(protect=True), +) + +portfoliomanager_repository = aws.ecr.Repository( + "portfoliomanager_repository", + name="pocketsizefund/portfoliomanager-server", + image_tag_mutability="MUTABLE", + image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( + scan_on_push=True, + ), + tags=tags, + opts=pulumi.ResourceOptions(protect=True), +) + +equitypricemodel_repository = aws.ecr.Repository( + "equitypricemodel_repository", + name="pocketsizefund/equitypricemodel-server", + image_tag_mutability="MUTABLE", + image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( + scan_on_push=True, + ), + tags=tags, + opts=pulumi.ResourceOptions(protect=True), +) + +equitypricemodel_trainer_repository = aws.ecr.Repository( + "equitypricemodel_trainer_repository", + name="pocketsizefund/equitypricemodel-trainer", + image_tag_mutability="MUTABLE", + image_scanning_configuration=aws.ecr.RepositoryImageScanningConfigurationArgs( + scan_on_push=True, + ), + tags=tags, + opts=pulumi.ResourceOptions(protect=True), +) + +# Generate image URIs - these will be used in task definitions +# For initial deployment, use a placeholder that will be updated when images are pushed +datamanager_image_uri = datamanager_repository.repository_url.apply(lambda url: f"{url}:latest") +portfoliomanager_image_uri = portfoliomanager_repository.repository_url.apply(lambda url: f"{url}:latest") +equitypricemodel_image_uri = equitypricemodel_repository.repository_url.apply(lambda url: f"{url}:latest") +equitypricemodel_trainer_image_uri = equitypricemodel_trainer_repository.repository_url.apply(lambda url: f"{url}:latest") vpc = aws.ec2.Vpc( "vpc", @@ -520,17 +589,12 @@ tags=tags, ) -secret_version = aws.secretsmanager.get_secret_version(secret_id=secret.id) -data_bucket_name = pulumi.Output.secret(secret_version.secret_string).apply( - lambda s: json.loads(s)["AWS_S3_DATA_BUCKET_NAME"] -) - aws.iam.RolePolicy( "task_role_s3_policy", name="pocketsizefund-ecs-task-role-s3-policy", role=task_role.id, - policy=data_bucket_name.apply( - lambda name: json.dumps( + policy=pulumi.Output.all(data_bucket.arn, model_artifacts_bucket.arn).apply( + lambda args: json.dumps( { "Version": "2012-10-17", "Statement": [ @@ -538,8 +602,10 @@ "Effect": "Allow", "Action": ["s3:GetObject", "s3:PutObject", "s3:ListBucket"], "Resource": [ - f"arn:aws:s3:::{name}", - f"arn:aws:s3:::{name}/*", + args[0], + f"{args[0]}/*", + args[1], + f"{args[1]}/*", ], } ], @@ -548,6 +614,107 @@ ), ) +# SageMaker Execution Role for training jobs +sagemaker_execution_role = aws.iam.Role( + "sagemaker_execution_role", + name="pocketsizefund-sagemaker-execution-role", + assume_role_policy=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": {"Service": "sagemaker.amazonaws.com"}, + } + ], + } + ), + tags=tags, +) + +aws.iam.RolePolicy( + "sagemaker_s3_policy", + name="pocketsizefund-sagemaker-s3-policy", + role=sagemaker_execution_role.id, + policy=pulumi.Output.all(data_bucket.arn, model_artifacts_bucket.arn).apply( + lambda args: json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket", + ], + "Resource": [ + args[0], + f"{args[0]}/*", + args[1], + f"{args[1]}/*", + ], + } + ], + } + ) + ), +) + +aws.iam.RolePolicy( + "sagemaker_ecr_policy", + name="pocketsizefund-sagemaker-ecr-policy", + role=sagemaker_execution_role.id, + policy=pulumi.Output.all(equitypricemodel_trainer_repository.arn).apply( + lambda args: json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "ecr:GetDownloadUrlForLayer", + "ecr:BatchGetImage", + "ecr:BatchCheckLayerAvailability", + ], + "Resource": args[0], + }, + { + "Effect": "Allow", + "Action": "ecr:GetAuthorizationToken", + "Resource": "*", + }, + ], + } + ) + ), +) + +aws.iam.RolePolicy( + "sagemaker_cloudwatch_policy", + name="pocketsizefund-sagemaker-cloudwatch-policy", + role=sagemaker_execution_role.id, + policy=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "logs:CreateLogGroup", + "logs:CreateLogStream", + "logs:PutLogEvents", + "logs:DescribeLogStreams", + ], + "Resource": "arn:aws:logs:*:*:log-group:/aws/sagemaker/*", + } + ], + } + ), +) + datamanager_log_group = aws.cloudwatch.LogGroup( "datamanager_logs", name="/ecs/pocketsizefund/datamanager", @@ -580,8 +747,9 @@ task_role_arn=task_role.arn, container_definitions=pulumi.Output.all( datamanager_log_group.name, - datamanager_image.image_uri, + datamanager_image_uri, secret.arn, + data_bucket.bucket, ).apply( lambda args: json.dumps( [ @@ -589,11 +757,17 @@ "name": "datamanager", "image": args[1], "portMappings": [{"containerPort": 8080, "protocol": "tcp"}], - "secrets": [ + "environment": [ + { + "name": "MASSIVE_BASE_URL", + "value": "https://api.massive.io", + }, { "name": "AWS_S3_DATA_BUCKET_NAME", - "valueFrom": f"{args[2]}:AWS_S3_DATA_BUCKET_NAME::", + "value": args[3], }, + ], + "secrets": [ { "name": "MASSIVE_API_KEY", "valueFrom": f"{args[2]}:MASSIVE_API_KEY::", @@ -626,7 +800,7 @@ container_definitions=pulumi.Output.all( portfoliomanager_log_group.name, service_discovery_namespace.name, - portfoliomanager_image.image_uri, + portfoliomanager_image_uri, secret.arn, ).apply( lambda args: json.dumps( @@ -686,7 +860,7 @@ container_definitions=pulumi.Output.all( equitypricemodel_log_group.name, service_discovery_namespace.name, - equitypricemodel_image.image_uri, + equitypricemodel_image_uri, ).apply( lambda args: json.dumps( [ @@ -843,8 +1017,16 @@ pulumi.export("aws_alb_dns_name", alb.dns_name) pulumi.export("aws_alb_url", pulumi.Output.concat(protocol, alb.dns_name)) pulumi.export("aws_service_discovery_namespace", service_discovery_namespace.name) -pulumi.export("aws_ecr_datamanager_image", datamanager_image.image_uri) -pulumi.export("aws_ecr_portfoliomanager_image", portfoliomanager_image.image_uri) -pulumi.export("aws_ecr_equitypricemodel_image", equitypricemodel_image.image_uri) +pulumi.export("aws_ecr_datamanager_image", datamanager_image_uri) +pulumi.export("aws_ecr_portfoliomanager_image", portfoliomanager_image_uri) +pulumi.export("aws_ecr_equitypricemodel_image", equitypricemodel_image_uri) +pulumi.export("aws_ecr_datamanager_repository", datamanager_repository.repository_url) +pulumi.export("aws_ecr_portfoliomanager_repository", portfoliomanager_repository.repository_url) +pulumi.export("aws_ecr_equitypricemodel_repository", equitypricemodel_repository.repository_url) +pulumi.export("aws_s3_data_bucket", data_bucket.bucket) +pulumi.export("aws_s3_model_artifacts_bucket", model_artifacts_bucket.bucket) +pulumi.export("aws_ecr_equitypricemodel_trainer_repository", equitypricemodel_trainer_repository.repository_url) +pulumi.export("aws_ecr_equitypricemodel_trainer_image", equitypricemodel_trainer_image_uri) +pulumi.export("aws_iam_sagemaker_role_arn", sagemaker_execution_role.arn) pulumi.export("psf_base_url", psf_base_url) pulumi.export("readme", pulumi.Output.format(readme_content, psf_base_url)) diff --git a/maskfile.md b/maskfile.md index 6def9688a..21d57940e 100644 --- a/maskfile.md +++ b/maskfile.md @@ -497,10 +497,58 @@ mask development python test echo "Python development checks completed successfully" ``` +## data + +> Data management commands + +### sync-categories + +> Sync equity categories (sector/industry) from Polygon API to S3 + +```bash +set -euo pipefail + +echo "Syncing equity categories from Polygon API" + +cd infrastructure +export AWS_S3_DATA_BUCKET="$(pulumi stack output aws_s3_data_bucket)" + +cd ../ + +# Get API key from AWS Secrets Manager +export MASSIVE_API_KEY=$(aws secretsmanager get-secret-value \ + --secret-id pocketsizefund/production/environment_variables \ + --query 'SecretString' \ + --output text | jq -r '.MASSIVE_API_KEY') + +uv run python tools/sync_equity_categories.py + +echo "Categories sync complete" +``` + ## models > Model management commands +### prepare (application_name) + +> Prepare training data by consolidating equity bars with categories + +```bash +set -euo pipefail + +export APPLICATION_NAME="${application_name}" + +cd infrastructure +export AWS_S3_DATA_BUCKET="$(pulumi stack output aws_s3_data_bucket)" +export AWS_S3_MODEL_ARTIFACTS_BUCKET="$(pulumi stack output aws_s3_model_artifacts_bucket)" +export LOOKBACK_DAYS="${LOOKBACK_DAYS:-365}" + +cd ../ + +uv run python tools/prepare_training_data.py +``` + ### train (application_name) > Train machine learning model @@ -510,6 +558,15 @@ set -euo pipefail export APPLICATION_NAME="${application_name}" +cd infrastructure +export AWS_ECR_EQUITY_PRICE_MODEL_TRAINER_IMAGE_ARN="$(pulumi stack output aws_ecr_equitypricemodel_trainer_image)" +export AWS_IAM_SAGEMAKER_ROLE_ARN="$(pulumi stack output aws_iam_sagemaker_role_arn)" +export AWS_S3_MODEL_ARTIFACTS_BUCKET="$(pulumi stack output aws_s3_model_artifacts_bucket)" +export AWS_S3_EQUITY_PRICE_MODEL_ARTIFACT_OUTPUT_PATH="s3://${AWS_S3_MODEL_ARTIFACTS_BUCKET}/artifacts" +export AWS_S3_EQUITY_PRICE_MODEL_TRAINING_DATA_PATH="s3://${AWS_S3_MODEL_ARTIFACTS_BUCKET}/training" + +cd ../ + uv run python tools/run_training_job.py ``` diff --git a/tools/prepare_training_data.py b/tools/prepare_training_data.py new file mode 100644 index 000000000..7b0b10e04 --- /dev/null +++ b/tools/prepare_training_data.py @@ -0,0 +1,260 @@ +"""Prepare consolidated training data from equity bars and categories. + +This script: +1. Reads equity bars from S3 (partitioned parquet) +2. Reads categories CSV from S3 +3. Joins them on ticker +4. Filters by minimum price/volume thresholds +5. Outputs consolidated parquet to S3 for SageMaker training +""" + +import io +import os +import sys +from datetime import UTC, datetime, timedelta + +import boto3 +import polars as pl +import structlog + +logger = structlog.get_logger() + +MINIMUM_CLOSE_PRICE = 1.0 +MINIMUM_VOLUME = 100_000 + + +def read_equity_bars_from_s3( + s3_client: boto3.client, + bucket_name: str, + start_date: datetime, + end_date: datetime, +) -> pl.DataFrame: + """Read equity bars parquet files from S3 for date range.""" + logger.info( + "Reading equity bars from S3", + bucket=bucket_name, + start_date=start_date.strftime("%Y-%m-%d"), + end_date=end_date.strftime("%Y-%m-%d"), + ) + + all_dataframes = [] + current_date = start_date + + while current_date <= end_date: + year = current_date.strftime("%Y") + month = current_date.strftime("%m") + day = current_date.strftime("%d") + + key = f"equity/bars/daily/year={year}/month={month}/day={day}/data.parquet" + + try: + response = s3_client.get_object(Bucket=bucket_name, Key=key) + parquet_bytes = response["Body"].read() + dataframe = pl.read_parquet(parquet_bytes) + all_dataframes.append(dataframe) + logger.debug("Read parquet file", key=key, rows=dataframe.height) + except s3_client.exceptions.NoSuchKey: + logger.debug("No data for date", date=current_date.strftime("%Y-%m-%d")) + except Exception as e: + logger.warning( + "Failed to read parquet file", key=key, error=str(e) + ) + + current_date += timedelta(days=1) + + if not all_dataframes: + message = "No equity bars data found for date range" + raise ValueError(message) + + combined = pl.concat(all_dataframes) + logger.info("Combined equity bars", total_rows=combined.height) + + return combined + + +def read_categories_from_s3( + s3_client: boto3.client, + bucket_name: str, +) -> pl.DataFrame: + """Read categories CSV from S3.""" + key = "equity/details/categories.csv" + + logger.info("Reading categories from S3", bucket=bucket_name, key=key) + + response = s3_client.get_object(Bucket=bucket_name, Key=key) + csv_bytes = response["Body"].read() + categories = pl.read_csv(csv_bytes) + + logger.info("Read categories", rows=categories.height) + + return categories + + +def filter_equity_bars( + data: pl.DataFrame, + minimum_close_price: float = MINIMUM_CLOSE_PRICE, + minimum_volume: int = MINIMUM_VOLUME, +) -> pl.DataFrame: + """Filter equity bars by minimum price and volume thresholds.""" + logger.info( + "Filtering equity bars", + minimum_close_price=minimum_close_price, + minimum_volume=minimum_volume, + input_rows=data.height, + ) + + filtered = data.filter( + (pl.col("close_price") >= minimum_close_price) + & (pl.col("volume") >= minimum_volume) + ) + + logger.info("Filtered equity bars", output_rows=filtered.height) + + return filtered + + +def consolidate_data( + equity_bars: pl.DataFrame, + categories: pl.DataFrame, +) -> pl.DataFrame: + """Join equity bars with categories on ticker.""" + logger.info( + "Consolidating data", + equity_bars_rows=equity_bars.height, + categories_rows=categories.height, + ) + + consolidated = equity_bars.join(categories, on="ticker", how="inner") + + retained_columns = [ + "ticker", + "timestamp", + "open_price", + "high_price", + "low_price", + "close_price", + "volume", + "volume_weighted_average_price", + "sector", + "industry", + ] + + available_columns = [col for col in retained_columns if col in consolidated.columns] + missing_columns = [col for col in retained_columns if col not in consolidated.columns] + + if missing_columns: + logger.warning("Missing columns in consolidated data", missing=missing_columns) + + result = consolidated.select(available_columns) + + logger.info("Consolidated data", output_rows=result.height, columns=available_columns) + + return result + + +def write_training_data_to_s3( + s3_client: boto3.client, + bucket_name: str, + data: pl.DataFrame, + output_key: str, +) -> str: + """Write consolidated training data to S3 as parquet.""" + logger.info( + "Writing training data to S3", + bucket=bucket_name, + key=output_key, + rows=data.height, + ) + + buffer = io.BytesIO() + data.write_parquet(buffer) + parquet_bytes = buffer.getvalue() + + s3_client.put_object( + Bucket=bucket_name, + Key=output_key, + Body=parquet_bytes, + ContentType="application/octet-stream", + ) + + s3_uri = f"s3://{bucket_name}/{output_key}" + logger.info("Wrote training data", s3_uri=s3_uri, size_bytes=len(parquet_bytes)) + + return s3_uri + + +def prepare_training_data( + data_bucket_name: str, + model_artifacts_bucket_name: str, + start_date: datetime, + end_date: datetime, + output_key: str = "training/filtered_tft_training_data.parquet", +) -> str: + """Main function to prepare training data.""" + logger.info( + "Preparing training data", + data_bucket=data_bucket_name, + model_artifacts_bucket=model_artifacts_bucket_name, + start_date=start_date.strftime("%Y-%m-%d"), + end_date=end_date.strftime("%Y-%m-%d"), + ) + + s3_client = boto3.client("s3") + + equity_bars = read_equity_bars_from_s3( + s3_client=s3_client, + bucket_name=data_bucket_name, + start_date=start_date, + end_date=end_date, + ) + + categories = read_categories_from_s3( + s3_client=s3_client, + bucket_name=data_bucket_name, + ) + + filtered_bars = filter_equity_bars(equity_bars) + + consolidated = consolidate_data( + equity_bars=filtered_bars, + categories=categories, + ) + + s3_uri = write_training_data_to_s3( + s3_client=s3_client, + bucket_name=model_artifacts_bucket_name, + data=consolidated, + output_key=output_key, + ) + + return s3_uri + + +if __name__ == "__main__": + data_bucket = os.getenv("AWS_S3_DATA_BUCKET") + model_artifacts_bucket = os.getenv("AWS_S3_MODEL_ARTIFACTS_BUCKET") + lookback_days = int(os.getenv("LOOKBACK_DAYS", "365")) + + if not data_bucket or not model_artifacts_bucket: + logger.error( + "Missing required environment variables", + AWS_S3_DATA_BUCKET=data_bucket, + AWS_S3_MODEL_ARTIFACTS_BUCKET=model_artifacts_bucket, + ) + sys.exit(1) + + end_date = datetime.now(tz=UTC).replace(hour=0, minute=0, second=0, microsecond=0) + start_date = end_date - timedelta(days=lookback_days) + + try: + output_uri = prepare_training_data( + data_bucket_name=data_bucket, + model_artifacts_bucket_name=model_artifacts_bucket, + start_date=start_date, + end_date=end_date, + ) + logger.info("Training data preparation complete", output_uri=output_uri) + + except Exception as e: + logger.exception("Failed to prepare training data", error=str(e)) + sys.exit(1) diff --git a/tools/sync_equity_categories.py b/tools/sync_equity_categories.py new file mode 100644 index 000000000..3241e5245 --- /dev/null +++ b/tools/sync_equity_categories.py @@ -0,0 +1,162 @@ +"""Sync equity categories (sector/industry) from Polygon API to S3. + +This script fetches ticker reference data from Polygon's API and uploads +a categories CSV to S3 for use in training data preparation. + +The CSV contains: ticker, sector, industry +""" + +import os +import sys +import time + +import boto3 +import polars as pl +import requests +import structlog + +logger = structlog.get_logger() + + +def fetch_all_tickers(api_key: str, base_url: str) -> list[dict]: + """Fetch all US stock tickers from Polygon API with pagination.""" + logger.info("Fetching tickers from Polygon API") + + all_tickers = [] + url = f"{base_url}/v3/reference/tickers" + params = { + "market": "stocks", + "active": "true", + "limit": 1000, + "apiKey": api_key, + } + + while url: + logger.debug("Fetching page", url=url) + + response = requests.get(url, params=params, timeout=30) + response.raise_for_status() + + data = response.json() + results = data.get("results", []) + all_tickers.extend(results) + + logger.info("Fetched tickers", count=len(results), total=len(all_tickers)) + + next_url = data.get("next_url") + if next_url: + url = next_url + params = {"apiKey": api_key} + time.sleep(0.25) + else: + url = None + + logger.info("Finished fetching tickers", total=len(all_tickers)) + return all_tickers + + +def extract_categories(tickers: list[dict]) -> pl.DataFrame: + """Extract ticker, sector, industry from ticker data.""" + logger.info("Extracting categories from ticker data") + + rows = [] + for ticker_data in tickers: + ticker = ticker_data.get("ticker", "") + if ticker_data.get("type") not in ("CS", "ADRC"): + continue + + sector = ticker_data.get("sector", "") + industry = ticker_data.get("industry", "") + + if not sector: + sector = "NOT AVAILABLE" + if not industry: + industry = "NOT AVAILABLE" + + rows.append({ + "ticker": ticker.upper(), + "sector": sector.upper(), + "industry": industry.upper(), + }) + + dataframe = pl.DataFrame(rows) + logger.info("Extracted categories", rows=dataframe.height) + + return dataframe + + +def upload_categories_to_s3( + s3_client: boto3.client, + bucket_name: str, + categories: pl.DataFrame, +) -> str: + """Upload categories CSV to S3.""" + key = "equity/details/categories.csv" + + logger.info( + "Uploading categories to S3", + bucket=bucket_name, + key=key, + rows=categories.height, + ) + + csv_bytes = categories.write_csv().encode("utf-8") + + s3_client.put_object( + Bucket=bucket_name, + Key=key, + Body=csv_bytes, + ContentType="text/csv", + ) + + s3_uri = f"s3://{bucket_name}/{key}" + logger.info("Uploaded categories", s3_uri=s3_uri) + + return s3_uri + + +def sync_equity_categories( + api_key: str, + base_url: str, + bucket_name: str, +) -> str: + """Main function to sync equity categories.""" + logger.info("Syncing equity categories", bucket=bucket_name) + + tickers = fetch_all_tickers(api_key, base_url) + categories = extract_categories(tickers) + + s3_client = boto3.client("s3") + s3_uri = upload_categories_to_s3(s3_client, bucket_name, categories) + + return s3_uri + + +if __name__ == "__main__": + api_key = os.getenv("MASSIVE_API_KEY") + base_url = os.getenv("MASSIVE_BASE_URL") + bucket_name = os.getenv("AWS_S3_DATA_BUCKET") + + if not api_key: + logger.error("MASSIVE_API_KEY environment variable not set") + sys.exit(1) + + if not base_url: + logger.error("MASSIVE_BASE_URL environment variable not set") + sys.exit(1) + + if not bucket_name: + logger.error("AWS_S3_DATA_BUCKET environment variable not set") + sys.exit(1) + + try: + output_uri = sync_equity_categories( + api_key=api_key, + base_url=base_url, + bucket_name=bucket_name, + ) + logger.info("Sync complete", output_uri=output_uri) + + except Exception as e: + logger.exception("Failed to sync equity categories", error=str(e)) + sys.exit(1)