From ad4be481f5c71b53d91472aef25572cfc3d5a0ad Mon Sep 17 00:00:00 2001 From: chrisaddy Date: Wed, 28 May 2025 15:29:47 -0400 Subject: [PATCH 1/4] add type annotations where missing --- .../datamanager/features/environment.py | 3 ++- .../features/steps/equity_bars_steps.py | 17 ++++++------ .../features/steps/health_steps.py | 4 ++- .../datamanager/src/datamanager/main.py | 26 ++++++++++--------- .../datamanager/src/datamanager/models.py | 7 +++-- .../src/positionmanager/clients.py | 2 +- .../src/positionmanager/main.py | 2 +- .../src/positionmanager/models.py | 14 ++++++---- .../src/positionmanager/portfolio.py | 2 +- .../tests/test_positionmanager_main.py | 18 ++++++++----- pyproject.toml | 12 +++++++++ uv.lock | 2 ++ 12 files changed, 71 insertions(+), 38 deletions(-) diff --git a/application/datamanager/features/environment.py b/application/datamanager/features/environment.py index 8ca6c9f2a..8c84c4cef 100644 --- a/application/datamanager/features/environment.py +++ b/application/datamanager/features/environment.py @@ -1,5 +1,6 @@ import os +from behave.runner import Context -def before_all(context): +def before_all(context: Context) -> None: context.base_url = os.environ.get("BASE_URL", "http://datamanager:8080") diff --git a/application/datamanager/features/steps/equity_bars_steps.py b/application/datamanager/features/steps/equity_bars_steps.py index 5ea39b6f6..2fff7bc33 100644 --- a/application/datamanager/features/steps/equity_bars_steps.py +++ b/application/datamanager/features/steps/equity_bars_steps.py @@ -6,29 +6,30 @@ import requests from behave import given, when, then +from behave.runner import Context @given("I have date ranges") -def step_impl_date_ranges(context): - for row in context.table: +def step_impl_date_ranges(context: Context) -> None: + for row in context.table: # ty: ignore context.start_date = row["start_date"] context.end_date = row["end_date"] @given("the datamanager API is running") -def step_impl_api_url(context): +def step_impl_api_url(context: Context) -> None: context.api_url = context.base_url @when('I send a POST request to "{endpoint}" for date range') -def step_impl_post_request(context, endpoint): +def step_impl_post_request(context: Context, endpoint: str) -> None: url = f"{context.api_url}{endpoint}" response = requests.post(url, json={"date": context.start_date}) context.response = response @when('I send a GET request to "{endpoint}" for date range') -def step_imp_get_request(context, endpoint): +def step_imp_get_request(context: Context, endpoint: str) -> None: url = f"{context.api_url}{endpoint}" response = requests.get( url, @@ -38,14 +39,14 @@ def step_imp_get_request(context, endpoint): @then("the response status code should be {status_code}") -def step_impl_response_status_code(context, status_code): +def step_impl_response_status_code(context: Context, status_code: str) -> None: assert context.response.status_code == int(status_code), ( f"Expected status code {status_code}, got {context.response.status_code}" ) @when('I send a DELETE request to "{endpoint}" for date "{date_str}"') -def step_impl(context, endpoint, date_str): +def step_impl(context: Context, endpoint: str, date_str: str) -> None: url = f"{context.api_url}{endpoint}" response = requests.delete(url, json={"date": date_str}) context.response = response @@ -53,7 +54,7 @@ def step_impl(context, endpoint, date_str): @then('the equity bars data for "{date_str}" should be deleted') -def step_impl_equity_bars(context, date_str): +def step_impl_equity_bars(context: Context, date_str: str) -> None: if os.environ.get("GCP_GCS_BUCKET"): assert True, "GCS bucket deletion check would go here" else: diff --git a/application/datamanager/features/steps/health_steps.py b/application/datamanager/features/steps/health_steps.py index 8772c8060..1f1660343 100644 --- a/application/datamanager/features/steps/health_steps.py +++ b/application/datamanager/features/steps/health_steps.py @@ -1,8 +1,10 @@ from behave import when +from behave.runner import Context + import requests @when('I send a GET request to "{endpoint}"') -def step_impl(context, endpoint): +def step_impl(context: Context, endpoint: str) -> None: url = f"{context.api_url}{endpoint}" context.response = requests.get(url) diff --git a/application/datamanager/src/datamanager/main.py b/application/datamanager/src/datamanager/main.py index 07922a025..f0694aedd 100644 --- a/application/datamanager/src/datamanager/main.py +++ b/application/datamanager/src/datamanager/main.py @@ -1,19 +1,21 @@ -import traceback -import pyarrow import os -from prometheus_fastapi_instrumentator import Instrumentator - -import duckdb -from google.api_core import exceptions -from google.cloud import storage +import traceback from contextlib import asynccontextmanager from datetime import date +from typing import AsyncGenerator + +import duckdb import httpx import polars as pl -from fastapi import FastAPI, Request, Response, status, HTTPException +import pyarrow +from fastapi import FastAPI, HTTPException, Request, Response, status +from google.api_core import exceptions +from google.cloud import storage # type: ignore +from loguru import logger +from prometheus_fastapi_instrumentator import Instrumentator + from .config import Settings from .models import BarsSummary, SummaryDate -from loguru import logger def bars_query(*, bucket: str, start_date: date, end_date: date) -> str: @@ -37,7 +39,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str: @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI) -> AsyncGenerator[None]: app.state.settings = Settings() app.state.bucket = storage.Client(os.getenv("GCP_PROJECT")).bucket( app.state.settings.gcp.bucket.name @@ -67,7 +69,7 @@ async def lifespan(app: FastAPI): @application.get("/health") -async def health_check(): +async def health_check() -> Response: return Response(status_code=status.HTTP_200_OK) @@ -151,7 +153,7 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars @application.delete("/equity-bars") -async def delete_equity_bars(request: Request, summary_date: SummaryDate): +async def delete_equity_bars(request: Request, summary_date: SummaryDate) -> Response: bucket = request.app.state.bucket year = summary_date.date.year month = summary_date.date.month diff --git a/application/datamanager/src/datamanager/models.py b/application/datamanager/src/datamanager/models.py index 6649f2ed2..efd255525 100644 --- a/application/datamanager/src/datamanager/models.py +++ b/application/datamanager/src/datamanager/models.py @@ -1,5 +1,6 @@ import datetime from pydantic import BaseModel, Field, field_validator +from pydantic_core import core_schema class SummaryDate(BaseModel): @@ -8,7 +9,7 @@ class SummaryDate(BaseModel): ) @field_validator("date", mode="before") - def parse_date(cls, value): + def parse_date(cls, value: datetime.date | str) -> datetime.date: if isinstance(value, datetime.date): return value for fmt in ("%Y-%m-%d", "%Y/%m/%d"): @@ -27,7 +28,9 @@ class DateRange(BaseModel): @field_validator("end") @classmethod - def check_end_after_start(cls, end_value, info): + def check_end_after_start( + cls, end_value: datetime.datetime, info: core_schema.ValidationInfo + ) -> datetime.datetime: start_value = info.data.get("start") if start_value and end_value <= start_value: raise ValueError("End date must be after start date.") diff --git a/application/positionmanager/src/positionmanager/clients.py b/application/positionmanager/src/positionmanager/clients.py index babf50b21..196c32902 100644 --- a/application/positionmanager/src/positionmanager/clients.py +++ b/application/positionmanager/src/positionmanager/clients.py @@ -54,7 +54,7 @@ def clear_positions(self) -> Dict[str, Any]: class DataClient: - def __init__(self, datamanager_base_url: str | None): + def __init__(self, datamanager_base_url: str | None) -> None: self.datamanager_base_url = datamanager_base_url def get_data( diff --git a/application/positionmanager/src/positionmanager/main.py b/application/positionmanager/src/positionmanager/main.py index af0a7ccc5..990656e49 100644 --- a/application/positionmanager/src/positionmanager/main.py +++ b/application/positionmanager/src/positionmanager/main.py @@ -18,7 +18,7 @@ @application.get("/health") -def get_health(): +def get_health() -> dict[str, str]: return {"status": "healthy"} diff --git a/application/positionmanager/src/positionmanager/models.py b/application/positionmanager/src/positionmanager/models.py index cf806ae6e..5bfe85584 100644 --- a/application/positionmanager/src/positionmanager/models.py +++ b/application/positionmanager/src/positionmanager/models.py @@ -1,7 +1,9 @@ -from pydantic import BaseModel, Field, field_validator -from typing import Dict, Any -from decimal import Decimal, ROUND_HALF_UP from datetime import datetime +from decimal import ROUND_HALF_UP, Decimal +from typing import Any, Dict + +from pydantic import BaseModel, Field, field_validator +from pydantic_core import core_schema class Money(BaseModel): @@ -10,7 +12,7 @@ class Money(BaseModel): ) @field_validator("amount", check_fields=True) - def validate_amount(cls, v): + def validate_amount(cls, v: str | Decimal) -> Decimal: if not isinstance(v, Decimal): v = Decimal(str(v)) @@ -39,7 +41,9 @@ class DateRange(BaseModel): @field_validator("end") @classmethod - def check_end_after_start(cls, end_value, info): + def check_end_after_start( + cls, end_value: datetime, info: core_schema.ValidationInfo + ) -> datetime: start_value = info.data.get("start") if start_value and end_value <= start_value: raise ValueError("End date must be after start date.") diff --git a/application/positionmanager/src/positionmanager/portfolio.py b/application/positionmanager/src/positionmanager/portfolio.py index 9a8de058f..6903f2e87 100644 --- a/application/positionmanager/src/positionmanager/portfolio.py +++ b/application/positionmanager/src/positionmanager/portfolio.py @@ -12,7 +12,7 @@ def __init__( self, minimum_portfolio_tickers: int = 5, maximum_portfolio_tickers: int = 20, - ): + ) -> None: self.minimum_portfolio_tickers = minimum_portfolio_tickers self.maximum_portfolio_tickers = maximum_portfolio_tickers diff --git a/application/positionmanager/tests/test_positionmanager_main.py b/application/positionmanager/tests/test_positionmanager_main.py index 68e9f86bf..1db0b3715 100644 --- a/application/positionmanager/tests/test_positionmanager_main.py +++ b/application/positionmanager/tests/test_positionmanager_main.py @@ -14,7 +14,7 @@ client = TestClient(application) -def test_health_check(): +def test_health_check() -> None: response = client.get("/health") assert response.status_code == 200 assert response.json() == {"status": "healthy"} @@ -25,8 +25,11 @@ class TestPositionsEndpoint(unittest.TestCase): @patch("application.positionmanager.src.positionmanager.main.DataClient") @patch("application.positionmanager.src.positionmanager.main.PortfolioOptimizer") def test_create_position_success( - self, MockPortfolioOptimizer, MockDataClient, MockAlpacaClient - ): + self, + MockPortfolioOptimizer: MagicMock, + MockDataClient: MagicMock, + MockAlpacaClient: MagicMock, + ) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) mock_cash_balance = Money(amount=Decimal("100000.00")) mock_alpaca_instance.get_cash_balance.return_value = mock_cash_balance @@ -76,7 +79,7 @@ def test_create_position_success( assert mock_alpaca_instance.place_notional_order.call_count == 2 @patch("application.positionmanager.src.positionmanager.main.AlpacaClient") - def test_create_position_alpaca_error(self, MockAlpacaClient): + def test_create_position_alpaca_error(self, MockAlpacaClient: MagicMock) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) mock_alpaca_instance.get_cash_balance.side_effect = Exception("API error") MockAlpacaClient.return_value = mock_alpaca_instance @@ -91,7 +94,7 @@ def test_create_position_alpaca_error(self, MockAlpacaClient): mock_alpaca_instance.get_cash_balance.assert_called_once() @patch("application.positionmanager.src.positionmanager.main.AlpacaClient") - def test_delete_positions_success(self, MockAlpacaClient): + def test_delete_positions_success(self, MockAlpacaClient: MagicMock) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) mock_alpaca_instance.clear_positions.return_value = { "status": "success", @@ -113,7 +116,10 @@ def test_delete_positions_success(self, MockAlpacaClient): mock_alpaca_instance.get_cash_balance.assert_called_once() @patch("application.positionmanager.src.positionmanager.main.AlpacaClient") - def test_delete_positions_error(self, MockAlpacaClient): + def test_delete_positions_error( + self, + MockAlpacaClient: MagicMock, + ) -> None: mock_alpaca_instance = MagicMock(spec=AlpacaClient) mock_alpaca_instance.clear_positions.side_effect = Exception("API error") MockAlpacaClient.return_value = mock_alpaca_instance diff --git a/pyproject.toml b/pyproject.toml index 69f7605be..33c091010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "httpx>=0.28.1", "pulumi-docker-build>=0.0.12", "pulumi-gcp>=8.32.0", + "requests>=2.32.3", ] [tool.uv.workspace] @@ -65,5 +66,16 @@ skip_covered = true [tool.coverage.xml] output = ".python_coverage.xml" +[tool.ruff] +select = [ + "ANN" +] + [tool.ty.rules] unresolved-import = "ignore" +invalid-return-type = "error" +invalid-argument-type = "error" +unresolved-reference = "error" + +[tool.pyright] +reportMissingImports = "none" diff --git a/uv.lock b/uv.lock index 21e14097a..414684ecf 100644 --- a/uv.lock +++ b/uv.lock @@ -1475,6 +1475,7 @@ dependencies = [ { name = "httpx" }, { name = "pulumi-docker-build" }, { name = "pulumi-gcp" }, + { name = "requests" }, ] [package.dev-dependencies] @@ -1500,6 +1501,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.1" }, { name = "pulumi-docker-build", specifier = ">=0.0.12" }, { name = "pulumi-gcp", specifier = ">=8.32.0" }, + { name = "requests", specifier = ">=2.32.3" }, ] [package.metadata.requires-dev] From 25fa6b040466e1573b91f2caa208fc02625d8e4e Mon Sep 17 00:00:00 2001 From: chrisaddy Date: Wed, 28 May 2025 15:30:33 -0400 Subject: [PATCH 2/4] add ANN type annotations and fix for ruff --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 33c091010..ac952198a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ output = ".python_coverage.xml" [tool.ruff] select = [ "ANN" + "ERA" ] [tool.ty.rules] @@ -76,6 +77,3 @@ unresolved-import = "ignore" invalid-return-type = "error" invalid-argument-type = "error" unresolved-reference = "error" - -[tool.pyright] -reportMissingImports = "none" From c72de2cea8398aef41e4534ea24cdda30ff2d4ac Mon Sep 17 00:00:00 2001 From: chrisaddy Date: Wed, 28 May 2025 15:45:57 -0400 Subject: [PATCH 3/4] ruff fixes From 5375d15a14087427d5b0492b283fb997e313220f Mon Sep 17 00:00:00 2001 From: chrisaddy Date: Wed, 28 May 2025 15:33:27 -0400 Subject: [PATCH 4/4] add fastapi linting fix bandit issues fixing ruff issues --- .../datamanager/features/steps/equity_bars_steps.py | 5 +++-- .../datamanager/features/steps/health_steps.py | 2 +- application/datamanager/src/datamanager/main.py | 6 +++--- pyproject.toml | 13 ++++++++++--- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/application/datamanager/features/steps/equity_bars_steps.py b/application/datamanager/features/steps/equity_bars_steps.py index 2fff7bc33..55b4f8acb 100644 --- a/application/datamanager/features/steps/equity_bars_steps.py +++ b/application/datamanager/features/steps/equity_bars_steps.py @@ -24,7 +24,7 @@ def step_impl_api_url(context: Context) -> None: @when('I send a POST request to "{endpoint}" for date range') def step_impl_post_request(context: Context, endpoint: str) -> None: url = f"{context.api_url}{endpoint}" - response = requests.post(url, json={"date": context.start_date}) + response = requests.post(url, json={"date": context.start_date}, timeout=30) context.response = response @@ -34,6 +34,7 @@ def step_imp_get_request(context: Context, endpoint: str) -> None: response = requests.get( url, params={"start_date": context.start_date, "end_date": context.end_date}, + timeout=30, ) context.response = response @@ -48,7 +49,7 @@ def step_impl_response_status_code(context: Context, status_code: str) -> None: @when('I send a DELETE request to "{endpoint}" for date "{date_str}"') def step_impl(context: Context, endpoint: str, date_str: str) -> None: url = f"{context.api_url}{endpoint}" - response = requests.delete(url, json={"date": date_str}) + response = requests.delete(url, json={"date": date_str}, timeout=30) context.response = response context.test_date = date_str diff --git a/application/datamanager/features/steps/health_steps.py b/application/datamanager/features/steps/health_steps.py index 1f1660343..446b07522 100644 --- a/application/datamanager/features/steps/health_steps.py +++ b/application/datamanager/features/steps/health_steps.py @@ -7,4 +7,4 @@ @when('I send a GET request to "{endpoint}"') def step_impl(context: Context, endpoint: str) -> None: url = f"{context.api_url}{endpoint}" - context.response = requests.get(url) + context.response = requests.get(url, timeout=30) diff --git a/application/datamanager/src/datamanager/main.py b/application/datamanager/src/datamanager/main.py index f0694aedd..e6754b861 100644 --- a/application/datamanager/src/datamanager/main.py +++ b/application/datamanager/src/datamanager/main.py @@ -21,7 +21,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str: path_pattern = f"gs://{bucket}/equity/bars/*/*/*/*" - return f""" + return f""" SELECT * FROM read_parquet( '{path_pattern}', @@ -35,7 +35,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str: (year < {end_date.year} OR (year = {end_date.year} AND month < {end_date.month}) OR (year = {end_date.year} AND month = {end_date.month} AND day <= {end_date.day})) - """ + """ # noqa: S608 @asynccontextmanager @@ -111,7 +111,7 @@ async def get_equity_bars( return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) -@application.post("/equity-bars", response_model=BarsSummary) +@application.post("/equity-bars") async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> BarsSummary: polygon = request.app.state.settings.polygon bucket = request.app.state.settings.gcp.bucket diff --git a/pyproject.toml b/pyproject.toml index ac952198a..467cc3050 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,11 +66,18 @@ skip_covered = true [tool.coverage.xml] output = ".python_coverage.xml" -[tool.ruff] +[tool.ruff.lint] select = [ - "ANN" - "ERA" + "ANN", # type annotations + "ASYNC", + "ERA", # dead code + "FAST", # fastapi + "S", # bandit (security) + "YTT" # flake8 ] +[tool.ruff.lint.per-file-ignores] +"**/tests/**/*.py" = ["S101"] +"**/features/steps/**/*.py" = ["S101"] [tool.ty.rules] unresolved-import = "ignore"