diff --git a/Cargo.lock b/Cargo.lock index 4ed7eac45..57bed4a6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1415,6 +1415,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.106", +] + +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.106", +] + [[package]] name = "datamanager" version = "0.1.0" @@ -1431,12 +1466,14 @@ dependencies = [ "reqwest", "serde", "serde_json", + "thiserror", "tokio", "tokio-test", "tower", "tower-http", "tracing", "tracing-subscriber", + "validator", ] [[package]] @@ -2322,6 +2359,22 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "idna" version = "1.1.0" @@ -3620,6 +3673,30 @@ dependencies = [ "toml_edit", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -4623,6 +4700,12 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "strum" version = "0.27.2" @@ -5085,6 +5168,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.19" @@ -5140,7 +5229,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", - "idna", + "idna 1.1.0", "percent-encoding", "serde", ] @@ -5169,6 +5258,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "validator" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db79c75af171630a3148bd3e6d7c4f42b6a9a014c2945bc5ed0020cbb8d9478e" +dependencies = [ + "idna 0.5.0", + "once_cell", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", + "validator_derive", +] + +[[package]] +name = "validator_derive" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0bcf92720c40105ac4b2dda2a4ea3aa717d4d6a862cc217da653a4bd5c6b10" +dependencies = [ + "darling", + "once_cell", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "valuable" version = "0.1.1" diff --git a/applications/datamanager/.claude/settings.local.json b/applications/datamanager/.claude/settings.local.json index 83c7d98f3..868bdffee 100644 --- a/applications/datamanager/.claude/settings.local.json +++ b/applications/datamanager/.claude/settings.local.json @@ -19,4 +19,5 @@ "deny": [], "defaultMode": "acceptEdits" } -} \ No newline at end of file +} + diff --git a/applications/datamanager/Cargo.toml b/applications/datamanager/Cargo.toml index c496d9e3e..4453121a3 100644 --- a/applications/datamanager/Cargo.toml +++ b/applications/datamanager/Cargo.toml @@ -14,7 +14,14 @@ path = "src/main.rs" [dependencies] axum = "0.8.4" chrono = { version = "0.4.41", features = ["serde"] } -polars = { version = "0.50.0", features = ["json", "lazy", "parquet", "temporal"] } +polars = { version = "0.50.0", features = [ + "json", + "lazy", + "parquet", + "temporal", + "serde", + "polars-io", +] } reqwest = "0.12.23" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.143" @@ -27,6 +34,8 @@ aws-config = "1.5.8" aws-sdk-s3 = "1.48.0" aws-credential-types = "1.2.6" duckdb = { version = "1.0", features = ["r2d2", "chrono"] } +validator = { version = "0.18", features = ["derive"] } +thiserror = "2.0.3" [dev-dependencies] tokio-test = "0.4" diff --git a/applications/datamanager/src/lib.rs b/applications/datamanager/src/lib.rs index f2ece13f8..a0586712c 100644 --- a/applications/datamanager/src/lib.rs +++ b/applications/datamanager/src/lib.rs @@ -2,11 +2,11 @@ use aws_sdk_s3::Client as S3Client; use axum::{routing::get, Router}; use reqwest::Client; use tower_http::trace::TraceLayer; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; pub mod routes; use routes::equity; use routes::health; +use routes::prediction; #[derive(Clone)] pub struct AlpacaSecrets { @@ -69,8 +69,8 @@ pub async fn create_app() -> Router { Router::::new() .route("/health", get(health::check)) + .merge(prediction::router()) .merge(equity::router()) .with_state(state) .layer(TraceLayer::new_for_http()) } - diff --git a/applications/datamanager/src/routes/equity.rs b/applications/datamanager/src/routes/equity.rs index 08a0738bf..d78c5a62e 100644 --- a/applications/datamanager/src/routes/equity.rs +++ b/applications/datamanager/src/routes/equity.rs @@ -32,16 +32,14 @@ struct DateRangeQuery { struct BarResult { #[serde(rename = "T")] ticker: String, - // TODO: money types - c: Option, - h: Option, - l: Option, - n: Option, - o: Option, - // otc: bool, - t: i64, - v: Option, - vw: Option, + c: Option, + h: Option, + l: Option, + n: Option, + o: Option, + t: u64, + v: Option, + vw: Option, } #[derive(serde::Deserialize, Debug)] @@ -337,12 +335,12 @@ async fn sync(State(state): State, payload: Json) -> impl I let tickers: Vec = bars.iter().map(|b| b.ticker.clone()).collect(); let volumes: Vec> = bars.iter().map(|b| b.v.map(|v| v as u64)).collect(); - let vw_prices: Vec> = bars.iter().map(|b| b.vw).collect(); - let open_prices: Vec> = bars.iter().map(|b| b.o).collect(); - let close_prices: Vec> = bars.iter().map(|b| b.c).collect(); - let high_prices: Vec> = bars.iter().map(|b| b.h).collect(); - let low_prices: Vec> = bars.iter().map(|b| b.l).collect(); - let timestamps: Vec = bars.iter().map(|b| b.t).collect(); + let vw_prices: Vec> = bars.iter().map(|b| b.vw.map(|vw| vw as f64)).collect(); + let open_prices: Vec> = bars.iter().map(|b| b.o.map(|o| o as f64)).collect(); + let close_prices: Vec> = bars.iter().map(|b| b.c.map(|c| c as f64)).collect(); + let high_prices: Vec> = bars.iter().map(|b| b.h.map(|h| h as f64)).collect(); + let low_prices: Vec> = bars.iter().map(|b| b.l.map(|l| l as f64)).collect(); + let timestamps: Vec = bars.iter().map(|b| b.t as i64).collect(); let num_transactions: Vec> = bars.iter().map(|b| b.n.map(|n| n as u64)).collect(); let df_result = df! { diff --git a/applications/datamanager/src/routes/mod.rs b/applications/datamanager/src/routes/mod.rs index 9dc2d3a56..bf8377c7a 100644 --- a/applications/datamanager/src/routes/mod.rs +++ b/applications/datamanager/src/routes/mod.rs @@ -1,2 +1,3 @@ pub mod equity; pub mod health; +pub mod prediction; diff --git a/applications/datamanager/src/routes/prediction.rs b/applications/datamanager/src/routes/prediction.rs new file mode 100644 index 000000000..15c9fc62a --- /dev/null +++ b/applications/datamanager/src/routes/prediction.rs @@ -0,0 +1,292 @@ +use crate::AppState; +use aws_credential_types::provider::error::CredentialsError; +use aws_credential_types::provider::ProvideCredentials; +use aws_sdk_s3::primitives::ByteStream; +use axum::{ + body::Body, + extract::{Json, State}, + http::{header, StatusCode}, + response::{IntoResponse, Response}, + routing::{get, post}, + Router, +}; +use chrono::{DateTime, Utc}; +use duckdb::{Connection, Error as DuckError}; +use polars::prelude::*; +use serde::Serialize; +use std::io::Cursor; +use thiserror::Error as ThisError; +use tracing::{debug, info}; + +#[derive(ThisError, Debug)] +enum Error { + #[error("DuckDB error: {0}")] + DuckDBError(#[from] DuckError), + #[error("Credentials error: {0}")] + CredentialsError(#[from] CredentialsError), + #[error("Other error: {0}")] + OtherError(String), +} + +#[derive(serde::Deserialize)] +struct SavePredictionsPayload { + data: DataFrame, + timestamp: DateTime, +} + +#[derive(serde::Deserialize)] +struct QueryPredictionsPayload { + positions: Vec, + #[allow(dead_code)] + timestamp: DateTime, +} + +#[derive(serde::Deserialize)] +struct QueryPredictionsPositionPayload { + ticker: String, + timestamp: DateTime, +} + +#[derive(Debug, Serialize)] +struct Prediction { + ticker: String, + timestamp: i64, + quantile_10: f64, + quantile_50: f64, + quantile_90: f64, +} + +async fn save_prediction( + State(state): State, + Json(payload): Json, +) -> impl IntoResponse { + let predictions = payload.data; + + let timestamp = payload.timestamp; + + match upload_dataframe_to_s3(&state, &predictions, ×tamp).await { + Ok(s3_key) => { + info!("Successfully uploaded DataFrame to S3 at key: {}", s3_key); + let response_message = format!( + "DataFrame created with {} rows and uploaded to S3: {}", + predictions.height(), + s3_key + ); + + (StatusCode::OK, response_message) + } + Err(err) => { + info!("Failed to upload to S3: {}", err); + let json_output = predictions.to_string(); + + ( + StatusCode::OK, + format!("S3 upload failed: {}\n\n{}", err, json_output), + ) + } + } +} + +async fn upload_dataframe_to_s3( + state: &AppState, + dataframe: &DataFrame, + date: &DateTime, +) -> Result { + info!("Uploading predictions DataFrame to S3 as parquet"); + + let year = date.format("%Y"); + let month = date.format("%m"); + let day = date.format("%d"); + + let key = format!( + "equity/predictions/daily/year={}/month={}/day={}/data.parquet", + year, month, day, + ); + + let mut buffer = Vec::new(); + { + let cursor = Cursor::new(&mut buffer); + let writer = ParquetWriter::new(cursor); + match writer.finish(&mut dataframe.clone()) { + Ok(_) => { + println!( + "DataFrame successfully converted to parquet, size: {} bytes", + buffer.len() + ); + } + Err(err) => { + return Err(Error::OtherError(format!( + "Failed to write parquet: {}", + err + ))); + } + } + } + + let body = ByteStream::from(buffer); + + match state + .s3_client + .put_object() + .bucket(&state.bucket_name) + .key(&key) + .body(body) + .content_type("application/octet-stream") + .send() + .await + { + Ok(_) => { + info!( + "Successfully uploaded parquet file to s3://{}/{}", + state.bucket_name, key + ); + Ok(key) + } + Err(err) => Err(Error::OtherError(format!( + "Failed to upload to S3: {}", + err + ))), + } +} + +async fn query_prediction( + State(state): State, + Json(payload): Json, +) -> impl IntoResponse { + info!("Fetching equity data from S3 partitioned files"); + + match query_s3_parquet_data(&state, payload.positions).await { + Ok(dataframe) => { + let json_string = dataframe.to_string(); + let mut response = Response::new(Body::from(json_string)); + response + .headers_mut() + .insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); + *response.status_mut() = StatusCode::OK; + response + } + Err(err) => { + info!("Failed to query S3 data: {}", err); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Query failed: {}", err), + ) + .into_response() + } + } +} + +async fn query_s3_parquet_data( + state: &AppState, + positions: Vec, +) -> Result { + let connection = Connection::open_in_memory()?; + + connection.execute_batch("INSTALL httpfs; LOAD httpfs;")?; + + let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await; + let provider = config + .credentials_provider() + .ok_or_else(|| Error::OtherError("No AWS credentials provider found".into()))?; + let credentials = provider.provide_credentials().await?; + let region = config + .region() + .map(|r| r.as_ref().to_string()) + .unwrap_or_else(|| "us-east-1".to_string()); + let session_token = credentials.session_token().unwrap_or_default(); + let s3_config = format!( + " + SET s3_region='{}'; + SET s3_url_style='path'; + SET s3_access_key_id='{}'; + SET s3_secret_access_key='{}'; + SET s3_session_token='{}'; + ", + region, + credentials.access_key_id(), + credentials.secret_access_key(), + session_token + ); + + connection.execute_batch(&s3_config)?; + + let mut s3_paths = Vec::new(); + let mut tickers = Vec::new(); + + for position in positions { + let year = position.timestamp.format("%Y"); + let month = position.timestamp.format("%m"); + let day = position.timestamp.format("%d"); + + let s3_path = format!( + "s3://{}/equity/predictions/daily/year={}/month={}/day={}/data.parquet", + state.bucket_name, year, month, day + ); + + s3_paths.push(s3_path); + + tickers.push(position.ticker); + } + + info!("Querying {} S3 files", s3_paths.len()); + + let s3_paths_query = s3_paths + .iter() + .map(|path| format!("SELECT * FROM '{}'", path)) + .collect::>() + .join(" UNION ALL "); + + let tickers_query = tickers + .iter() + .map(|ticker| format!("'{}'", ticker)) + .collect::>() + .join(", "); + + let query = format!( + " + SELECT + ticker, + timestamp, + quantile_10, + quantile_50, + quantile_90 + FROM ({}) + WHERE ticker IN ({}) + ORDER BY timestamp, ticker + ", + s3_paths_query, tickers_query, + ); + + debug!("Executing export SQL: {}", query); + + let mut statement = connection.prepare(&query)?; + + let predictions_iterator = statement.query_map([], |row| { + Ok(Prediction { + ticker: row.get(0)?, + timestamp: row.get(1)?, + quantile_10: row.get(2)?, + quantile_50: row.get(3)?, + quantile_90: row.get(4)?, + }) + })?; + + let predictions: Vec = predictions_iterator + .collect::, _>>() + .map_err(|e| Error::OtherError(format!("Failed to collect predictions: {}", e)))?; + + df!( + "ticker" => predictions.iter().map(|p| p.ticker.as_str()).collect::>(), + "timestamp" => predictions.iter().map(|p| p.timestamp).collect::>(), + "quantile_10" => predictions.iter().map(|p| p.quantile_10).collect::>(), + "quantile_50" => predictions.iter().map(|p| p.quantile_50).collect::>(), + "quantile_90" => predictions.iter().map(|p| p.quantile_90).collect::>(), + ) + .map_err(|e| Error::OtherError(format!("Failed to create DataFrame: {}", e))) +} + +pub fn router() -> Router { + Router::new() + .route("/predictions", post(save_prediction)) + .route("/predictions", get(query_prediction)) +}