Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions application/datamanager/src/datamanager/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def bars_query(*, bucket: str, start_date: date, end_date: date) -> str:
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
app.state.settings = Settings()
app.state.bucket = storage.Client(os.getenv("GCP_PROJECT")).bucket(
app.state.settings.gcp.bucket.name
app.state.settings.gcp.bucket.name,
)

DUCKDB_ACCESS_KEY = os.getenv("DUCKDB_ACCESS_KEY")
Expand Down Expand Up @@ -81,12 +81,16 @@ async def health_check() -> Response:

@application.get("/equity-bars")
async def get_equity_bars(
request: Request, start_date: date, end_date: date
request: Request,
start_date: date,
end_date: date,
) -> Response:
settings: Settings = request.app.state.settings

query = bars_query(
bucket=settings.gcp.bucket.name, start_date=start_date, end_date=end_date
bucket=settings.gcp.bucket.name,
start_date=start_date,
end_date=end_date,
)

try:
Expand Down Expand Up @@ -150,9 +154,10 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars
pl.from_epoch("t", time_unit="ms").dt.year().alias("year"),
pl.from_epoch("t", time_unit="ms").dt.month().alias("month"),
pl.from_epoch("t", time_unit="ms").dt.day().alias("day"),
]
],
).write_parquet(
bucket.daily_bars_path, partition_by=["year", "month", "day"]
bucket.daily_bars_path,
partition_by=["year", "month", "day"],
)
except (
requests.RequestException,
Expand All @@ -166,7 +171,7 @@ async def fetch_equity_bars(request: Request, summary_date: SummaryDate) -> Bars
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to write data",
)
) from e
return BarsSummary(date=summary_date.date.strftime("%Y-%m-%d"), count=count)


Expand Down
12 changes: 9 additions & 3 deletions application/datamanager/src/datamanager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class SummaryDate(BaseModel):
date: datetime.date = Field(
default_factory=lambda: datetime.datetime.utcnow().date()
default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc).date(),
)

@field_validator("date", mode="before")
Expand All @@ -14,7 +14,11 @@ def parse_date(cls, value: datetime.date | str) -> datetime.date:
return value
for fmt in ("%Y-%m-%d", "%Y/%m/%d"):
try:
return datetime.datetime.strptime(value, fmt).date()
return (
datetime.datetime.strptime(value, fmt)
.replace(tzinfo=datetime.timezone.utc)
.date()
)
except ValueError:
continue
raise ValueError("Invalid date format: expected YYYY-MM-DD or YYYY/MM/DD")
Expand All @@ -29,7 +33,9 @@ class DateRange(BaseModel):
@field_validator("end")
@classmethod
def check_end_after_start(
cls, end_value: datetime.datetime, info: core_schema.ValidationInfo
cls,
end_value: datetime.datetime,
info: core_schema.ValidationInfo,
) -> datetime.datetime:
start_value = info.data.get("start")
if start_value and end_value <= start_value:
Expand Down
2 changes: 1 addition & 1 deletion application/positionmanager/src/positionmanager/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_data(
pl.col("timestamp")
.str.slice(0, 10)
.str.strptime(pl.Date, "%Y-%m-%d")
.alias("date")
.alias("date"),
)

data = (
Expand Down
16 changes: 8 additions & 8 deletions application/positionmanager/src/positionmanager/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import FastAPI, HTTPException
import requests
import os
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
import polars as pl
from typing import Dict, Any

Expand Down Expand Up @@ -51,8 +51,8 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
) from e

date_range = DateRange(
start=datetime.now() - timedelta(days=trading_days_per_year),
end=datetime.now(),
start=datetime.now(tz=timezone.utc) - timedelta(days=trading_days_per_year),
end=datetime.now(tz=timezone.utc),
)

try:
Expand Down Expand Up @@ -83,21 +83,21 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
continue

latest_prices = historical_data.filter(pl.col(ticker).is_not_null()).select(
ticker
ticker,
)
if latest_prices.is_empty():
executed_trades.append(
{
"ticker": ticker,
"status": "error",
"error": "No recent price available",
}
},
)
continue
latest_price = latest_prices.tail(1)[0, 0]

notional_amount = Money.from_float(
latest_price * share_count * 0.95
latest_price * share_count * 0.95,
) # 5% buffer

try:
Expand All @@ -109,7 +109,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
"share_count": share_count,
"notional_amount": float(notional_amount),
"status": "success",
}
},
)

except (requests.RequestException, APIError, ValidationError) as e:
Expand All @@ -120,7 +120,7 @@ def create_position(payload: PredictionPayload) -> Dict[str, Any]:
"notional_amount": float(notional_amount),
"status": "error",
"error": str(e),
}
},
)

final_cash_balance = alpaca_client.get_cash_balance()
Expand Down
7 changes: 5 additions & 2 deletions application/positionmanager/src/positionmanager/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

class Money(BaseModel):
amount: Decimal = Field(
..., description="Monetary value in USD with 2 decimal places"
...,
description="Monetary value in USD with 2 decimal places",
)

@field_validator("amount", check_fields=True)
Expand Down Expand Up @@ -42,7 +43,9 @@ class DateRange(BaseModel):
@field_validator("end")
@classmethod
def check_end_after_start(
cls, end_value: datetime, info: core_schema.ValidationInfo
cls,
end_value: datetime,
info: core_schema.ValidationInfo,
) -> datetime:
start_value = info.data.get("start")
if start_value and end_value <= start_value:
Expand Down
4 changes: 3 additions & 1 deletion application/positionmanager/src/positionmanager/portfolio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def get_optimized_portfolio(

long_only_weight_bounds = (0, 0.2) # 20% max weight per asset
efficient_frontier = EfficientFrontier(
mu, S, weight_bounds=long_only_weight_bounds
mu,
S,
weight_bounds=long_only_weight_bounds,
)

efficient_frontier.max_sharpe(risk_free_rate=0.02) # 2% risk-free rate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_create_position_success(

mock_data_instance = MagicMock(spec=DataClient)
mock_historical_data = pl.DataFrame(
{"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]}
{"date": ["2025-05-01"], "AAPL": [150.00], "MSFT": [250.00]},
)
mock_data_instance.get_data.return_value = mock_historical_data
MockDataClient.return_value = mock_data_instance
Expand Down Expand Up @@ -84,7 +84,8 @@ def test_create_position_success(
def test_create_position_alpaca_error(self, MockAlpacaClient: MagicMock) -> None:
mock_alpaca_instance = MagicMock(spec=AlpacaClient)
mock_alpaca_instance.get_cash_balance.side_effect = HTTPException(
status_code=500, detail="Error getting cash balance"
status_code=500,
detail="Error getting cash balance",
)
MockAlpacaClient.return_value = mock_alpaca_instance

Expand Down Expand Up @@ -126,7 +127,8 @@ def test_delete_positions_error(
) -> None:
mock_alpaca_instance = MagicMock(spec=AlpacaClient)
mock_alpaca_instance.clear_positions.side_effect = HTTPException(
status_code=500, detail="Error getting cash balance"
status_code=500,
detail="Error getting cash balance",
)
MockAlpacaClient.return_value = mock_alpaca_instance

Expand Down
7 changes: 4 additions & 3 deletions infrastructure/cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
value=buckets.production_data_bucket.name,
),
cloudrun.ServiceTemplateSpecContainerEnvArgs(
name="DUCKDB_ACCESS_KEY", value=duckdb_access_key
name="DUCKDB_ACCESS_KEY",
value=duckdb_access_key,
),
cloudrun.ServiceTemplateSpecContainerEnvArgs(
name="DUCKDB_SECRET",
value=duckdb_secret,
),
],
)
),
],
),
),
Expand All @@ -59,7 +60,7 @@
push_config=pubsub.SubscriptionPushConfigArgs(
push_endpoint=service.statuses[0].url,
oidc_token=pubsub.SubscriptionPushConfigOidcTokenArgs(
service_account_email=project.platform_service_account.email
service_account_email=project.platform_service_account.email,
),
),
)
Expand Down
9 changes: 6 additions & 3 deletions infrastructure/images.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from datetime import datetime
from datetime import datetime, timezone
from glob import glob
import pulumi
import pulumi_docker_build as docker_build
Expand All @@ -12,7 +12,10 @@
dockerfile_paths = glob(os.path.join("..", "application", "*", "Dockerfile"))
dockerfile_paths = [os.path.relpath(dockerfile) for dockerfile in dockerfile_paths]

tags = ["latest", datetime.utcnow().strftime("%Y%m%d")]
tags = [
"latest",
datetime.now(tz=timezone.utc).strftime("%Y%m%d"),
]

images = {}
for dockerfile in dockerfile_paths:
Expand All @@ -36,7 +39,7 @@
address="docker.io",
username=dockerhub_username,
password=dockerhub_password,
)
),
],
)

Expand Down
22 changes: 11 additions & 11 deletions infrastructure/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
cloudrun.ServiceTemplateSpecContainerVolumeMountArgs(
name="prometheus-config",
mount_path="/etc/prometheus",
)
),
],
ports=[
cloudrun.ServiceTemplateSpecContainerPortArgs(
container_port=9090
)
container_port=9090,
),
],
)
),
],
volumes=[
cloudrun.ServiceTemplateSpecVolumeArgs(
Expand All @@ -41,12 +41,12 @@
cloudrun.ServiceTemplateSpecVolumeSecretItemArgs(
path="prometheus.yaml",
version=prometheus_config_version.version,
)
),
],
),
)
),
],
)
),
),
)

Expand All @@ -61,11 +61,11 @@
image="grafana/grafana:latest",
ports=[
cloudrun.ServiceTemplateSpecContainerPortArgs(
container_port=3000
)
container_port=3000,
),
],
)
),
],
)
),
),
)
2 changes: 1 addition & 1 deletion infrastructure/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@
project=PROJECT,
role="roles/pubsub.subscriber",
member=platform_service_account.email.apply(
lambda e: f"serviceAccount:{e}"
lambda e: f"serviceAccount:{e}",
), # ty: ignore[missing-argument]
)
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,24 @@ output = ".python_coverage.xml"

[tool.ruff.lint]
select = [
"A", # flake8 builtins
"ANN", # type annotations
"ASYNC",
"B", # bugbear
"COM", # commas
"C4", # comprehensions
"BLE", # no blind exceptions
"DTZ", # datetimes
"ERA", # dead code
"FAST", # fastapi
"FBT", # boolean traps
"S", # bandit (security)
"YTT" # flake8
]
ignore = [
"COM812",
]

[tool.ruff.lint.per-file-ignores]
"**/tests/**/*.py" = ["S101"]
"**/features/steps/**/*.py" = ["S101"]
Expand Down