diff --git a/application/datamanager/src/datamanager/main.py b/application/datamanager/src/datamanager/main.py index 8d4694072..40768bec3 100644 --- a/application/datamanager/src/datamanager/main.py +++ b/application/datamanager/src/datamanager/main.py @@ -8,10 +8,16 @@ import httpx import polars as pl import pyarrow +import pyarrow.lib # for ArrowIOError if using Arrow internally +from duckdb import IOException +import requests + from fastapi import FastAPI, HTTPException, Request, Response, status from google.api_core import exceptions +from google.api_core.exceptions import GoogleAPIError from google.cloud import storage # type: ignore from loguru import logger +from polars.exceptions import ComputeError from prometheus_fastapi_instrumentator import Instrumentator from .config import Settings @@ -42,7 +48,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str: async def lifespan(app: FastAPI) -> AsyncGenerator[None]: app.state.settings = Settings() app.state.bucket = storage.Client(os.getenv("GCP_PROJECT")).bucket( - app.state.settings.gcp.bucket.name + app.state.settings.gcp.bucket.name, ) DUCKDB_ACCESS_KEY = os.getenv("DUCKDB_ACCESS_KEY") @@ -75,12 +81,16 @@ async def health_check() -> Response: @application.get("/equity-bars") async def get_equity_bars( - request: Request, start_date: date, end_date: date + request: Request, + start_date: date, + end_date: date, ) -> Response: settings: Settings = request.app.state.settings query = bars_query( - bucket=settings.gcp.bucket.name, start_date=start_date, end_date=end_date + bucket=settings.gcp.bucket.name, + start_date=start_date, + end_date=end_date, ) try: @@ -105,7 +115,13 @@ async def get_equity_bars( }, ) - except Exception as e: + except ( + requests.RequestException, + ComputeError, + IOException, + GoogleAPIError, + pyarrow.lib.ArrowIOError, + ) as e: logger.error(f"Error querying data: {e}") logger.error(traceback.format_exc()) return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -138,17 +154,24 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars pl.from_epoch("t", time_unit="ms").dt.year().alias("year"), pl.from_epoch("t", time_unit="ms").dt.month().alias("month"), pl.from_epoch("t", time_unit="ms").dt.day().alias("day"), - ] + ], ).write_parquet( - bucket.daily_bars_path, partition_by=["year", "month", "day"] + bucket.daily_bars_path, + partition_by=["year", "month", "day"], ) - except Exception as e: + except ( + requests.RequestException, + ComputeError, + IOException, + GoogleAPIError, + pyarrow.lib.ArrowIOError, + ) as e: logger.error(f"Error writing parquet file: {e}") logger.error(traceback.format_exc()) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to write data", - ) + ) from e return BarsSummary(date=summary_date.date.strftime("%Y-%m-%d"), count=count) diff --git a/application/datamanager/src/datamanager/models.py b/application/datamanager/src/datamanager/models.py index efd255525..36f1261ed 100644 --- a/application/datamanager/src/datamanager/models.py +++ b/application/datamanager/src/datamanager/models.py @@ -5,7 +5,7 @@ class SummaryDate(BaseModel): date: datetime.date = Field( - default_factory=lambda: datetime.datetime.utcnow().date() + default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc).date(), ) @field_validator("date", mode="before") @@ -14,7 +14,11 @@ def parse_date(cls, value: datetime.date | str) -> datetime.date: return value for fmt in ("%Y-%m-%d", "%Y/%m/%d"): try: - return datetime.datetime.strptime(value, fmt).date() + return ( + datetime.datetime.strptime(value, fmt) + .replace(tzinfo=datetime.timezone.utc) + .date() + ) except ValueError: continue raise ValueError("Invalid date format: expected YYYY-MM-DD or YYYY/MM/DD") @@ -29,7 +33,9 @@ class DateRange(BaseModel): @field_validator("end") @classmethod def check_end_after_start( - cls, end_value: datetime.datetime, info: core_schema.ValidationInfo + cls, + end_value: datetime.datetime, + info: core_schema.ValidationInfo, ) -> datetime.datetime: start_value = info.data.get("start") if start_value and end_value <= start_value: diff --git a/application/positionmanager/src/positionmanager/clients.py b/application/positionmanager/src/positionmanager/clients.py index cf4ff79f8..696d02169 100644 --- a/application/positionmanager/src/positionmanager/clients.py +++ b/application/positionmanager/src/positionmanager/clients.py @@ -12,6 +12,9 @@ class AlpacaClient: def __init__( self, + *, + api_key: str | None = "", + api_secret: str | None = "", api_key: str | None = None, api_secret: str | None = None, paper: bool = True, @@ -89,7 +92,7 @@ def get_data( pl.col("timestamp") .str.slice(0, 10) .str.strptime(pl.Date, "%Y-%m-%d") - .alias("date") + .alias("date"), ) data = ( diff --git a/application/positionmanager/src/positionmanager/main.py b/application/positionmanager/src/positionmanager/main.py index 263f60a28..22f6bd1c6 100644 --- a/application/positionmanager/src/positionmanager/main.py +++ b/application/positionmanager/src/positionmanager/main.py @@ -1,6 +1,7 @@ from fastapi import FastAPI, HTTPException +import requests import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import polars as pl from typing import Dict, Any from .models import Money, DateRange, PredictionPayload @@ -8,6 +9,9 @@ from .portfolio import PortfolioOptimizer from prometheus_fastapi_instrumentator import Instrumentator +from alpaca.common.rest import APIError +from pydantic import ValidationError + trading_days_per_year = 252 @@ -38,21 +42,21 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: try: cash_balance = alpaca_client.get_cash_balance() - except Exception as e: + except (requests.RequestException, APIError, ValidationError) as e: raise HTTPException( status_code=500, detail=f"Error getting cash balance: {str(e)}", ) from e date_range = DateRange( - start=datetime.now() - timedelta(days=trading_days_per_year), - end=datetime.now(), + start=datetime.now(tz=timezone.utc) - timedelta(days=trading_days_per_year), + end=datetime.now(tz=timezone.utc), ) try: historical_data = data_client.get_data(date_range=date_range) - except Exception as e: + except (requests.RequestException, APIError, ValidationError) as e: raise HTTPException( status_code=500, detail=f"Error getting historical data: {str(e)}", @@ -65,7 +69,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: predictions=payload.predictions, ) - except Exception as e: + except (requests.RequestException, APIError, ValidationError) as e: raise HTTPException( status_code=500, detail=f"Error optimizing portfolio: {str(e)}", @@ -77,7 +81,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: continue latest_prices = historical_data.filter(pl.col(ticker).is_not_null()).select( - ticker + ticker, ) if latest_prices.is_empty(): executed_trades.append( @@ -85,13 +89,13 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: "ticker": ticker, "status": "error", "error": "No recent price available", - } + }, ) continue latest_price = latest_prices.tail(1)[0, 0] notional_amount = Money.from_float( - latest_price * share_count * 0.95 + latest_price * share_count * 0.95, ) # 5% buffer try: @@ -103,10 +107,10 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: "share_count": share_count, "notional_amount": float(notional_amount), "status": "success", - } + }, ) - except Exception as e: + except (requests.RequestException, APIError, ValidationError) as e: executed_trades.append( { "ticker": ticker, @@ -114,7 +118,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: "notional_amount": float(notional_amount), "status": "error", "error": str(e), - } + }, ) final_cash_balance = alpaca_client.get_cash_balance() @@ -140,7 +144,7 @@ def delete_positions() -> Dict[str, Any]: try: result = alpaca_client.clear_positions() - except Exception as e: + except (requests.RequestException, APIError, ValidationError) as e: raise HTTPException(status_code=500, detail=str(e)) from e cash_balance = alpaca_client.get_cash_balance() diff --git a/application/positionmanager/src/positionmanager/models.py b/application/positionmanager/src/positionmanager/models.py index 5bfe85584..cbfbc8cb9 100644 --- a/application/positionmanager/src/positionmanager/models.py +++ b/application/positionmanager/src/positionmanager/models.py @@ -8,7 +8,8 @@ class Money(BaseModel): amount: Decimal = Field( - ..., description="Monetary value in USD with 2 decimal places" + ..., + description="Monetary value in USD with 2 decimal places", ) @field_validator("amount", check_fields=True) @@ -42,7 +43,9 @@ class DateRange(BaseModel): @field_validator("end") @classmethod def check_end_after_start( - cls, end_value: datetime, info: core_schema.ValidationInfo + cls, + end_value: datetime, + info: core_schema.ValidationInfo, ) -> datetime: start_value = info.data.get("start") if start_value and end_value <= start_value: diff --git a/application/positionmanager/src/positionmanager/portfolio.py b/application/positionmanager/src/positionmanager/portfolio.py index 6903f2e87..b0f6eadc7 100644 --- a/application/positionmanager/src/positionmanager/portfolio.py +++ b/application/positionmanager/src/positionmanager/portfolio.py @@ -40,7 +40,9 @@ def get_optimized_portfolio( long_only_weight_bounds = (0, 0.2) # 20% max weight per asset efficient_frontier = EfficientFrontier( - mu, S, weight_bounds=long_only_weight_bounds + mu, + S, + weight_bounds=long_only_weight_bounds, ) efficient_frontier.max_sharpe(risk_free_rate=0.02) # 2% risk-free rate diff --git a/application/positionmanager/tests/test_positionmanager_main.py b/application/positionmanager/tests/test_positionmanager_main.py index 1db0b3715..357bf4cf0 100644 --- a/application/positionmanager/tests/test_positionmanager_main.py +++ b/application/positionmanager/tests/test_positionmanager_main.py @@ -1,4 +1,6 @@ from fastapi.testclient import TestClient +from fastapi import HTTPException + import unittest from unittest.mock import patch, MagicMock import polars as pl @@ -41,7 +43,7 @@ def test_create_position_success( mock_data_instance = MagicMock(spec=DataClient) mock_historical_data = pl.DataFrame( - {"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]} + {"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]}, ) mock_data_instance.get_data.return_value = mock_historical_data MockDataClient.return_value = mock_data_instance @@ -81,7 +83,10 @@ def test_create_position_success( @patch("application.positionmanager.src.positionmanager.main.AlpacaClient") def test_create_position_alpaca_error(self, MockAlpacaClient: MagicMock) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) - mock_alpaca_instance.get_cash_balance.side_effect = Exception("API error") + mock_alpaca_instance.get_cash_balance.side_effect = HTTPException( + status_code=500, + detail="Error getting cash balance", + ) MockAlpacaClient.return_value = mock_alpaca_instance payload = {"predictions": {"AAPL": 0.8}} @@ -121,13 +126,16 @@ def test_delete_positions_error( MockAlpacaClient: MagicMock, ) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) - mock_alpaca_instance.clear_positions.side_effect = Exception("API error") + mock_alpaca_instance.clear_positions.side_effect = HTTPException( + status_code=500, + detail="Error getting cash balance", + ) MockAlpacaClient.return_value = mock_alpaca_instance response = client.delete("/positions") assert response.status_code == 500 - assert "API error" in response.json()["detail"] + assert "Error" in response.json()["detail"] MockAlpacaClient.assert_called_once() mock_alpaca_instance.clear_positions.assert_called_once() diff --git a/infrastructure/cloud_run.py b/infrastructure/cloud_run.py index 5daecd9a6..efe3517e3 100644 --- a/infrastructure/cloud_run.py +++ b/infrastructure/cloud_run.py @@ -40,7 +40,8 @@ value=buckets.production_data_bucket.name, ), cloudrun.ServiceTemplateSpecContainerEnvArgs( - name="DUCKDB_ACCESS_KEY", value=duckdb_access_key + name="DUCKDB_ACCESS_KEY", + value=duckdb_access_key, ), cloudrun.ServiceTemplateSpecContainerEnvArgs( name="DUCKDB_SECRET", @@ -58,7 +59,7 @@ value=duckdb_secret, ), ], - ) + ), ], ), ), @@ -70,7 +71,7 @@ push_config=pubsub.SubscriptionPushConfigArgs( push_endpoint=datamanager_service.statuses[0].url, oidc_token=pubsub.SubscriptionPushConfigOidcTokenArgs( - service_account_email=project.platform_service_account.email + service_account_email=project.platform_service_account.email, ), ), ) diff --git a/infrastructure/images.py b/infrastructure/images.py index dfd2e1e2d..4c2190810 100644 --- a/infrastructure/images.py +++ b/infrastructure/images.py @@ -1,5 +1,5 @@ import os -from datetime import datetime +from datetime import datetime, timezone from glob import glob import pulumi import pulumi_docker_build as docker_build @@ -12,7 +12,10 @@ dockerfile_paths = glob(os.path.join("..", "application", "*", "Dockerfile")) dockerfile_paths = [os.path.relpath(dockerfile) for dockerfile in dockerfile_paths] -tags = ["latest", datetime.utcnow().strftime("%Y%m%d")] +tags = [ + "latest", + datetime.now(tz=timezone.utc).strftime("%Y%m%d"), +] images = {} for dockerfile in dockerfile_paths: @@ -36,7 +39,7 @@ address="docker.io", username=dockerhub_username, password=dockerhub_password, - ) + ), ], ) diff --git a/infrastructure/monitoring.py b/infrastructure/monitoring.py index bf5cbdd81..5ff0cc8b8 100644 --- a/infrastructure/monitoring.py +++ b/infrastructure/monitoring.py @@ -23,14 +23,14 @@ cloudrun.ServiceTemplateSpecContainerVolumeMountArgs( name="prometheus-config", mount_path="/etc/prometheus", - ) + ), ], ports=[ cloudrun.ServiceTemplateSpecContainerPortArgs( - container_port=9090 - ) + container_port=9090, + ), ], - ) + ), ], volumes=[ cloudrun.ServiceTemplateSpecVolumeArgs( @@ -41,12 +41,12 @@ cloudrun.ServiceTemplateSpecVolumeSecretItemArgs( path="prometheus.yaml", version=prometheus_config_version.version, - ) + ), ], ), - ) + ), ], - ) + ), ), ) @@ -61,11 +61,11 @@ image="grafana/grafana:latest", ports=[ cloudrun.ServiceTemplateSpecContainerPortArgs( - container_port=3000 - ) + container_port=3000, + ), ], - ) + ), ], - ) + ), ), ) diff --git a/infrastructure/project.py b/infrastructure/project.py index 39b9c8bf6..04ad148c5 100644 --- a/infrastructure/project.py +++ b/infrastructure/project.py @@ -55,6 +55,6 @@ project=PROJECT, role="roles/pubsub.subscriber", member=platform_service_account.email.apply( - lambda e: f"serviceAccount:{e}" + lambda e: f"serviceAccount:{e}", ), # ty: ignore[missing-argument] ) diff --git a/pyproject.toml b/pyproject.toml index ada299f6e..fe54c9568 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,13 +67,24 @@ output = "coverage_output/.python_coverage.xml" [tool.ruff.lint] select = [ + "A", # flake8 builtins "ANN", # type annotations "ASYNC", + "B", # bugbear + "COM", # commas + "C4", # comprehensions + "BLE", # no blind exceptions + "DTZ", # datetimes "ERA", # dead code "FAST", # fastapi + "FBT", # boolean traps "S", # bandit (security) "YTT" # flake8 ] +ignore = [ + "COM812", +] + [tool.ruff.lint.per-file-ignores] "**/tests/**/*.py" = ["S101"] "**/features/steps/**/*.py" = ["S101"]