Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions application/datamanager/src/datamanager/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
import httpx
import polars as pl
import pyarrow
import pyarrow.lib # for ArrowIOError if using Arrow internally
from duckdb import IOException
import requests

from fastapi import FastAPI, HTTPException, Request, Response, status
from google.api_core import exceptions
from google.api_core.exceptions import GoogleAPIError
from google.cloud import storage # type: ignore
from loguru import logger
from polars.exceptions import ComputeError
from prometheus_fastapi_instrumentator import Instrumentator

from .config import Settings
Expand Down Expand Up @@ -42,7 +48,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str:
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
app.state.settings = Settings()
app.state.bucket = storage.Client(os.getenv("GCP_PROJECT")).bucket(
app.state.settings.gcp.bucket.name
app.state.settings.gcp.bucket.name,
)

DUCKDB_ACCESS_KEY = os.getenv("DUCKDB_ACCESS_KEY")
Expand Down Expand Up @@ -75,12 +81,16 @@ async def health_check() -> Response:

@application.get("/equity-bars")
async def get_equity_bars(
request: Request, start_date: date, end_date: date
request: Request,
start_date: date,
end_date: date,
) -> Response:
settings: Settings = request.app.state.settings

query = bars_query(
bucket=settings.gcp.bucket.name, start_date=start_date, end_date=end_date
bucket=settings.gcp.bucket.name,
start_date=start_date,
end_date=end_date,
)

try:
Expand All @@ -105,7 +115,13 @@ async def get_equity_bars(
},
)

except Exception as e:
except (
requests.RequestException,
ComputeError,
IOException,
GoogleAPIError,
pyarrow.lib.ArrowIOError,
) as e:
logger.error(f"Error querying data: {e}")
logger.error(traceback.format_exc())
return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
Expand Down Expand Up @@ -138,17 +154,24 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars
pl.from_epoch("t", time_unit="ms").dt.year().alias("year"),
pl.from_epoch("t", time_unit="ms").dt.month().alias("month"),
pl.from_epoch("t", time_unit="ms").dt.day().alias("day"),
]
],
).write_parquet(
bucket.daily_bars_path, partition_by=["year", "month", "day"]
bucket.daily_bars_path,
partition_by=["year", "month", "day"],
)
except Exception as e:
except (
requests.RequestException,
ComputeError,
IOException,
GoogleAPIError,
pyarrow.lib.ArrowIOError,
) as e:
logger.error(f"Error writing parquet file: {e}")
logger.error(traceback.format_exc())
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to write data",
)
) from e
return BarsSummary(date=summary_date.date.strftime("%Y-%m-%d"), count=count)


Expand Down
12 changes: 9 additions & 3 deletions application/datamanager/src/datamanager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class SummaryDate(BaseModel):
date: datetime.date = Field(
default_factory=lambda: datetime.datetime.utcnow().date()
default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc).date(),
)

@field_validator("date", mode="before")
Expand All @@ -14,7 +14,11 @@ def parse_date(cls, value: datetime.date | str) -> datetime.date:
return value
for fmt in ("%Y-%m-%d", "%Y/%m/%d"):
try:
return datetime.datetime.strptime(value, fmt).date()
return (
datetime.datetime.strptime(value, fmt)
.replace(tzinfo=datetime.timezone.utc)
.date()
)
except ValueError:
continue
raise ValueError("Invalid date format: expected YYYY-MM-DD or YYYY/MM/DD")
Expand All @@ -29,7 +33,9 @@ class DateRange(BaseModel):
@field_validator("end")
@classmethod
def check_end_after_start(
cls, end_value: datetime.datetime, info: core_schema.ValidationInfo
cls,
end_value: datetime.datetime,
info: core_schema.ValidationInfo,
) -> datetime.datetime:
start_value = info.data.get("start")
if start_value and end_value <= start_value:
Expand Down
5 changes: 4 additions & 1 deletion application/positionmanager/src/positionmanager/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
class AlpacaClient:
def __init__(
self,
*,
api_key: str | None = "",
api_secret: str | None = "",
api_key: str | None = None,
api_secret: str | None = None,
paper: bool = True,
Expand Down Expand Up @@ -89,7 +92,7 @@ def get_data(
pl.col("timestamp")
.str.slice(0, 10)
.str.strptime(pl.Date, "%Y-%m-%d")
.alias("date")
.alias("date"),
)

data = (
Expand Down
30 changes: 17 additions & 13 deletions application/positionmanager/src/positionmanager/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from fastapi import FastAPI, HTTPException
import requests
import os
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
import polars as pl
from typing import Dict, Any
from .models import Money, DateRange, PredictionPayload
from .clients import AlpacaClient, DataClient
from .portfolio import PortfolioOptimizer
from prometheus_fastapi_instrumentator import Instrumentator

from alpaca.common.rest import APIError
from pydantic import ValidationError


trading_days_per_year = 252

Expand Down Expand Up @@ -38,21 +42,21 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
try:
cash_balance = alpaca_client.get_cash_balance()

except Exception as e:
except (requests.RequestException, APIError, ValidationError) as e:
raise HTTPException(
status_code=500,
detail=f"Error getting cash balance: {str(e)}",
) from e

date_range = DateRange(
start=datetime.now() - timedelta(days=trading_days_per_year),
end=datetime.now(),
start=datetime.now(tz=timezone.utc) - timedelta(days=trading_days_per_year),
end=datetime.now(tz=timezone.utc),
)

try:
historical_data = data_client.get_data(date_range=date_range)

except Exception as e:
except (requests.RequestException, APIError, ValidationError) as e:
raise HTTPException(
status_code=500,
detail=f"Error getting historical data: {str(e)}",
Expand All @@ -65,7 +69,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
predictions=payload.predictions,
)

except Exception as e:
except (requests.RequestException, APIError, ValidationError) as e:
raise HTTPException(
status_code=500,
detail=f"Error optimizing portfolio: {str(e)}",
Expand All @@ -77,21 +81,21 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
continue

latest_prices = historical_data.filter(pl.col(ticker).is_not_null()).select(
ticker
ticker,
)
if latest_prices.is_empty():
executed_trades.append(
{
"ticker": ticker,
"status": "error",
"error": "No recent price available",
}
},
)
continue
latest_price = latest_prices.tail(1)[0, 0]

notional_amount = Money.from_float(
latest_price * share_count * 0.95
latest_price * share_count * 0.95,
) # 5% buffer

try:
Expand All @@ -103,18 +107,18 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
"share_count": share_count,
"notional_amount": float(notional_amount),
"status": "success",
}
},
)

except Exception as e:
except (requests.RequestException, APIError, ValidationError) as e:
executed_trades.append(
{
"ticker": ticker,
"share_count": share_count,
"notional_amount": float(notional_amount),
"status": "error",
"error": str(e),
}
},
)

final_cash_balance = alpaca_client.get_cash_balance()
Expand All @@ -140,7 +144,7 @@ def delete_positions() -> Dict[str, Any]:
try:
result = alpaca_client.clear_positions()

except Exception as e:
except (requests.RequestException, APIError, ValidationError) as e:
raise HTTPException(status_code=500, detail=str(e)) from e

cash_balance = alpaca_client.get_cash_balance()
Expand Down
7 changes: 5 additions & 2 deletions application/positionmanager/src/positionmanager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

class Money(BaseModel):
amount: Decimal = Field(
..., description="Monetary value in USD with 2 decimal places"
...,
description="Monetary value in USD with 2 decimal places",
)

@field_validator("amount", check_fields=True)
Expand Down Expand Up @@ -42,7 +43,9 @@ class DateRange(BaseModel):
@field_validator("end")
@classmethod
def check_end_after_start(
cls, end_value: datetime, info: core_schema.ValidationInfo
cls,
end_value: datetime,
info: core_schema.ValidationInfo,
) -> datetime:
start_value = info.data.get("start")
if start_value and end_value <= start_value:
Expand Down
4 changes: 3 additions & 1 deletion application/positionmanager/src/positionmanager/portfolio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def get_optimized_portfolio(

long_only_weight_bounds = (0, 0.2) # 20% max weight per asset
efficient_frontier = EfficientFrontier(
mu, S, weight_bounds=long_only_weight_bounds
mu,
S,
weight_bounds=long_only_weight_bounds,
)

efficient_frontier.max_sharpe(risk_free_rate=0.02) # 2% risk-free rate
Expand Down
16 changes: 12 additions & 4 deletions application/positionmanager/tests/test_positionmanager_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from fastapi.testclient import TestClient
from fastapi import HTTPException

import unittest
from unittest.mock import patch, MagicMock
import polars as pl
Expand Down Expand Up @@ -41,7 +43,7 @@ def test_create_position_success(

mock_data_instance = MagicMock(spec=DataClient)
mock_historical_data = pl.DataFrame(
{"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]}
{"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]},
)
mock_data_instance.get_data.return_value = mock_historical_data
MockDataClient.return_value = mock_data_instance
Expand Down Expand Up @@ -81,7 +83,10 @@ def test_create_position_success(
@patch("application.positionmanager.src.positionmanager.main.AlpacaClient")
def test_create_position_alpaca_error(self, MockAlpacaClient: MagicMock) -> None:
mock_alpaca_instance = MagicMock(spec=AlpacaClient)
mock_alpaca_instance.get_cash_balance.side_effect = Exception("API error")
mock_alpaca_instance.get_cash_balance.side_effect = HTTPException(
status_code=500,
detail="Error getting cash balance",
)
MockAlpacaClient.return_value = mock_alpaca_instance

payload = {"predictions": {"AAPL": 0.8}}
Expand Down Expand Up @@ -121,13 +126,16 @@ def test_delete_positions_error(
MockAlpacaClient: MagicMock,
) -> None:
mock_alpaca_instance = MagicMock(spec=AlpacaClient)
mock_alpaca_instance.clear_positions.side_effect = Exception("API error")
mock_alpaca_instance.clear_positions.side_effect = HTTPException(
status_code=500,
detail="Error getting cash balance",
)
MockAlpacaClient.return_value = mock_alpaca_instance

response = client.delete("/positions")

assert response.status_code == 500
assert "API error" in response.json()["detail"]
assert "Error" in response.json()["detail"]

MockAlpacaClient.assert_called_once()
mock_alpaca_instance.clear_positions.assert_called_once()
7 changes: 4 additions & 3 deletions infrastructure/cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
value=buckets.production_data_bucket.name,
),
cloudrun.ServiceTemplateSpecContainerEnvArgs(
name="DUCKDB_ACCESS_KEY", value=duckdb_access_key
name="DUCKDB_ACCESS_KEY",
value=duckdb_access_key,
),
cloudrun.ServiceTemplateSpecContainerEnvArgs(
name="DUCKDB_SECRET",
Expand All @@ -58,7 +59,7 @@
value=duckdb_secret,
),
],
)
),
],
),
),
Expand All @@ -70,7 +71,7 @@
push_config=pubsub.SubscriptionPushConfigArgs(
push_endpoint=datamanager_service.statuses[0].url,
oidc_token=pubsub.SubscriptionPushConfigOidcTokenArgs(
service_account_email=project.platform_service_account.email
service_account_email=project.platform_service_account.email,
),
),
)
Expand Down
9 changes: 6 additions & 3 deletions infrastructure/images.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from datetime import datetime
from datetime import datetime, timezone
from glob import glob
import pulumi
import pulumi_docker_build as docker_build
Expand All @@ -12,7 +12,10 @@
dockerfile_paths = glob(os.path.join("..", "application", "*", "Dockerfile"))
dockerfile_paths = [os.path.relpath(dockerfile) for dockerfile in dockerfile_paths]

tags = ["latest", datetime.utcnow().strftime("%Y%m%d")]
tags = [
"latest",
datetime.now(tz=timezone.utc).strftime("%Y%m%d"),
]

images = {}
for dockerfile in dockerfile_paths:
Expand All @@ -36,7 +39,7 @@
address="docker.io",
username=dockerhub_username,
password=dockerhub_password,
)
),
],
)

Expand Down
Loading
Loading