diff --git a/application/datamanager/src/datamanager/main.py b/application/datamanager/src/datamanager/main.py index dbaa6787e..779a6809d 100644 --- a/application/datamanager/src/datamanager/main.py +++ b/application/datamanager/src/datamanager/main.py @@ -48,7 +48,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str: async def lifespan(app: FastAPI) -> AsyncGenerator[None]: app.state.settings = Settings() app.state.bucket = storage.Client(os.getenv("GCP_PROJECT")).bucket( - app.state.settings.gcp.bucket.name + app.state.settings.gcp.bucket.name, ) DUCKDB_ACCESS_KEY = os.getenv("DUCKDB_ACCESS_KEY") @@ -81,12 +81,16 @@ async def health_check() -> Response: @application.get("/equity-bars") async def get_equity_bars( - request: Request, start_date: date, end_date: date + request: Request, + start_date: date, + end_date: date, ) -> Response: settings: Settings = request.app.state.settings query = bars_query( - bucket=settings.gcp.bucket.name, start_date=start_date, end_date=end_date + bucket=settings.gcp.bucket.name, + start_date=start_date, + end_date=end_date, ) try: @@ -150,9 +154,10 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars pl.from_epoch("t", time_unit="ms").dt.year().alias("year"), pl.from_epoch("t", time_unit="ms").dt.month().alias("month"), pl.from_epoch("t", time_unit="ms").dt.day().alias("day"), - ] + ], ).write_parquet( - bucket.daily_bars_path, partition_by=["year", "month", "day"] + bucket.daily_bars_path, + partition_by=["year", "month", "day"], ) except ( requests.RequestException, @@ -166,7 +171,7 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to write data", - ) + ) from e return BarsSummary(date=summary_date.date.strftime("%Y-%m-%d"), count=count) diff --git a/application/datamanager/src/datamanager/models.py b/application/datamanager/src/datamanager/models.py index efd255525..36f1261ed 100644 --- a/application/datamanager/src/datamanager/models.py +++ b/application/datamanager/src/datamanager/models.py @@ -5,7 +5,7 @@ class SummaryDate(BaseModel): date: datetime.date = Field( - default_factory=lambda: datetime.datetime.utcnow().date() + default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc).date(), ) @field_validator("date", mode="before") @@ -14,7 +14,11 @@ def parse_date(cls, value: datetime.date | str) -> datetime.date: return value for fmt in ("%Y-%m-%d", "%Y/%m/%d"): try: - return datetime.datetime.strptime(value, fmt).date() + return ( + datetime.datetime.strptime(value, fmt) + .replace(tzinfo=datetime.timezone.utc) + .date() + ) except ValueError: continue raise ValueError("Invalid date format: expected YYYY-MM-DD or YYYY/MM/DD") @@ -29,7 +33,9 @@ class DateRange(BaseModel): @field_validator("end") @classmethod def check_end_after_start( - cls, end_value: datetime.datetime, info: core_schema.ValidationInfo + cls, + end_value: datetime.datetime, + info: core_schema.ValidationInfo, ) -> datetime.datetime: start_value = info.data.get("start") if start_value and end_value <= start_value: diff --git a/application/positionmanager/src/positionmanager/clients.py b/application/positionmanager/src/positionmanager/clients.py index c4e95bf41..595bc1e22 100644 --- a/application/positionmanager/src/positionmanager/clients.py +++ b/application/positionmanager/src/positionmanager/clients.py @@ -85,7 +85,7 @@ def get_data( pl.col("timestamp") .str.slice(0, 10) .str.strptime(pl.Date, "%Y-%m-%d") - .alias("date") + .alias("date"), ) data = ( diff --git a/application/positionmanager/src/positionmanager/main.py b/application/positionmanager/src/positionmanager/main.py index cba438523..9cb789873 100644 --- a/application/positionmanager/src/positionmanager/main.py +++ b/application/positionmanager/src/positionmanager/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI, HTTPException import requests import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import polars as pl from typing import Dict, Any @@ -51,8 +51,8 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: ) from e date_range = DateRange( - start=datetime.now() - timedelta(days=trading_days_per_year), - end=datetime.now(), + start=datetime.now(tz=timezone.utc) - timedelta(days=trading_days_per_year), + end=datetime.now(tz=timezone.utc), ) try: @@ -83,7 +83,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: continue latest_prices = historical_data.filter(pl.col(ticker).is_not_null()).select( - ticker + ticker, ) if latest_prices.is_empty(): executed_trades.append( @@ -91,13 +91,13 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: "ticker": ticker, "status": "error", "error": "No recent price available", - } + }, ) continue latest_price = latest_prices.tail(1)[0, 0] notional_amount = Money.from_float( - latest_price * share_count * 0.95 + latest_price * share_count * 0.95, ) # 5% buffer try: @@ -109,7 +109,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: "share_count": share_count, "notional_amount": float(notional_amount), "status": "success", - } + }, ) except (requests.RequestException, APIError, ValidationError) as e: @@ -120,7 +120,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]: "notional_amount": float(notional_amount), "status": "error", "error": str(e), - } + }, ) final_cash_balance = alpaca_client.get_cash_balance() diff --git a/application/positionmanager/src/positionmanager/models.py b/application/positionmanager/src/positionmanager/models.py index 5bfe85584..cbfbc8cb9 100644 --- a/application/positionmanager/src/positionmanager/models.py +++ b/application/positionmanager/src/positionmanager/models.py @@ -8,7 +8,8 @@ class Money(BaseModel): amount: Decimal = Field( - ..., description="Monetary value in USD with 2 decimal places" + ..., + description="Monetary value in USD with 2 decimal places", ) @field_validator("amount", check_fields=True) @@ -42,7 +43,9 @@ class DateRange(BaseModel): @field_validator("end") @classmethod def check_end_after_start( - cls, end_value: datetime, info: core_schema.ValidationInfo + cls, + end_value: datetime, + info: core_schema.ValidationInfo, ) -> datetime: start_value = info.data.get("start") if start_value and end_value <= start_value: diff --git a/application/positionmanager/src/positionmanager/portfolio.py b/application/positionmanager/src/positionmanager/portfolio.py index 6903f2e87..b0f6eadc7 100644 --- a/application/positionmanager/src/positionmanager/portfolio.py +++ b/application/positionmanager/src/positionmanager/portfolio.py @@ -40,7 +40,9 @@ def get_optimized_portfolio( long_only_weight_bounds = (0, 0.2) # 20% max weight per asset efficient_frontier = EfficientFrontier( - mu, S, weight_bounds=long_only_weight_bounds + mu, + S, + weight_bounds=long_only_weight_bounds, ) efficient_frontier.max_sharpe(risk_free_rate=0.02) # 2% risk-free rate diff --git a/application/positionmanager/tests/test_positionmanager_main.py b/application/positionmanager/tests/test_positionmanager_main.py index 9de8f0991..357bf4cf0 100644 --- a/application/positionmanager/tests/test_positionmanager_main.py +++ b/application/positionmanager/tests/test_positionmanager_main.py @@ -43,7 +43,7 @@ def test_create_position_success( mock_data_instance = MagicMock(spec=DataClient) mock_historical_data = pl.DataFrame( - {"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]} + {"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]}, ) mock_data_instance.get_data.return_value = mock_historical_data MockDataClient.return_value = mock_data_instance @@ -84,7 +84,8 @@ def test_create_position_success( def test_create_position_alpaca_error(self, MockAlpacaClient: MagicMock) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) mock_alpaca_instance.get_cash_balance.side_effect = HTTPException( - status_code=500, detail="Error getting cash balance" + status_code=500, + detail="Error getting cash balance", ) MockAlpacaClient.return_value = mock_alpaca_instance @@ -126,7 +127,8 @@ def test_delete_positions_error( ) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) mock_alpaca_instance.clear_positions.side_effect = HTTPException( - status_code=500, detail="Error getting cash balance" + status_code=500, + detail="Error getting cash balance", ) MockAlpacaClient.return_value = mock_alpaca_instance diff --git a/infrastructure/cloud_run.py b/infrastructure/cloud_run.py index 58d8c4894..0ebc65d94 100644 --- a/infrastructure/cloud_run.py +++ b/infrastructure/cloud_run.py @@ -40,14 +40,15 @@ value=buckets.production_data_bucket.name, ), cloudrun.ServiceTemplateSpecContainerEnvArgs( - name="DUCKDB_ACCESS_KEY", value=duckdb_access_key + name="DUCKDB_ACCESS_KEY", + value=duckdb_access_key, ), cloudrun.ServiceTemplateSpecContainerEnvArgs( name="DUCKDB_SECRET", value=duckdb_secret, ), ], - ) + ), ], ), ), @@ -59,7 +60,7 @@ push_config=pubsub.SubscriptionPushConfigArgs( push_endpoint=service.statuses[0].url, oidc_token=pubsub.SubscriptionPushConfigOidcTokenArgs( - service_account_email=project.platform_service_account.email + service_account_email=project.platform_service_account.email, ), ), ) diff --git a/infrastructure/images.py b/infrastructure/images.py index f29039add..de60cd0de 100644 --- a/infrastructure/images.py +++ b/infrastructure/images.py @@ -1,5 +1,5 @@ import os -from datetime import datetime +from datetime import datetime, timezone from glob import glob import pulumi import pulumi_docker_build as docker_build @@ -12,7 +12,10 @@ dockerfile_paths = glob(os.path.join("..", "application", "*", "Dockerfile")) dockerfile_paths = [os.path.relpath(dockerfile) for dockerfile in dockerfile_paths] -tags = ["latest", datetime.utcnow().strftime("%Y%m%d")] +tags = [ + "latest", + datetime.now(tz=timezone.utc).strftime("%Y%m%d"), +] images = {} for dockerfile in dockerfile_paths: @@ -36,7 +39,7 @@ address="docker.io", username=dockerhub_username, password=dockerhub_password, - ) + ), ], ) diff --git a/infrastructure/monitoring.py b/infrastructure/monitoring.py index bf5cbdd81..5ff0cc8b8 100644 --- a/infrastructure/monitoring.py +++ b/infrastructure/monitoring.py @@ -23,14 +23,14 @@ cloudrun.ServiceTemplateSpecContainerVolumeMountArgs( name="prometheus-config", mount_path="/etc/prometheus", - ) + ), ], ports=[ cloudrun.ServiceTemplateSpecContainerPortArgs( - container_port=9090 - ) + container_port=9090, + ), ], - ) + ), ], volumes=[ cloudrun.ServiceTemplateSpecVolumeArgs( @@ -41,12 +41,12 @@ cloudrun.ServiceTemplateSpecVolumeSecretItemArgs( path="prometheus.yaml", version=prometheus_config_version.version, - ) + ), ], ), - ) + ), ], - ) + ), ), ) @@ -61,11 +61,11 @@ image="grafana/grafana:latest", ports=[ cloudrun.ServiceTemplateSpecContainerPortArgs( - container_port=3000 - ) + container_port=3000, + ), ], - ) + ), ], - ) + ), ), ) diff --git a/infrastructure/project.py b/infrastructure/project.py index 39b9c8bf6..04ad148c5 100644 --- a/infrastructure/project.py +++ b/infrastructure/project.py @@ -55,6 +55,6 @@ project=PROJECT, role="roles/pubsub.subscriber", member=platform_service_account.email.apply( - lambda e: f"serviceAccount:{e}" + lambda e: f"serviceAccount:{e}", ), # ty: ignore[missing-argument] ) diff --git a/pyproject.toml b/pyproject.toml index e09fdd47d..362b3a34b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,15 +68,24 @@ output = ".python_coverage.xml" [tool.ruff.lint] select = [ + "A", # flake8 builtins "ANN", # type annotations "ASYNC", + "B", # bugbear + "COM", # commas + "C4", # comprehensions "BLE", # no blind exceptions + "DTZ", # datetimes "ERA", # dead code "FAST", # fastapi "FBT", # boolean traps "S", # bandit (security) "YTT" # flake8 ] +ignore = [ + "COM812", +] + [tool.ruff.lint.per-file-ignores] "**/tests/**/*.py" = ["S101"] "**/features/steps/**/*.py" = ["S101"]