Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. Caution Review failedThe pull request is closed. WalkthroughMonorepo reorganization: moved service sources to Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Operator as Operator
participant FlyteWF as train_tft_model (Flyte)
participant Reader as read_local_data
participant Trainer as train_model
participant Validator as validate_model
participant Saver as save_model
Operator->>FlyteWF: trigger train_tft_model()
FlyteWF->>Reader: read_local_data("training_data.csv")
Reader-->>FlyteWF: TemporalFusionTransformerDataset
FlyteWF->>Trainer: train_model(dataset, wandb_run)
Trainer-->>FlyteWF: trained model
FlyteWF->>Validator: validate_model(dataset, model)
Validator-->>FlyteWF: validation metrics
FlyteWF->>Saver: save_model(model)
Saver-->>FlyteWF: artifact path
FlyteWF-->>Operator: workflow complete (artifact path)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Poem
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Pull Request Overview
This PR implements a significant restructuring of the Python codebase, migrating from a distributed application architecture to a more consolidated library-based approach centered around machine learning model development.
- Consolidates all service packages into a single
applications/modelspackage focused on TFT (Temporal Fusion Transformer) model development - Creates a shared
internallibrary package containing neural network components, data structures, and utility functions - Removes existing prediction engine, position manager, and data manager services along with their dependencies
Reviewed Changes
Copilot reviewed 77 out of 82 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| pyproject.toml | Updates workspace configuration to point to new applications and libraries structure |
| libraries/python/ | Creates new shared internal library with neural network components, data models, and utilities |
| applications/models/ | New consolidated package for model training workflows and data fetching scripts |
| applications/datamanager/, applications/portfoliomanager/ | New simplified service packages depending on internal library |
| workflows/, cli/, application/ | Removes old workflow definitions, CLI tools, and distributed service packages |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 21
♻️ Duplicate comments (1)
libraries/python/src/internal/tft_model.py (1)
237-241: Same shape mismatch issue in validate methodThe validate method has the same shape mismatch issue as the train method when calling quantile_loss.
Apply the same fix as suggested for the train method.
🧹 Nitpick comments (27)
.gitignore (1)
19-21: Avoid globally ignoring all CSV filesA blanket "*.csv" ignore can hide legitimate source/test fixtures, especially with the new internal dataset utilities. Prefer scoping ignores to known artifact locations or remove this entry.
Apply one of the following:
Option A — remove the global ignore:
-*.csvOption B — keep but add an explicit note and scope later (preferred if unsure):
-*.csv +# Avoid globally ignoring CSVs; scope to artifact directories as needed, e.g.: +# data/**/*.csvConfirm whether any CSVs are expected to be committed (e.g., small sample datasets for tests). If yes, drop the global ignore to prevent accidental omissions.
.claude/settings.local.json (1)
14-14: Pytest allowlist broadened significantly; consider narrowing scopeAllowing any pytest invocation increases the blast radius. If the intent is to run tests across the repo, at least scope to your test roots to reduce risk.
Replace with a narrower pattern:
- "Bash(python -m pytest:*)", + "Bash(python -m pytest libraries/python/tests:*)", + "Bash(python -m pytest applications/*/tests:*)",Confirm whether other ad-hoc pytest invocations are required (e.g., targeting subdirs or markers). If so, we can add explicit patterns for those cases rather than a global wildcard.
applications/portfoliomanager/pyproject.toml (1)
1-7: Relax Python version constraintPinning to an exact patch version (==3.12.10) is brittle and will break on minor patch upgrades in CI images. Use a compatible range.
-requires-python = "==3.12.10" +requires-python = ">=3.12,<3.13"Ensure other workspace projects use consistent Python constraints to avoid environment resolution conflicts.
.github/workflows/launch_infrastructure.yaml (1)
2-8: Cron runs in UTC; update the comment to avoid EST/EDT confusionGitHub Actions schedules are UTC. The current comment says “8:00 AM EST,” which is only accurate outside DST; during EDT this will run at 9:00 AM local.
Update the comment for clarity:
-name: Launch infrastructure @@ - - cron: '0 13 * * 1,2,3,4,5' # launch at 8:00 AM EST + - cron: '0 13 * * 1,2,3,4,5' # 13:00 UTC (08:00 EST / 09:00 EDT)If you need a fixed local-time launch regardless of DST, consider moving the trigger to an external scheduler or adjusting cron seasonally.
applications/models/src/models/get_alpaca_calendar.py (1)
43-43: Consider making the output path configurableThe hardcoded output path "calendar.csv" could cause issues in different environments or when running multiple instances.
+ output_path = os.getenv("CALENDAR_OUTPUT_PATH", "calendar.csv") - calendar_content.write_csv("calendar.csv") + calendar_content.write_csv(output_path)And update the log message on line 45:
- logger.info("Calendar data has been written to calendar.csv") + logger.info(f"Calendar data has been written to {output_path}")infrastructure/environment_variables.py (1)
8-14: Consider consistent naming convention for environment variablesThere's an inconsistency in the naming pattern: most variables use underscores (e.g.,
ALPACA_API_KEY_ID), while AWS region uses a colon separator (aws:region).Consider using a consistent naming pattern. If
aws:regionis a Pulumi-specific convention, document it with a comment:+# Note: aws:region follows Pulumi's config naming convention aws_region = configuration.get("aws:region") or "us-east-1"applications/models/src/models/get_alpaca_equity_bars.py (2)
95-100: Optimize timestamp conversion using vectorized operationsThe current implementation uses
map_elementswith a lambda function, which is inefficient for large datasets. Use Polars' built-in datetime operations instead.equity_bars_data = equity_bars_data.with_columns( - ( - pl.col("timestamp").map_elements( - lambda x: int(datetime.fromisoformat(x).timestamp() * 1000) - ) - ).alias("timestamp") + (pl.col("timestamp").str.to_datetime().dt.epoch(time_unit="ms")).alias("timestamp") )
112-115: Consider memory-efficient concatenation for large datasetsLoading all CSV files into memory simultaneously could cause memory issues with many tickers. Consider using lazy evaluation or streaming.
if saved_files: - all_bars = pl.concat([pl.read_csv(fp) for fp in saved_files]) + # Use lazy evaluation for memory efficiency + all_bars = pl.concat([pl.scan_csv(fp) for fp in saved_files]).collect() all_bars.write_csv("equity_bars_combined.csv") logger.info("Finished saving combined equity bars.")libraries/python/src/internal/lstm_network.py (1)
13-16: Consider making the minimum layer count configurable or removing the restrictionThe hardcoded requirement for at least 3 layers seems arbitrary and could limit flexibility. If there's a specific architectural reason for this constraint, it should be documented.
Either document the reason for the constraint:
minimum_layer_count = 3 if layer_count < minimum_layer_count: + # TFT architecture requires at least 3 LSTM layers for proper feature extraction message = f"Layer count must be at least {minimum_layer_count}" raise ValueError(message)Or consider removing the restriction entirely if it's not architecturally necessary:
- minimum_layer_count = 3 - if layer_count < minimum_layer_count: - message = f"Layer count must be at least {minimum_layer_count}" - raise ValueError(message) + if layer_count < 1: + message = "Layer count must be at least 1" + raise ValueError(message)libraries/python/tests/test_lstm_network.py (2)
12-17: Consider adding a test that asserts the new minimum layer_count constraintAll test instantiations were updated to use
layer_count=3, which matches the new class requirement (minimum 3). To prevent regressions, add one negative test that verifieslayer_count < 3raises a ValueError with the expected message.Example pytest-style test (add near the other tests):
import pytest def test_lstm_min_layer_count_enforced() -> None: with pytest.raises(ValueError, match="Layer count must be at least 3"): LSTM(input_size=8, hidden_size=16, layer_count=1, dropout_rate=0.0)If you prefer unittest style, I can provide that version too.
Also applies to: 32-37, 50-55, 65-70, 80-85, 94-99
42-46: Nit: clarify the intent ofexpected_hidden_stateThe name suggests a tensor, but you’re asserting the tuple length (
(hidden_state, cell_state)→ 2). Consider a more explicit name likeexpected_state_tuple_lenfor readability.- expected_hidden_state = 2 + expected_state_tuple_len = 2 ... - assert len(hidden_state) == expected_hidden_state + assert len(hidden_state) == expected_state_tuple_lenlibraries/python/src/internal/summaries.py (1)
1-6: Optional: use a strongly-typed date instead ofstrIf you want validation and type-safety, consider using
datetime.date. Pydantic will parse ISO8601 strings and serialize back to ISO by default, preserving the behavior while preventing invalid dates.+from datetime import date from pydantic import BaseModel class BarsSummary(BaseModel): - date: str + date: date count: intAlternatively, keep
strbut enforce a pattern:from pydantic import BaseModel, Field class BarsSummary(BaseModel): date: str = Field(pattern=r"^\d{4}-\d{2}-\d{2}$") count: intapplications/datamanager/pyproject.toml (1)
5-5: Optional: relax the Python requirement to a minor version rangePinning to an exact patch (
==3.12.10) can cause unnecessary churn across machines/CI. Consider a minor-range pin unless you rely on a patch-specific fix.-requires-python = "==3.12.10" +requires-python = ">=3.12,<3.13"If you need the exact patch for reproducibility, feel free to keep as-is.
applications/models/pyproject.toml (2)
6-19: Remove duplicate "internal" dependency entryThere are two "internal" entries in the dependencies list. Deduplicate to keep things tidy and avoid confusion.
dependencies = [ "internal", "boto3>=1.38.23", "botocore>=1.38.23", "requests>=2.31.0", "pyarrow>=20.0.0", "polygon-api-client>=1.14.6", - "internal", "flytekit>=1.16.1", "polars>=1.29.0", "loguru>=0.7.3", "pydantic>=2.8.2", "wandb>=0.21.1", ]
5-5: Optional: align Python requirement with a minor-range pinUnless a specific patch is required, consider a minor-range pin for flexibility across environments.
-requires-python = "==3.12.10" +requires-python = ">=3.12,<3.13"libraries/python/src/internal/dates.py (1)
34-55: Consider adding validation for DateRange fields.The
DateRangeclass lacks field validators similar to those in theDateclass. If string inputs are expected forstartandendfields, consider adding@field_validatordecorators to ensure consistent date parsing behavior.class DateRange(BaseModel): - start: datetime.date - end: datetime.date + start: datetime.date + end: datetime.date + + @field_validator("start", "end", mode="before") + @classmethod + def parse_dates(cls, value: datetime.date | str) -> datetime.date: + if isinstance(value, datetime.date): + return value + for fmt in ("%Y-%m-%d", "%Y/%m/%d"): + try: + return ( + datetime.datetime.strptime(value, fmt) + .replace(tzinfo=ZoneInfo("America/New_York")) + .date() + ) + except ValueError: + continue + msg = "Invalid date format: expected YYYY-MM-DD or YYYY/MM/DD" + raise ValueError(msg)libraries/python/src/internal/loss_functions.py (2)
9-10: Consider aligning default quantiles with common practice.The default quantiles
[0.1, 0.5, 0.9]differ from the original implementation which used(0.25, 0.5, 0.75). Consider whether this change is intentional or if it should match common quantile regression practices.
19-19: Consider optimizing tensor creation.Creating a new
Tensor([quantile])in each loop iteration could be optimized by creating quantile tensors once outside the loop.+ quantile_tensors = [Tensor([q]) for q in quantiles] for index, quantile in enumerate(quantiles): error = targets.sub(predictions[:, :, index]) - quantile_tensor = Tensor([quantile]) + quantile_tensor = quantile_tensors[index]libraries/python/src/internal/cloud_event.py (1)
15-15: Consider validation for event metadata.The event metadata is joined without validation. Consider adding validation to ensure the metadata doesn't contain characters that could break the event type format (e.g., empty strings, special characters).
def create_cloud_event_success( application_name: str, event_metadata: list[str], data: dict, ) -> CloudEvent: + if not event_metadata or any(not item.strip() for item in event_metadata): + raise ValueError("Event metadata cannot be empty or contain empty strings") return CloudEvent(Also applies to: 32-32
libraries/python/tests/test_dates.py (1)
29-32: Consider adding test for slash format parsingWhile this test validates that the Date model accepts a date object created from ISO format, it doesn't actually test the string parsing capability of the Date model. Based on the implementation in
internal/dates.py, the Date model supports both dash and slash formats.Consider adding a test that directly passes a string to test the parsing:
def test_date_string_dash_format() -> None: - date_instance = Date(date=datetime.date.fromisoformat("2023-03-15")) + date_instance = Date(date="2023-03-15") assert date_instance.date == datetime.date(2023, 3, 15) + +def test_date_string_slash_format() -> None: + date_instance = Date(date="2023/03/15") + + assert date_instance.date == datetime.date(2023, 3, 15)pyproject.toml (1)
17-32: Consider adjusting the rootdir configurationThe
--rootdir=/testsconfiguration appears unusual as it sets an absolute path that may not exist in all environments. This could cause issues when running tests locally.Consider using a relative path or removing this option:
addopts = [ "--verbose", "--tb=short", "--strict-markers", "--color=yes", - "--rootdir=/tests", ]applications/models/src/models/train_tft_model.py (1)
141-160: Remove unnecessary type ignore commentsThe type ignore comments suggest there might be type mismatches between Flyte's task return types and the expected parameter types. These should be properly resolved rather than suppressed.
Consider properly typing the Flyte tasks or using proper type annotations to avoid the need for type ignores. If Flyte requires these ignores due to its internal typing, consider adding a comment explaining why they're necessary.
libraries/python/src/internal/tft_model.py (1)
140-140: Unnecessary Tensor wrapping of already-Tensor objectsMultiple lines unnecessarily wrap Tensor objects with
Tensor(), which is redundant since the multiplication and addition operations already return Tensors.Apply this diff to remove redundant Tensor wrapping:
- encoder_input = Tensor(encoder_input * encoder_weights) + encoder_input = encoder_input * encoder_weights - decoder_input = Tensor(decoder_input * decoder_weights) + decoder_input = decoder_input * decoder_weights - encoder_output = Tensor(encoder_output + encoder_static_context) + encoder_output = encoder_output + encoder_static_context - decoder_output = Tensor(decoder_output + decoder_static_context) + decoder_output = decoder_output + decoder_static_context attended_output, _ = self.self_attention.forward( - Tensor(sequence + expanded_static_context), + sequence + expanded_static_context,Also applies to: 144-144, 156-156, 160-160, 169-169
libraries/python/src/internal/dataset.py (4)
15-19: Replace constant with named constant for better maintainabilityThe hard-coded value
1e-8should be defined as a class constant for better maintainability and clarity.class Scaler: + EPSILON = 1e-8 # Small value to avoid division by zero + def __init__(self) -> None: pass def fit(self, data: pl.DataFrame) -> None: self.means = data.mean() self.standard_deviations = data.std() self.standard_deviations = self.standard_deviations.select( - pl.all().replace(0, 1e-8) + pl.all().replace(0, self.EPSILON) ) # avoid division by zero
104-104: Define magic number as a named constantThe magic number
4representing Friday should be defined as a constant for better code readability.- friday_number = 4 + FRIDAY = 4 # datetime.weekday() value for Friday # set is_holiday value for missing weekdays data = ( data.with_columns( pl.col("datetime").dt.weekday().alias("temporary_weekday") ) .with_columns( pl.when( pl.col("is_holiday").is_null() - & (pl.col("temporary_weekday") <= friday_number) + & (pl.col("temporary_weekday") <= FRIDAY) ) .then(True) # noqa: FBT003 .when( pl.col("is_holiday").is_null() - & (pl.col("temporary_weekday") > friday_number) + & (pl.col("temporary_weekday") > FRIDAY) )
210-212: Improve validation split check to be more preciseThe validation split check uses exact equality with floats, which can be problematic due to floating-point precision. Consider using a range check instead.
- if validation_split in {0.0, 1.0}: - message = "Validation split must be between 0.0 and 1.0 (exclusive)." + if not 0.0 < validation_split < 1.0: + message = "Validation split must be between 0.0 and 1.0 (exclusive)." raise ValueError(message)
292-326: Consider memory efficiency for large datasetsThe current implementation loads all batches into memory at once. For large datasets with many tickers and long sequences, this could cause memory issues.
Consider implementing a generator-based approach for memory efficiency:
def get_batch_generator( self, data_type: str = "train", validation_split: float = 0.8, input_length: int = 35, output_length: int = 7, batch_size: int = 32, ) -> Generator[dict[str, Tensor], None, None]: """Yield batches one at a time to reduce memory usage.""" # ... validation and data preparation code ... batch_buffer = [] for ticker in self.batch_data["ticker"].unique(): # ... existing ticker processing ... for i in range(len(ticker_data) - input_length - output_length + 1): # ... create batch dict ... batch_buffer.append(batch) if len(batch_buffer) >= batch_size: # Stack tensors for the batch yield self._stack_batch(batch_buffer) batch_buffer = [] # Yield remaining samples if batch_buffer: yield self._stack_batch(batch_buffer)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (78)
.claude/settings.local.json(1 hunks).github/workflows/launch_infrastructure.yaml(1 hunks).github/workflows/teardown_infrastructure.yaml(1 hunks).gitignore(2 hunks).mise.toml(0 hunks)Dockerfile.tests(1 hunks)README.md(0 hunks)application/datamanager/.dockerignore(0 hunks)application/datamanager/Dockerfile(0 hunks)application/datamanager/Dockerfile.test(0 hunks)application/datamanager/compose.yaml(0 hunks)application/datamanager/features/environment.py(0 hunks)application/datamanager/features/equity_bars.feature(0 hunks)application/datamanager/features/health.feature(0 hunks)application/datamanager/features/steps/equity_bars_steps.py(0 hunks)application/datamanager/features/steps/health_steps.py(0 hunks)application/datamanager/mise.toml(0 hunks)application/datamanager/pyproject.toml(0 hunks)application/datamanager/src/datamanager/clients.py(0 hunks)application/datamanager/src/datamanager/main.py(0 hunks)application/datamanager/tests/test_datamanager_main.py(0 hunks)application/datamanager/tests/test_datamanager_models.py(0 hunks)application/positionmanager/Dockerfile(0 hunks)application/positionmanager/pyproject.toml(0 hunks)application/positionmanager/src/positionmanager/__init__.py(0 hunks)application/positionmanager/src/positionmanager/clients.py(0 hunks)application/positionmanager/src/positionmanager/main.py(0 hunks)application/positionmanager/src/positionmanager/portfolio.py(0 hunks)application/positionmanager/tests/test_positionmanager_main.py(0 hunks)application/predictionengine/Dockerfile(0 hunks)application/predictionengine/compose.yaml(0 hunks)application/predictionengine/pyproject.toml(0 hunks)application/predictionengine/src/predictionengine/dataset.py(0 hunks)application/predictionengine/src/predictionengine/gated_residual_network.py(0 hunks)application/predictionengine/src/predictionengine/loss_function.py(0 hunks)application/predictionengine/src/predictionengine/main.py(0 hunks)application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py(0 hunks)application/predictionengine/src/predictionengine/multi_head_self_attention.py(0 hunks)application/predictionengine/src/predictionengine/post_processor.py(0 hunks)application/predictionengine/src/predictionengine/ticker_embedding.py(0 hunks)application/predictionengine/tests/test_dataset.py(0 hunks)application/predictionengine/tests/test_gated_residual_network.py(0 hunks)application/predictionengine/tests/test_post_processor.py(0 hunks)application/predictionengine/tests/test_predictionengine_main.py(0 hunks)application/predictionengine/tests/test_ticker_embedding.py(0 hunks)applications/datamanager/pyproject.toml(1 hunks)applications/models/pyproject.toml(1 hunks)applications/models/src/models/get_alpaca_calendar.py(1 hunks)applications/models/src/models/get_alpaca_equity_bars.py(1 hunks)applications/models/src/models/train_tft_model.py(1 hunks)applications/portfoliomanager/pyproject.toml(1 hunks)cli/datamanager.py(0 hunks)cli/pyproject.toml(0 hunks)infrastructure/environment_variables.py(1 hunks)infrastructure/images.py(1 hunks)libraries/python/pyproject.toml(1 hunks)libraries/python/src/internal/cloud_event.py(1 hunks)libraries/python/src/internal/dataset.py(1 hunks)libraries/python/src/internal/dates.py(2 hunks)libraries/python/src/internal/equity_bar.py(1 hunks)libraries/python/src/internal/loss_functions.py(1 hunks)libraries/python/src/internal/lstm_network.py(2 hunks)libraries/python/src/internal/mhsa_network.py(1 hunks)libraries/python/src/internal/money.py(0 hunks)libraries/python/src/internal/summaries.py(1 hunks)libraries/python/src/internal/tft_model.py(1 hunks)libraries/python/src/internal/variable_selection_network.py(1 hunks)libraries/python/tests/test_dataset.py(1 hunks)libraries/python/tests/test_dates.py(1 hunks)libraries/python/tests/test_equity_bar.py(1 hunks)libraries/python/tests/test_loss_functions.py(4 hunks)libraries/python/tests/test_lstm_network.py(5 hunks)libraries/python/tests/test_mhsa_network.py(5 hunks)libraries/python/tests/test_variable_selection_network.py(1 hunks)pyproject.toml(1 hunks)workflows/fetch_data.py(0 hunks)workflows/pyproject.toml(0 hunks)workflows/train_predictionengine.py(0 hunks)
💤 Files with no reviewable changes (46)
- application/datamanager/Dockerfile.test
- application/predictionengine/tests/test_ticker_embedding.py
- application/predictionengine/compose.yaml
- application/positionmanager/Dockerfile
- workflows/pyproject.toml
- application/positionmanager/pyproject.toml
- application/predictionengine/src/predictionengine/gated_residual_network.py
- application/datamanager/features/equity_bars.feature
- cli/pyproject.toml
- application/datamanager/src/datamanager/clients.py
- application/predictionengine/tests/test_predictionengine_main.py
- application/datamanager/features/health.feature
- application/datamanager/pyproject.toml
- cli/datamanager.py
- application/datamanager/features/environment.py
- application/datamanager/.dockerignore
- application/positionmanager/tests/test_positionmanager_main.py
- application/positionmanager/src/positionmanager/portfolio.py
- application/predictionengine/src/predictionengine/ticker_embedding.py
- application/positionmanager/src/positionmanager/clients.py
- application/predictionengine/pyproject.toml
- README.md
- application/predictionengine/Dockerfile
- application/predictionengine/src/predictionengine/post_processor.py
- application/datamanager/compose.yaml
- application/datamanager/Dockerfile
- application/positionmanager/src/positionmanager/init.py
- application/predictionengine/src/predictionengine/multi_head_self_attention.py
- workflows/fetch_data.py
- application/predictionengine/tests/test_dataset.py
- application/datamanager/tests/test_datamanager_main.py
- application/predictionengine/tests/test_post_processor.py
- application/predictionengine/tests/test_gated_residual_network.py
- application/predictionengine/src/predictionengine/loss_function.py
- application/datamanager/features/steps/health_steps.py
- application/datamanager/mise.toml
- .mise.toml
- workflows/train_predictionengine.py
- application/datamanager/tests/test_datamanager_models.py
- application/datamanager/features/steps/equity_bars_steps.py
- libraries/python/src/internal/money.py
- application/datamanager/src/datamanager/main.py
- application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py
- application/positionmanager/src/positionmanager/main.py
- application/predictionengine/src/predictionengine/dataset.py
- application/predictionengine/src/predictionengine/main.py
🧰 Additional context used
🧬 Code Graph Analysis (20)
libraries/python/src/internal/summaries.py (1)
application/datamanager/tests/test_datamanager_models.py (9)
TestBarsSummary(76-106)TestModelIntegration(109-127)test_summary_date_to_bars_summary(110-117)test_bars_summary_from_dict(101-106)test_bars_summary_initialization(77-81)test_bars_summary_zero_count(83-87)test_bars_summary_json_serialization(95-99)test_bars_summary_negative_count(89-93)test_multiple_model_validation(119-127)
libraries/python/tests/test_variable_selection_network.py (1)
libraries/python/src/internal/variable_selection_network.py (2)
VariableSelectionNetwork(5-21)forward(17-21)
applications/models/src/models/get_alpaca_equity_bars.py (3)
application/datamanager/src/datamanager/main.py (2)
get_equity_bars(149-202)fetch_equity_bars(206-269)workflows/train_predictionengine.py (1)
fetch_data(22-67)application/datamanager/src/datamanager/clients.py (1)
get_all_equity_bars(12-18)
libraries/python/src/internal/cloud_event.py (1)
application/positionmanager/src/positionmanager/main.py (1)
create_cloud_event_error(269-279)
libraries/python/tests/test_dataset.py (1)
libraries/python/src/internal/dataset.py (3)
TemporalFusionTransformerDataset(28-326)get_dimensions(240-248)get_batches(250-326)
libraries/python/src/internal/equity_bar.py (3)
application/datamanager/src/datamanager/models.py (1)
BarsSummary(52-54)application/datamanager/src/datamanager/main.py (1)
get_equity_bars(149-202)workflows/train_predictionengine.py (1)
fetch_data(22-67)
applications/models/src/models/get_alpaca_calendar.py (1)
application/positionmanager/src/positionmanager/main.py (1)
open_position(117-236)
libraries/python/tests/test_equity_bar.py (1)
libraries/python/src/internal/equity_bar.py (1)
EquityBar(6-49)
libraries/python/src/internal/loss_functions.py (3)
application/predictionengine/src/predictionengine/loss_function.py (1)
quantile_loss(8-29)application/predictionengine/tests/test_loss_function.py (7)
test_quantile_loss_multiple_samples(24-32)test_quantile_loss_basic(13-21)test_quantile_loss_different_quantiles(45-53)test_quantile_loss_shapes(56-63)test_quantile_loss_perfect_prediction(35-42)test_quantile_loss_shape_mismatch(66-72)test_quantile_loss_invalid_quantiles(75-81)application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py (1)
validate(148-161)
libraries/python/src/internal/mhsa_network.py (2)
application/predictionengine/src/predictionengine/multi_head_self_attention.py (3)
MultiHeadSelfAttention(8-66)__init__(9-30)forward(32-66)application/predictionengine/tests/test_multi_head_self_attention.py (6)
test_multi_head_attention_different_heads(34-45)test_multi_head_attention_forward(20-31)test_multi_head_attention_single_sequence(48-54)test_multi_head_attention_batch_processing(67-75)test_multi_head_attention_initialization(11-17)test_multi_head_attention_longer_sequences(57-64)
libraries/python/tests/test_loss_functions.py (3)
libraries/python/src/internal/loss_functions.py (1)
quantile_loss(4-28)application/predictionengine/src/predictionengine/loss_function.py (1)
quantile_loss(8-29)application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py (1)
validate(148-161)
libraries/python/tests/test_lstm_network.py (3)
libraries/python/src/internal/lstm_network.py (1)
LSTM(5-85)application/predictionengine/src/predictionengine/long_short_term_memory.py (2)
LongShortTermMemory(5-69)forward(22-69)application/predictionengine/tests/test_long_short_term_memory.py (7)
test_lstm_consistency(81-92)test_lstm_single_timestep(70-78)test_lstm_multiple_layers(58-67)Expected(19-22)test_lstm_initialization(14-28)test_lstm_forward(31-43)test_lstm_different_sequence_lengths(46-55)
libraries/python/tests/test_dates.py (1)
libraries/python/src/internal/dates.py (3)
Date(8-31)DateRange(34-55)to_object(51-55)
libraries/python/tests/test_mhsa_network.py (1)
libraries/python/src/internal/mhsa_network.py (1)
MultiHeadSelfAttentionNetwork(6-64)
applications/models/src/models/train_tft_model.py (2)
libraries/python/src/internal/dataset.py (3)
TemporalFusionTransformerDataset(28-326)get_dimensions(240-248)get_batches(250-326)libraries/python/src/internal/tft_model.py (5)
Parameters(19-33)TemporalFusionTransformer(37-271)train(192-225)validate(227-246)save(248-253)
infrastructure/environment_variables.py (1)
infrastructure/services.py (1)
create_service_environment_variables(9-12)
libraries/python/src/internal/dates.py (2)
application/datamanager/tests/test_datamanager_models.py (5)
test_summary_date_to_bars_summary(110-117)TestSummaryDate(14-44)TestBarsSummary(76-106)test_bars_summary_json_serialization(95-99)test_bars_summary_zero_count(83-87)application/datamanager/tests/test_datamanager_main.py (1)
TestDataManagerModels(20-32)
libraries/python/src/internal/tft_model.py (4)
libraries/python/src/internal/loss_functions.py (1)
quantile_loss(4-28)libraries/python/src/internal/lstm_network.py (2)
LSTM(5-85)forward(32-85)libraries/python/src/internal/mhsa_network.py (2)
MultiHeadSelfAttentionNetwork(6-64)forward(30-64)libraries/python/src/internal/variable_selection_network.py (2)
VariableSelectionNetwork(5-21)forward(17-21)
libraries/python/src/internal/lstm_network.py (2)
application/predictionengine/src/predictionengine/long_short_term_memory.py (3)
LongShortTermMemory(5-69)forward(22-69)__init__(6-20)application/predictionengine/tests/test_long_short_term_memory.py (6)
test_lstm_multiple_layers(58-67)test_lstm_consistency(81-92)test_lstm_forward(31-43)test_lstm_single_timestep(70-78)test_lstm_initialization(14-28)test_lstm_different_sequence_lengths(46-55)
libraries/python/src/internal/dataset.py (1)
application/predictionengine/src/predictionengine/dataset.py (3)
DataSet(121-318)_scale_data(224-244)_compute_scalers(211-222)
🪛 GitHub Check: Run Python quality checks
applications/models/src/models/train_tft_model.py
[failure] 1-10: Ruff (I001)
applications/models/src/models/train_tft_model.py:1:1: I001 Import block is un-sorted or un-formatted
🪛 GitHub Actions: Code quality check
applications/models/src/models/train_tft_model.py
[error] 1-1: Ruff check failed (command: 'ruff check --output-format=github .'): I001 Import block is un-sorted or un-formatted.
🔇 Additional comments (86)
.gitignore (1)
6-10: Good adds for common Python artifactsIgnoring ruff/pytest caches, local venvs and a dedicated coverage_output/ dir is sensible and will reduce noise in diffs.
.github/workflows/launch_infrastructure.yaml (1)
2-8: Rename looks consistent with repo restructuringWorkflow and job rename to “infrastructure” aligns with the new structure.
.github/workflows/teardown_infrastructure.yaml (2)
2-2: LGTM! Consistent naming updateThe renaming from "application" to "infrastructure" is consistent and aligns with the PR's restructuring objectives.
Also applies to: 7-8
4-5: Consider the implications of unconditional teardownThe removal of the conditional gate means the workflow will run on every scheduled trigger. Ensure this is intentional and won't cause issues with other workflows or deployments.
Can you confirm that running teardown unconditionally at 6:00 PM EST every weekday is the intended behavior? Previously, it seems there was a conditional check that has been removed.
infrastructure/environment_variables.py (1)
17-32: No None-risk: Config.require_secret returns a pulumi.Output and raises if missingPulumi's Config.require_secret(...) returns a pulumi.Output[T] (e.g., Output[str]) and raises ConfigMissingError when the key is absent — it does not return None. The current code's assumption that those secret values are non-None is therefore valid. If any of these should be optional, use configuration.get_secret(...) or configuration.get(...) and provide an explicit fallback.
Relevant locations:
- infrastructure/environment_variables.py — ALPACA_, DATA_BUCKET_NAME, POLYGON_API_KEY, DUCKDB_ use configuration.require_secret and are safe.
- infrastructure/main.py — other require_secret usages (DockerHub creds, IAM ARNs, AWS_S3_DATA_BUCKET_NAME) — consistent with required configuration.
libraries/python/src/internal/lstm_network.py (2)
39-49: LGTM! Improved state managementThe change to per-layer state lists is a good improvement that provides better flexibility for stateful LSTM operations.
25-30: Verified — tinygrad LSTMCell constructor is compatibletinygrad.nn.LSTMCell has signature LSTMCell(input_size: int, hidden_size: int, bias: bool = True), so calling it with named arguments input_size=... and hidden_size=... is valid.
- Location: libraries/python/src/internal/lstm_network.py (LSTMCell construction at ~lines 25–29)
- Relevant snippet:
self.layers.append( LSTMCell( input_size=input_size, hidden_size=self.hidden_size, ) )libraries/python/tests/test_lstm_network.py (1)
4-4: Import path update aligns with the new internal package layoutSwitching to
from internal.lstm_network import LSTMmatches the repo restructuring and the shared internal library. Good change.libraries/python/src/internal/summaries.py (1)
1-6: BarsSummary relocation verified — no remaining definitions or imports
- libraries/python/src/internal/summaries.py — class BarsSummary defined (lines 4–6)
- No other occurrences of "BarsSummary" found elsewhere in the repo
applications/datamanager/pyproject.toml (1)
1-9: Workspace wiring to internal looks correctDeclaring
dependencies = ["internal"]with[tool.uv.sources] internal = { workspace = true }is aligned with the new shared library approach.libraries/python/pyproject.toml (2)
14-16: Packaging configuration looks sane for a src-layout library
[tool.uv] package = truewithsrc = ["src"]should exposeinternalcleanly.
6-12: Minimum versions exist on PyPI — compatibility still needs verificationI checked PyPI for each exact version; all were found:
- pydantic 2.8.2 — FOUND
- cloudevents 1.12.0 — FOUND
- tinygrad 0.10.3 — FOUND
- numpy 2.2.6 — FOUND
- polars 1.29.0 — FOUND
Location to review:
- libraries/python/pyproject.toml (lines 6–12)
Recommendation (short): these versions exist, but please run your dependency resolver in a clean environment or CI (pip install / poetry/pip-compile / lockfile update) to catch any transitive incompatibilities before merging.
applications/models/pyproject.toml (1)
6-19: Verify required versions exist for all external dependenciesBefore merging, confirm that the specified minimum versions are available and compatible (especially
pyarrow>=20.0.0,polygon-api-client>=1.14.6, andflytekit>=1.16.1).libraries/python/tests/test_variable_selection_network.py (11)
1-6: LGTM! Well-structured test setup.The imports, RNG setup, and overall test module structure are appropriate for testing the VariableSelectionNetwork implementation. The use of PCG64 for deterministic testing is a good practice.
9-24: LGTM! Comprehensive initialization test.The test properly validates the network's initialization by checking both the existence of required attributes and the correct weight matrix shapes for both layers.
26-42: LGTM! Good forward pass validation.The test correctly verifies the basic forward pass functionality and output shape preservation. The use of random inputs with deterministic RNG ensures reproducible test results.
44-59: LGTM! Important output range validation.This test correctly validates that the sigmoid activation ensures outputs are within the expected [0, 1] range, which is crucial for variable selection weights.
61-77: LGTM! Good batch size flexibility test.Testing different batch sizes ensures the network can handle varying input dimensions correctly, which is important for practical usage scenarios.
79-100: LGTM! Comprehensive dimension testing.The test covers various input/hidden dimension combinations, including edge cases like hidden_size < input_dimension, ensuring robustness across different network configurations.
102-118: LGTM! Good edge case testing.Testing with zero inputs is important for validating network behavior with boundary conditions. The output range validation ensures the sigmoid activation works correctly even with zero inputs.
120-138: LGTM! Positive input validation.Testing with positive-only inputs helps validate network behavior across different input distributions while maintaining output range constraints.
140-158: LGTM! Negative input validation.Testing with negative-only inputs complements the positive input test and ensures the network handles the full range of possible input values correctly.
160-176: LGTM! Essential determinism test.This test validates that the network produces consistent outputs for identical inputs, which is crucial for reproducible results in production environments.
178-194: LGTM! Good minimal dimension test.Testing with single input/output dimension ensures the network works correctly even at the smallest meaningful scale, validating the implementation's robustness at edge cases.
Dockerfile.tests (2)
20-28: LGTM! Correct path updates for repository restructuring.The COPY commands have been properly updated to reflect the new repository structure:
applications/datamanager/(wasapplication/datamanager/)applications/portfoliomanager/(wasapplication/positionmanager/)applications/models/(wasapplication/predictionengine/)libraries/python/(wasworkflows/)These changes align with the PR's objective of restructuring the codebase.
32-34: LGTM! Bulk directory copies for new structure.Adding bulk copies for
applications/andlibraries/directories ensures all necessary code is available in the container for the new repository structure.libraries/python/src/internal/dates.py (2)
51-55: LGTM! Well-implemented utility method.The
to_objectmethod provides a clean way to convert the DateRange to a dictionary with ISO-formatted date strings. This appears to replace functionality that was previously in the removedBarsSummaryclass.
8-8: Rename verified — no remainingSummaryDatereferencesSearch of the repository found no occurrences of
SummaryDate. The new class is defined at:
- libraries/python/src/internal/dates.py: class Date(BaseModel)
And usages in the repo are updated to
Date(e.g. libraries/python/tests/test_dates.py).libraries/python/tests/test_mhsa_network.py (2)
1-1: LGTM! Correct import path update.The import has been properly updated to use the new internal module structure, changing from the old
application.predictionenginepath tointernal.mhsa_network.
11-11: LGTM! Consistent class name updates.All test instantiations have been correctly updated from
MultiHeadSelfAttentiontoMultiHeadSelfAttentionNetwork, maintaining consistency with the new class name and using the proper constructor parameters.Also applies to: 18-18, 34-36, 46-46, 55-55, 65-65
libraries/python/src/internal/variable_selection_network.py (3)
1-3: LGTM! Clean imports.The imports are minimal and appropriate for the functionality being implemented.
5-16: LGTM! Well-structured network initialization.The constructor properly initializes two linear layers with the correct dimensions for a variable selection network. The architecture follows the typical pattern of mapping from input dimension to hidden size and back to input dimension for feature weighting.
17-21: LGTM! Correct forward pass implementation.The forward method implements the expected variable selection network behavior:
- Linear transformation to hidden space
- ReLU activation for non-linearity
- Linear transformation back to input dimension
- Sigmoid activation to produce weights in [0, 1] range
This is appropriate for variable selection where the output represents feature importance weights.
libraries/python/tests/test_loss_functions.py (7)
3-3: LGTM! Import path correctly updated.The import path has been successfully updated from the old prediction engine module to the new internal loss functions module.
11-13: LGTM! Predictions tensor shape updated correctly.The predictions tensor now has the expected 3D shape
[batch_size, output_size, len(quantiles)]with the last dimension matching the number of quantiles, which aligns with the new quantile loss function signature.
22-24: LGTM! Shape changes consistent across all test cases.All test cases have been consistently updated to use 3D predictions tensors and quantiles as lists rather than tuples.
32-35: LGTM! Perfect prediction test updated correctly.The test correctly uses the new 3D tensor format where all quantiles predict the same value (2.0), which should result in zero loss for a perfect prediction.
42-45: LGTM! Different quantiles test properly updated.The test correctly creates a 5-quantile prediction tensor with shape
[1, 1, 5]matching the 5-element quantiles list.
54-57: LGTM! Shape test updated for new tensor dimensions.The test now correctly generates 3D predictions with shape
(batch_size, 1, 3)matching the 3-quantile format expected by the new loss function.
64-66: LGTM! Invalid quantiles test updated correctly.The test properly uses the new 3D tensor format and list-based quantiles while maintaining the validation logic for invalid quantile values.
libraries/python/src/internal/loss_functions.py (4)
1-1: LGTM! Appropriate import for the tensor library.The import is correct for the TinyGrad tensor operations used in the function.
4-8: LGTM! Function signature well-documented.The function signature is clear with proper type hints and inline comments describing the expected tensor shapes.
12-14: LGTM! Input validation preserved.The quantile range validation logic is correctly preserved from the original implementation.
16-28: LGTM! Quantile loss computation is mathematically correct.The implementation correctly computes the quantile loss using the standard formula:
- For each quantile, it computes
error = targets - predictions[:, :, index]- Applies the asymmetric loss:
max(quantile * error, (quantile - 1) * error)- Uses
Tensor.wherefor the conditional logic and averages across the batchThe final division by
len(quantiles)ensures the loss is normalized across quantiles.libraries/python/tests/test_equity_bar.py (12)
1-5: LGTM! Imports are appropriate.All necessary imports are present for testing the EquityBar model with proper Pydantic validation.
8-35: LGTM! Comprehensive valid creation test.The test thoroughly validates all fields of a successfully created EquityBar instance with realistic financial data.
37-50: LGTM! Ticker normalization test is effective.The test correctly verifies that lowercase tickers are automatically converted to uppercase by the validator.
52-65: LGTM! Whitespace handling test is thorough.The test confirms that leading and trailing whitespace is properly stripped from ticker symbols.
67-81: LGTM! Empty ticker validation test is correct.The test properly verifies that empty ticker strings raise a ValidationError with the expected message.
83-97: LGTM! Whitespace-only ticker validation test is comprehensive.The test ensures that ticker strings containing only whitespace are treated as empty and properly rejected.
99-113: LGTM! Negative price validation test is effective.The test correctly verifies that negative prices trigger validation errors with the expected message.
115-131: LGTM! Zero price acceptance test is appropriate.The test confirms that zero prices are allowed, which is reasonable for financial data (e.g., delisted stocks or special situations).
133-146: LGTM! ISO format timestamp test is correct.The test validates that timestamp parsing from ISO format strings works as expected.
148-168: LGTM! Comprehensive price field validation test.The test systematically validates that all price fields (open, high, low, close) properly reject negative values.
170-185: LGTM! Large volume test demonstrates scalability.The test with volume
10**12ensures the model can handle large trading volumes that occur in real markets.
187-203: LGTM! Special ticker symbols test covers real-world cases.The test includes realistic special ticker symbols like "BRK.B" and "BF-B" that exist in actual markets, ensuring the model handles these edge cases properly.
libraries/python/tests/test_dataset.py (4)
1-2: LGTM! Imports are correct.Proper imports for Polars DataFrame handling and the internal dataset module.
5-47: LGTM! Dataset loading test is comprehensive.The test creates a realistic DataFrame with all required columns (timestamp, OHLCV data, ticker, sector, industry, is_holiday) and validates that the dataset initialization exposes the expected public attributes.
49-79: LGTM! Dimensions test validates feature configuration.The test correctly verifies that the dataset exposes all expected dimension keys for encoder/decoder categorical/continuous features and static features, which is essential for TFT model configuration.
81-135: LGTM! Batch generation test validates tensor shapes.The test thoroughly validates:
- Batch generation returns a list structure
- Expected number of batches (1 for the single ticker)
- Presence of all required batch components
- Correct tensor shapes matching the specified input/output lengths
This ensures the dataset properly formats data for TFT model consumption.
libraries/python/src/internal/cloud_event.py (4)
1-4: LGTM! Imports are appropriate.Proper imports for datetime handling, timezone support, and CloudEvent creation.
7-21: LGTM! Success event factory function is well-designed.The function provides a clean interface for creating success CloudEvents with:
- Parameterized application name and event metadata
- Automatic timestamp generation in the correct timezone
- Flexible data payload merging
24-38: LGTM! Error event factory function mirrors success pattern.The error function follows the same pattern as the success function but with a specific error payload structure, providing consistency across the API.
18-18: LGTM! Consistent timezone usage.Both functions consistently use "America/New_York" timezone and ISO format for timestamps, ensuring standardized event timestamps across the system.
Also applies to: 35-35
libraries/python/tests/test_dates.py (10)
1-8: LGTM!The imports are appropriately organized and all necessary modules are included for the test coverage.
10-19: LGTM!The test correctly validates the default date behavior using the mocked datetime to ensure deterministic testing.
22-26: LGTM!The test appropriately validates explicit date assignment functionality.
35-40: LGTM!The test correctly validates the JSON encoding format, confirming that dates are serialized in the expected "YYYY/MM/DD" format.
43-50: LGTM!The test appropriately validates that DateRange accepts valid date ranges.
53-60: LGTM!The test correctly validates that DateRange rejects invalid date ranges where the end date is before the start date.
63-69: LGTM!The test appropriately validates that DateRange rejects ranges where start and end dates are identical.
72-79: LGTM!The test correctly validates that DateRange accepts minimal valid ranges where dates are one day apart.
82-94: LGTM!The test correctly validates the
to_object()method's output format with ISO date strings.
97-109: LGTM!The test appropriately validates the
to_object()method with leap year dates, ensuring correct handling of edge cases.libraries/python/src/internal/equity_bar.py (5)
1-4: LGTM!The imports are appropriate for the Pydantic model implementation.
6-16: LGTM!The model fields are well-defined with clear descriptions. Using
floatfor volume is appropriate as it can accommodate fractional shares in modern trading systems.
18-24: LGTM!The ticker validation correctly handles empty strings, trims whitespace, and normalizes to uppercase for consistency.
26-38: Consider allowing zero prices for specific scenariosThe current validation allows zero prices, which may be appropriate for certain edge cases (e.g., stock halts, data gaps). However, you should verify this aligns with your business requirements.
Based on the test coverage and workflows that use this model, zero prices are intentionally allowed. This is appropriate for handling edge cases in equity data.
40-49: LGTM!The timestamp validation correctly handles both date objects and ISO-formatted strings with appropriate error messages.
pyproject.toml (5)
1-9: LGTM!The workspace configuration successfully transitions to a scalable glob-based pattern that simplifies dependency management.
11-12: LGTM!The simplified workspace member pattern using globs is more maintainable than explicit listings.
34-43: LGTM!The coverage configuration is well-structured with appropriate exclusions and XML output for CI integration.
45-98: LGTM!The comprehensive linting rules provide thorough code quality checks while appropriately ignoring certain rules in test files.
104-111: LGTM!The type checking configuration appropriately balances strictness with practicality.
libraries/python/src/internal/mhsa_network.py (2)
1-4: LGTM!The imports are appropriate for the multi-head self-attention implementation using tinygrad.
6-28: LGTM!The initialization correctly validates the embedding dimension divisibility and sets up all necessary projection layers with appropriate scaling.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (5)
applications/models/src/models/get_alpaca_calendar.py (1)
12-20: Good: explicit env-var validation with clear error messageThis addresses the previous review’s concern and prevents cryptic failures when credentials are missing. Logging before raising is also helpful.
applications/models/src/models/train_tft_model.py (4)
1-11: Fix import ordering according to project Ruff configuration.The imports are not properly ordered according to the project's configuration. The static analysis tool is flagging this as a failure.
Apply this fix to sort the imports correctly:
from datetime import datetime from zoneinfo import ZoneInfo import polars as pl +import wandb from flytekit import task, workflow from internal.dataset import TemporalFusionTransformerDataset from internal.tft_model import Parameters, TemporalFusionTransformer from loguru import logger +from wandb import Run - -import wandb -from wandb import Run
135-142: Add error handling for WandB initialization.The WandB initialization could fail due to network issues, missing API keys, or other configuration problems, which would crash the entire training pipeline. This is especially important since this is a Flyte workflow that should be robust.
Add error handling for WandB initialization:
+ try: if wandb.run is not None: wandb.finish() # close active run if it exists wandb_run = wandb.init( project="Pocket Size Fund", config=configuration, name=f"tft-model-run-{datetime.now(tz=ZoneInfo('America/New_York')).strftime('%Y-%m-%d_%H-%M-%S')}", ) + except Exception as e: + logger.warning(f"Failed to initialize WandB: {e}. Continuing without logging.") + wandb_run = NoneThen update the train_model task call to handle the case where wandb_run might be None:
model = train_model( dataset=dataset, # type: ignore[assignment] validation_split=configuration["validation_split"], epoch_count=configuration["epoch_count"], learning_rate=configuration["learning_rate"], - wandb_run=wandb_run, + wandb_run=wandb_run, # May be None if WandB init failed )
82-83: Fix incorrect handling of losses data structure.The code assumes
lossesis iterable (like a list), but based on the TFT model implementation inlibraries/python/src/internal/tft_model.py, thetrainmethod returns{"losses": list_of_losses}. This will cause a runtime error when trying to iterate over the dictionary keys instead of the actual loss values.Apply this fix to properly handle the losses dictionary:
- for loss in losses: - wandb_run.log({"loss": loss}) + for loss in losses["losses"]: + wandb_run.log({"loss": loss})
82-85: Guard WandB logging against initialization failures.The code assumes
wandb_runis always valid, but if WandB initialization fails (as addressed in the previous comment), this will cause a runtime error.Add a guard to check if WandB is available before logging:
- for loss in losses["losses"]: - wandb_run.log({"loss": loss}) - - wandb_run.finish() + if wandb_run is not None: + for loss in losses["losses"]: + wandb_run.log({"loss": loss}) + wandb_run.finish()
🧹 Nitpick comments (6)
applications/models/src/models/get_alpaca_calendar.py (5)
21-25: Parse ALPACA_PAPER more robustly (accept 1/0, yes/no, on/off, etc.)Current lower() == "true" is fine but brittle. Consider a small helper to handle common truthy/falsey values.
Apply this change to where the client is created:
- paper=os.getenv("ALPACA_PAPER", "true").lower() == "true", + paper=env_bool("ALPACA_PAPER", default=True),Add this helper near the imports:
def env_bool(name: str, default: bool = True) -> bool: raw = os.getenv(name) if raw is None: return default return raw.strip().lower() in {"1", "true", "t", "yes", "y", "on"}
41-43: Consider catching SDK-specific exceptions and adding minimal retryCatching Exception and re-raising is acceptable, but you’ll get better diagnostics and resiliency by handling known Alpaca/HTTP exceptions explicitly and retrying transient failures (e.g., network hiccups, 429 rate limits).
- Catch SDK exceptions (e.g., API/HTTP errors) explicitly.
- Add a small backoff retry (e.g., 2–3 attempts with exponential backoff).
- Keep the broad except as a final fallback if you prefer.
I can provide a concrete snippet tailored to the Alpaca SDK you’re using if you confirm the exception class names exposed by your version.
45-52: Preserve native date/time types into Polars for richer typingPolars can handle Python date/time types directly. Avoiding str() here preserves richer types and can help downstream operations.
- { - "date": str(calendar.date), - "open": str(calendar.open), - "close": str(calendar.close), - } + { + "date": calendar.date, + "open": calendar.open, + "close": calendar.close, + }
56-58: Make output path configurable and ensure directory existsDefaults are fine, but allowing an env var and ensuring parent directories exist improves UX in CI and scripts.
- calendar_content.write_csv("calendar.csv") - - logger.info("Calendar data has been written to calendar.csv") + output_path = os.getenv("ALPACA_CALENDAR_CSV", "calendar.csv") + try: + from pathlib import Path + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + except Exception: + # Non-fatal: if path has no parent or creation fails, continue and let write_csv raise + pass + calendar_content.write_csv(output_path) + logger.info(f"Calendar data has been written to {output_path}")
12-58: Optional: Factor logic into small functions for reuse and testabilityRight now the flow is script-only. Extracting functions (e.g., build_client, fetch_calendar, to_dataframe, write_csv) will make it easier to unit test and reuse in other tools.
Example structure (illustrative):
def build_client(api_key: str, secret_key: str, paper: bool) -> TradingClient: ... def fetch_calendar(client: TradingClient, start: date, end: date) -> list[Calendar]: ... def to_dataframe(calendars: list[Calendar]) -> pl.DataFrame: ... def write_csv(df: pl.DataFrame, path: str) -> None: ...If you want, I can send a concrete refactor patch aligned with your repository style.
applications/models/src/models/train_tft_model.py (1)
144-146: Consider making the training data path configurable.The hardcoded path
"applications/models/src/models/training_data.csv"makes the workflow less flexible and harder to test with different datasets.Consider adding the filepath to the configuration or as a workflow parameter:
configuration = { "architecture": "TFT", "learning_rate": 0.02, "epoch_count": 10, "validation_split": 0.8, + "training_data_path": "applications/models/src/models/training_data.csv", }Then use it in the workflow:
dataset = read_local_data( - filepath="applications/models/src/models/training_data.csv" + filepath=configuration["training_data_path"] ) # type: ignore[assignment]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (6)
applications/models/pyproject.toml(1 hunks)applications/models/src/models/get_alpaca_calendar.py(1 hunks)applications/models/src/models/get_alpaca_equity_bars.py(1 hunks)applications/models/src/models/train_tft_model.py(1 hunks)libraries/python/src/internal/dataset.py(1 hunks)libraries/python/src/internal/tft_model.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
- applications/models/pyproject.toml
- applications/models/src/models/get_alpaca_equity_bars.py
- libraries/python/src/internal/tft_model.py
- libraries/python/src/internal/dataset.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
applications/models/src/models/train_tft_model.py (2)
libraries/python/src/internal/dataset.py (3)
TemporalFusionTransformerDataset(28-326)get_dimensions(240-248)get_batches(250-326)libraries/python/src/internal/tft_model.py (5)
Parameters(19-33)TemporalFusionTransformer(37-273)train(192-227)validate(229-248)save(250-255)
🪛 GitHub Check: Run Python quality checks
applications/models/src/models/train_tft_model.py
[failure] 1-11: Ruff (I001)
applications/models/src/models/train_tft_model.py:1:1: I001 Import block is un-sorted or un-formatted
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Run Python tests
🔇 Additional comments (3)
applications/models/src/models/train_tft_model.py (3)
23-33: LGTM on timing implementation.The timing calculation correctly uses
total_seconds()which captures the full elapsed time including days and fractional seconds.
67-72: LGTM on batch generation.The
data_type="train"parameter is correct and matches the expected API from the dataset implementation. The parameter passing is well-structured.
103-108: LGTM on validation batch generation.The
data_type="validate"parameter is correct and matches the expected dataset API. Good use of model parameters for consistent windowing.
|
@chrisaddy feel free to take a look at this whenever. |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
applications/models/src/models/get_alpaca_equity_bars.py (1)
22-29: Good addition: explicit env var validation (resolves earlier feedback).Validating API credentials up front is the right call and addresses the earlier review request.
🧹 Nitpick comments (6)
applications/models/src/models/get_alpaca_equity_bars.py (6)
30-40: Deduplicate ALPACA_PAPER parsing and reuse for both clients.Parse the flag once to avoid drift and repeated env lookups.
alpaca_trading_client = TradingClient( api_key=api_key, secret_key=secret_key, - paper=os.getenv("ALPACA_PAPER", "true").lower() == "true", + paper=is_paper, ) alpaca_data_client = StockHistoricalDataClient( api_key=api_key, secret_key=secret_key, - sandbox=os.getenv("ALPACA_PAPER", "true").lower() == "true", + sandbox=is_paper, )Add just above the TradingClient construction:
+ is_paper = os.getenv("ALPACA_PAPER", "true").lower() == "true"
46-51: Prefer enum-based filters over magic strings for asset attributes.Using a string for attributes is brittle; favor the enum to catch typos at type-check time and avoid case sensitivity surprises.
- GetAssetsRequest( - status=AssetStatus.ACTIVE, - asset_class=AssetClass.US_EQUITY, - attributes="has_options", - ) + GetAssetsRequest( + status=AssetStatus.ACTIVE, + asset_class=AssetClass.US_EQUITY, + attributes=[AssetAttributes.HAS_OPTIONS], + )You’ll need to import the enum:
-from alpaca.trading.enums import AssetClass, AssetStatus +from alpaca.trading.enums import AssetClass, AssetStatus, AssetAttributesNote: Please verify the exact enum name/value in the installed alpaca-py version (some versions expose AssetAttribute vs AssetAttributes, and member names can vary).
78-84: Use enum members instead of constructing enums from strings.This avoids case-sensitivity bugs and gives IDE/type-checker help.
- timeframe=TimeFrame( - amount=1, - unit=TimeFrameUnit("Day"), - ), - adjustment=Adjustment("all"), - feed=DataFeed("iex"), + timeframe=TimeFrame(amount=1, unit=TimeFrameUnit.Day), + adjustment=Adjustment.ALL, + feed=DataFeed.IEX,
114-118: Vectorize timestamp conversion with Polars dt API for speed and clarity.Avoid Python lambdas over rows; Polars’ native ops are faster and cleaner.
- equity_bars_data = equity_bars_data.with_columns( - ( - pl.col("timestamp").map_elements(lambda x: int(x.timestamp() * 1000)) - ).alias("timestamp") - ) + equity_bars_data = equity_bars_data.with_columns( + pl.col("timestamp").dt.epoch(time_unit="ms").alias("timestamp") + )
130-133: Optional: stream CSV concatenation to reduce memory spikes.For many tickers/files, concat via scan_csv and collect with streaming to limit memory.
- all_bars = pl.concat([pl.read_csv(fp) for fp in saved_files]) - all_bars.write_csv("equity_bars_combined.csv") + all_bars = pl.concat([pl.scan_csv(fp) for fp in saved_files]).collect(streaming=True) + all_bars.write_csv("equity_bars_combined.csv")
60-61: Sanity-check end timestamp for daily bars.Daily bars typically include the last fully settled day. Using “now” in America/New_York is fine, but if the script runs during market hours you may get a partial current-day bar depending on feed behavior. Consider truncating end to previous market close or to the start of today to avoid partials.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
applications/models/src/models/get_alpaca_equity_bars.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
applications/models/src/models/get_alpaca_equity_bars.py (3)
application/datamanager/src/datamanager/main.py (2)
get_equity_bars(149-202)fetch_equity_bars(206-269)application/datamanager/tests/test_datamanager_main.py (2)
test_fetch_equity_bars_success(199-226)TestFetchEquityBarsEndpoint(196-261)workflows/train_predictionengine.py (1)
fetch_data(22-67)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Run Python tests
| if len(equity_bars.dict()) == 0: | ||
| logger.info(f"No equity bars found for {ticker}.") | ||
|
|
||
| time.sleep(rate_limit_sleep) | ||
|
|
||
| continue | ||
|
|
||
| equity_bars_data = pl.DataFrame(equity_bars[ticker]) | ||
|
|
There was a problem hiding this comment.
Bug: empty result detection is incorrect; fix empty-bar handling and DataFrame conversion.
len(equity_bars.dict()) counts keys on the model, not the number of bars for the ticker, so it can falsely pass when there are zero bars. Also, constructing a Polars DataFrame directly from Pydantic model instances often doesn’t yield intended columns; convert to plain dict rows first.
Apply this diff to safely extract bars and handle the empty case, then build a DataFrame from dicts:
- if len(equity_bars.dict()) == 0:
- logger.info(f"No equity bars found for {ticker}.")
-
- time.sleep(rate_limit_sleep)
-
- continue
-
- equity_bars_data = pl.DataFrame(equity_bars[ticker])
+ # Safely get bars for this symbol
+ try:
+ bars_list = equity_bars[ticker]
+ except Exception:
+ bars_list = []
+
+ if not bars_list:
+ logger.info(f"No equity bars found for {ticker}.")
+ time.sleep(rate_limit_sleep)
+ continue
+
+ # Convert model instances to plain dict rows for Polars
+ rows = [
+ b.model_dump() if hasattr(b, "model_dump") else b.dict()
+ for b in bars_list
+ ]
+ equity_bars_data = pl.DataFrame(rows)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if len(equity_bars.dict()) == 0: | |
| logger.info(f"No equity bars found for {ticker}.") | |
| time.sleep(rate_limit_sleep) | |
| continue | |
| equity_bars_data = pl.DataFrame(equity_bars[ticker]) | |
| # Safely get bars for this symbol | |
| try: | |
| bars_list = equity_bars[ticker] | |
| except Exception: | |
| bars_list = [] | |
| if not bars_list: | |
| logger.info(f"No equity bars found for {ticker}.") | |
| time.sleep(rate_limit_sleep) | |
| continue | |
| # Convert model instances to plain dict rows for Polars | |
| rows = [ | |
| b.model_dump() if hasattr(b, "model_dump") else b.dict() | |
| for b in bars_list | |
| ] | |
| equity_bars_data = pl.DataFrame(rows) |
🤖 Prompt for AI Agents
In applications/models/src/models/get_alpaca_equity_bars.py around lines 95–103,
the code checks len(equity_bars.dict()) which counts model keys not the number
of bars for the ticker and then constructs a Polars DataFrame directly from
Pydantic model instances; fix by first extracting the list of bars for the
specific ticker (e.g., bars = equity_bars.dict().get(ticker, []) or
equity_bars[ticker] if that yields a list), check emptiness with if not bars:
logger.info(...); time.sleep(...); continue, and then build the DataFrame from
plain dict rows like pl.DataFrame([bar.dict() if hasattr(bar, "dict") else
dict(bar) for bar in bars]) so columns are created correctly.
Overview
Changes
applications/service packagesapplications/modelspackageinternalpackageComments
Big structural rebuild. This is a fresh pull request because the other one got cluttered with reviews.
Summary by CodeRabbit
New Features
Refactor
Chores
Tests
Documentation