Integrate predictionengine into the workflows#588
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
|
Warning This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
Graphite Automations"Assign author to pull request" took an action on this PR • (06/05/25)1 assignee was added to this PR based on John Forstmeier's automation. |
There was a problem hiding this comment.
Pull Request Overview
Integrate the prediction engine into the broader workflows by updating type hints, standardizing randomness in tests, and cleaning up import and exception patterns.
- Migrated all
typing.*imports to built-in generic types (e.g.dict[str, ...],list[...]). - Replaced
np.random.randnwithnp.random.default_rng().standard_normalin tests for consistency. - Refactored exception
raisestatements to assignmessagevariables (enabling lint suppression) and reorganized imports.
Reviewed Changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| application/predictionengine/tests/test_multi_head_self_attention.py | Switched to default_rng and added # noqa comments on assertions |
| application/predictionengine/src/predictionengine/multi_head_self_attention.py | Updated type hints to built-in generics; refactored cast usage |
| application/predictionengine/src/predictionengine/post_processor.py | Migrated typing generics to built-ins; standardized exception messages |
| application/predictionengine/src/predictionengine/miniature_temporal_fusion_transformer.py | Renamed parameters and updated method signatures; reorganized imports |
| application/predictionengine/src/predictionengine/main.py | Introduced SEQUENCE_LENGTH; parameterized hard-coded values; updated route decorator |
Comments suppressed due to low confidence (1)
application/predictionengine/src/predictionengine/main.py:94
- Removing
response_model=PredictionResponsefrom this route decorator may lead to missing automatic response validation and documentation. Re-addresponse_model=PredictionResponseif you rely on FastAPI's schema enforcement.
@application.post("/create-predictions")
| ) | ||
|
|
||
| attention_weights: Tensor = cast(Tensor, attention_scores).softmax(axis=-1) | ||
| attention_weights: Tensor = cast("Tensor", attention_scores).softmax(axis=-1) |
There was a problem hiding this comment.
typing.cast expects a type object rather than a string literal. Update to cast(Tensor, attention_scores) to ensure the intended type cast for static analysis.
| attention_weights: Tensor = cast("Tensor", attention_scores).softmax(axis=-1) | |
| attention_weights: Tensor = cast(Tensor, attention_scores).softmax(axis=-1) |
| error = cast("Tensor", y_true - y_pred) | ||
| for quantile in quantiles: | ||
| quantile_error = cast(Tensor, quantile * error) | ||
| quantile_minus_one_error = cast(Tensor, (quantile - 1) * error) | ||
| quantile_error = cast("Tensor", quantile * error) | ||
| quantile_minus_one_error = cast("Tensor", (quantile - 1) * error) |
There was a problem hiding this comment.
The first argument to typing.cast must be a type, not a string. Change to cast(Tensor, y_true - y_pred) for correct type annotation.
| residual = features | ||
|
|
||
| gated_output = cast(Tensor, gate_state * output_state + residual) | ||
| gated_output = cast("Tensor", gate_state * output_state + residual) |
There was a problem hiding this comment.
typing.cast should be passed the Tensor class, not a string. Use cast(Tensor, gate_state * output_state + residual) so type checkers recognize the cast.
| gated_output = cast("Tensor", gate_state * output_state + residual) | |
| gated_output = cast(Tensor, gate_state * output_state + residual) |
| except Exception as e: | ||
| except Exception as e: # noqa: BLE001 | ||
| logger.warning(f"Failed to load model weights: {e}") | ||
| logger.warning(f"Failed to load model weights: {e}") |
There was a problem hiding this comment.
[nitpick] This logger.warning is duplicated immediately after the newly added one. Consider removing one to avoid redundant log entries.
| logger.warning(f"Failed to load model weights: {e}") |
49b392f
into
06-02-fix_inter-service_communication_patterns

Overview
Changes
predictionengineresource integration into model training Flyte workflowComments
There are a bunch of linting fixes in this PR so I'm sure there will be some merge conflicts and rebasing in the future. Also, I haven't tested this yet.