Skip to content

Add initial inference data filtering function#621

Merged
forstmeier merged 5 commits intomasterfrom
08-22-add_initial_inference_data_filtering_function
Aug 24, 2025
Merged

Add initial inference data filtering function#621
forstmeier merged 5 commits intomasterfrom
08-22-add_initial_inference_data_filtering_function

Conversation

@forstmeier
Copy link
Copy Markdown
Collaborator

@forstmeier forstmeier commented Aug 22, 2025

Overview

Changes

  • add "preprocessing" inference filtering logic

Comments

I'll manually test this one and these minimum limits are likely gonna be updated. UPDATE: this has been manually tested and trims the training data from 5,829 tickers to 1,453.

Summary by CodeRabbit

  • New Features

    • Configurable equity filtering that returns per‑ticker aggregates and new calendar features with standardized index types.
    • Added dataset validation schema for time-series inputs.
  • Bug Fixes

    • Standardized timestamps to UTC and ensured inputs are not mutated; enforced consistent integer types.
  • Tests

    • Added comprehensive unit tests for the filtering behavior and edge cases.
  • Chores

    • Enabled UV-based packaging for the portfolio manager.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Aug 22, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds UV packaging config to pyproject.toml. Updates preprocess.filter_equity_bars to accept configurable thresholds and adds tests. Enhances libraries/python/src/internal/tft_dataset.py to clone and validate inputs with a new dataset_schema, switch timestamps to UTC, generate calendar features, and cast time_idx and new features to Int64.

Changes

Cohort / File(s) Summary
Packaging setup
applications/portfoliomanager/pyproject.toml
Adds [tool.uv] with package = true and src = ["src"]; retains [tool.uv.sources] with internal = { workspace = true }.
Preprocessing utility + tests
applications/portfoliomanager/src/portfoliomanager/preprocess.py, applications/portfoliomanager/tests/test_preprocess.py
Updates filter_equity_bars signature to filter_equity_bars(data: pl.DataFrame, minimum_average_close_price: float = 10.0, minimum_average_volume: float = 1_000_000.0) -> pl.DataFrame; function clones input, groups by ticker, computes per-ticker avg_close_price and avg_volume, filters by thresholds, and returns results. Adds comprehensive pytest coverage for success, boundary, empty, single-row, mixed-ticker, and immutability cases.
TFT dataset validation & feature engineering
libraries/python/src/internal/tft_dataset.py
Adds a module-level dataset_schema (pandera.polars DataFrameSchema) and validates input via dataset_schema.validate(data); clones input to avoid mutation; changes timestamp fill/handling to UTC; computes calendar features (day_of_week, day_of_month, day_of_year, month, year) and casts them and time_idx to Int64; exposes dataset_schema as a public symbol.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant C as Caller
  participant P as preprocess.filter_equity_bars
  C->>P: data (pl.DataFrame)
  activate P
  P->>P: clone input dataframe
  P->>P: group_by("ticker") → agg(avg_close_price, avg_volume)
  P->>P: filter(avg_close_price > minimum_average_close_price and avg_volume > minimum_average_volume)
  P-->>C: filtered per-ticker DataFrame
  deactivate P
Loading
sequenceDiagram
  autonumber
  participant U as Upstream Loader
  participant T as TFTDataset.__init__
  participant S as dataset_schema
  participant FE as FeatureEngineering
  participant SC as Scaler

  U->>T: raw data
  activate T
  T->>T: clone data (avoid mutation)
  T->>S: dataset_schema.validate(data)
  S-->>T: validated data
  rect rgba(220,235,255,0.5)
    note right of FE: timezone → UTC\ncompute day_of_week, day_of_month,\nday_of_year, month, year\ncast calendar features & time_idx → Int64
    T->>FE: apply timezone & compute features
    FE-->>T: augmented/cast dataset
  end
  T->>SC: scaling/processing on validated data
  SC-->>T: scaled dataset
  T-->>U: initialized TFTDataset
  deactivate T
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45–70 minutes

Poem

A rabbit taps keys with a soft little grin,
Cloning data lightly, letting checks begin.
UTC moons align, calendar features sprout,
Filters hop through tickers — some in, some out.
Tests clap their paws: packaging purrs, all set to win. 🐇✨

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch 08-22-add_initial_inference_data_filtering_function

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Collaborator Author

forstmeier commented Aug 22, 2025

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds initial data filtering functionality and refactors the TFT dataset processing. It introduces a new filter_equity_bars function to filter financial data based on average close price and volume thresholds, along with schema validation using pandera.

  • Adds filter_equity_bars function with configurable price and volume thresholds
  • Introduces data schema validation using pandera for equity bar data
  • Refactors TFT dataset timezone handling and type casting for consistency

Reviewed Changes

Copilot reviewed 4 out of 5 changed files in this pull request and generated 2 comments.

File Description
libraries/python/src/internal/tft_dataset.py Adds schema validation, changes timezone to UTC, improves type consistency with Int64 casting
applications/portfoliomanager/src/portfoliomanager/preprocess.py Implements new filtering function for equity data based on price and volume thresholds
applications/portfoliomanager/tests/test_preprocess.py Comprehensive test suite covering various filtering scenarios and edge cases
applications/portfoliomanager/pyproject.toml Updates project configuration to support package structure

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment thread applications/portfoliomanager/src/portfoliomanager/preprocess.py Outdated
Comment thread applications/portfoliomanager/src/portfoliomanager/preprocess.py Outdated
@graphite-app
Copy link
Copy Markdown

graphite-app Bot commented Aug 22, 2025

Graphite Automations

"Assign author to pull request" took an action on this PR • (08/22/25)

1 assignee was added to this PR based on John Forstmeier's automation.

Comment thread libraries/python/src/internal/tft_dataset.py Outdated
chrisaddy
chrisaddy previously approved these changes Aug 23, 2025
Base automatically changed from 08-22-temporarily_rename_tft_resources_for_grepability to master August 24, 2025 02:00
@forstmeier forstmeier dismissed chrisaddy’s stale review August 24, 2025 02:00

The base branch was changed.

… github.com:pocketsizefund/pocketsizefund into 08-22-add_initial_inference_data_filtering_function
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)

7-9: Make thresholds configurable via parameters (repeat of earlier feedback)

Hardcoded thresholds make experimentation cumbersome. Thread them as optional parameters with sensible defaults to preserve current behavior.

-def filter_equity_bars(data: pl.DataFrame) -> pl.DataFrame:
+def filter_equity_bars(
+    data: pl.DataFrame,
+    minimum_average_close_price: float = 10.0,
+    minimum_average_volume: float = 1_000_000.0,
+    *,
+    strict: bool = True,
+) -> pl.DataFrame:
     data = data.clone()
 
-    minimum_average_close_price = 10.0
-    minimum_average_volume = 1_000_000.0
+    price_col = pl.col("avg_close_price")
+    vol_col = pl.col("avg_volume")
+    price_pred = price_col > minimum_average_close_price if strict else price_col >= minimum_average_close_price
+    vol_pred = vol_col > minimum_average_volume if strict else vol_col >= minimum_average_volume
 
     return (
         data.group_by("ticker")
         .agg(
             avg_close_price=pl.col("close_price").mean(),
             avg_volume=pl.col("volume").mean(),
         )
-        .filter(
-            (pl.col("avg_close_price") > minimum_average_close_price)
-            & (pl.col("avg_volume") > minimum_average_volume)
-        )
+        .filter(price_pred & vol_pred)
     )
🧹 Nitpick comments (2)
applications/portfoliomanager/pyproject.toml (1)

8-11: Packaging enabled: good step; ensure direct runtime deps are declared

Enabling [tool.uv].package = true means portfoliomanager can be built/installed independently. Since portfoliomanager.preprocess imports polars, declare polars as a direct dependency here instead of relying on transitive installs via internal to prevent import errors for downstream consumers.

Apply:

 [project]
 name = "portfoliomanager"
 version = "0.1.0"
 description = "Portfolio prediction and construction service"
-requires-python = "==3.12.10"
-dependencies = ["internal"]
+requires-python = ">=3.12,<3.13"
+dependencies = [
+  "internal",
+  "polars",  # direct runtime dependency used by portfoliomanager.preprocess
+]
applications/portfoliomanager/tests/test_preprocess.py (1)

1-4: Minor nit: prefer DataFrame.height for row counts

len(df) works, but df.height reads clearer and avoids constructing an iterator in some backends. Optional.

Example:

-assert len(result) == 1
+assert result.height == 1
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 987a519 and 4cd84be.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (4)
  • applications/portfoliomanager/pyproject.toml (1 hunks)
  • applications/portfoliomanager/src/portfoliomanager/preprocess.py (1 hunks)
  • applications/portfoliomanager/tests/test_preprocess.py (1 hunks)
  • libraries/python/src/internal/tft_dataset.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)
applications/datamanager/src/datamanager/alpaca_client.py (1)
  • fetch_latest_data (64-170)
libraries/python/src/internal/tft_dataset.py (2)
libraries/python/src/internal/dataset.py (1)
  • TemporalFusionTransformerDataset (28-321)
libraries/python/tests/test_equity_bar.py (9)
  • test_equity_bar_schema_type_coercion (180-198)
  • test_equity_bar_schema_valid_data (7-22)
  • test_equity_bar_schema_zero_timestamp (79-94)
  • test_equity_bar_schema_negative_timestamp (61-76)
  • test_equity_bar_schema_multiple_rows (237-253)
  • test_equity_bar_schema_ticker_uppercase_passes (43-58)
  • test_equity_bar_schema_negative_volume (144-159)
  • test_equity_bar_schema_ticker_lowercase_fails (25-40)
  • test_equity_bar_schema_missing_required_column (201-216)
applications/portfoliomanager/tests/test_preprocess.py (1)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)
  • filter_equity_bars (4-20)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Run Python quality checks
  • GitHub Check: Run Python tests
🔇 Additional comments (9)
applications/portfoliomanager/src/portfoliomanager/preprocess.py (1)

4-20: Solid, side-effect-free filtering; aligns with tests

Cloning the input for immutability and using per-ticker aggregates with strict “greater-than” thresholds is clear and easy to reason about. The returned schema (ticker, avg_close_price, avg_volume) matches the tests.

libraries/python/src/internal/tft_dataset.py (5)

30-31: Good call to clone the input

Defensive copy avoids mutating upstream dataframes during feature engineering.


147-154: Explicit Int64 casts for calendar features improve dtype stability

Casting calendar columns to Int64 ensures deterministic dtypes across platforms and Polars versions.


165-166: Schema validation placed pre-scaling: good

Validating full, engineered-frame dtypes before scaling/mapping will surface data issues early and keeps the scaler’s inputs consistent.


140-143: Timezone change to UTC requires downstream validation

The recent switch in tft_dataset.py from using America/New_York–derived midnight to UTC for timestamp normalization may shift epoch seconds by up to 5 hours (depending on DST), which can impact any downstream logic or consumers built around NY-local dates. Please review the following areas where America/New_York–based timestamps are explicitly used and confirm that their behavior remains correct:

• libraries/python/src/internal/cloud_event.py (lines 16–20, 33–37): event “date” fields generated with ZoneInfo("America/New_York").
• libraries/python/src/internal/dates.py (lines 9–13, 21–25): default_factory and parser .replace(tzinfo=ZoneInfo("America/New_York")).
• libraries/python/tests/test_dates.py (lines 11–15): tests mocking datetime.now(tz=ZoneInfo("America/New_York")).
• applications/models/src/models/combine_tft_data_sources.py (lines 40–44): output filename timestamp uses NY timezone.
• applications/models/src/models/train_tft_model.py (lines 20–30, 44–53, 87–95, 100–108, 123–131, 141–142): logging runtimes and run names all use ZoneInfo("America/New_York").
• applications/models/src/models/get_alpaca_equity_bars.py (lines 61–63): end = datetime.now(tz=ZoneInfo("America/New_York")).
• applications/datamanager/src/datamanager/alpaca_client.py (lines 102–114): converts naive timestamps to UTC then compares dates in NY to determine bar grouping.
• applications/datamanager/tests/test_alpaca_client.py (lines 36–44, 72–76, 130–138): tests assume bar timestamps at NY 16:00 local.

Double-check each consumer’s use of epoch seconds or date comparisons to ensure no unintended shifts occur now that tft_dataset uses UTC normalization.


156-163: time_idx bump to Int64 is safe — no Int32 consumers found

I’ve verified that:

  • The only definitions and uses of time_idx are within tft_dataset.py, and it’s now consistently cast to pl.Int64.
  • The Pandera schema defines "time_idx": pa.Column(int, required=True), which maps to Python’s int (i.e. NumPy/Pandas int64) rather than a 32-bit type.

No other code paths serialize or consume time_idx as a 32-bit integer. You can safely keep the Int64 change.

applications/portfoliomanager/tests/test_preprocess.py (3)

6-25: Tests hit the intended grouping/aggregation behavior

Covers strict “greater-than” semantics and validates both aggregate values and output schema. Nice.


73-114: Boundary and near-boundary coverage is thorough

Exact-threshold exclusion and “just above” inclusion guard against off-by-one mistakes.


197-214: Immutability test is valuable

Confirms the clone pattern is honored and protects callers from accidental mutation.

Comment thread libraries/python/src/internal/tft_dataset.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
libraries/python/src/internal/tft_dataset.py (2)

103-121: Bug: weekday/weekend holiday mapping is inverted for null rows

The logic sets is_holiday=True on weekdays (Mon–Fri) and False on weekends for rows where is_holiday is null. That’s the opposite of typical market calendars. This will mislabel most business days as holidays and weekends as non-holidays, skewing categorical signals and training targets.

Proposed fix (invert the branches):

-                .then(True)  # noqa: FBT003
+                .then(False)  # weekdays are not holidays
 ...
-                .then(False)  # noqa: FBT003
+                .then(True)   # weekends are holidays

If you intend “holiday” to mean “non-trading day,” this inversion is required. If you intend “holiday” to specifically mean exchange-observed holidays (subset of weekdays), consider driving this from an exchange calendar (e.g., pandas-market-calendars) rather than heuristics.


126-145: Avoid filling OHLC/VWAP with zeros; prefer forward-fill (per ticker) and zero volume only on non-trading days

Filling prices with 0.0 creates non-physical values and teaches the model that prices collapse to zero on missing days. Forward-filling prices for non-trading days is a better default; volume can be zero on those days.

Example refactor (preserves your null-timestamp fallback; forward-fills per ticker after sorting):

-        data = data.with_columns(
-            [
-                pl.col("open_price").fill_null(0.0),
-                pl.col("high_price").fill_null(0.0),
-                pl.col("low_price").fill_null(0.0),
-                pl.col("close_price").fill_null(0.0),
-                pl.col("volume").fill_null(0.0),
-                pl.col("volume_weighted_average_price").fill_null(0.0),
-                pl.col("sector").fill_null("Not Available"),
-                pl.col("industry").fill_null("Not Available"),
-                pl.col("ticker").fill_null("UNKNOWN"),
-                pl.col("timestamp").fill_null(
-                    pl.col("date")
-                    .cast(pl.Datetime)
-                    .dt.replace_time_zone("UTC")
-                    .cast(pl.Int64)
-                    .floordiv(1000)
-                ),
-            ]
-        )
+        data = data.sort(["ticker", "date"]).with_columns(
+            [
+                # forward-fill prices within each ticker
+                pl.col("open_price").fill_null(strategy="forward").over("ticker"),
+                pl.col("high_price").fill_null(strategy="forward").over("ticker"),
+                pl.col("low_price").fill_null(strategy="forward").over("ticker"),
+                pl.col("close_price").fill_null(strategy="forward").over("ticker"),
+                # zero volume on non-trading days (still nulls after join)
+                pl.col("volume").fill_null(0.0),
+                pl.col("volume_weighted_average_price").fill_null(
+                    strategy="forward"
+                ).over("ticker"),
+                pl.col("sector").fill_null("Not Available"),
+                pl.col("industry").fill_null("Not Available"),
+                pl.col("ticker").fill_null("UNKNOWN"),
+                pl.col("timestamp").fill_null(
+                    pl.col("date")
+                    .cast(pl.Datetime)
+                    .dt.replace_time_zone("UTC")
+                    .dt.epoch("ms")
+                ),
+            ]
+        )

This prevents artificial price shocks while keeping volume semantics reasonable.

🧹 Nitpick comments (4)
libraries/python/src/internal/tft_dataset.py (4)

140-143: Use dt.epoch('ms') for clarity and unit safety when deriving fallback timestamps

Casting to Int64 and dividing by 1,000 relies on the engine’s default time unit. dt.epoch('ms') makes the intent explicit and robust.

-                    .cast(pl.Int64)
-                    .floordiv(1000)
+                    .dt.epoch("ms")

165-165: Good: validate with schema before scaling

Validating prior to scaling prevents dirty data from contaminating the scaler. Consider strict=True (see next comment) if you want to forbid extra columns creeping in.


330-382: Tighten schema: set strict=True and align with upstream equity schema to avoid drift

  • Add strict=True to reject unexpected columns.
  • Optional: add ordered=False explicitly for clarity (defaults are fine).
  • Ensure this schema stays in sync with any upstream equity_bar_schema to avoid silent divergence in constraints.

Proposed minimal change:

-dataset_schema = pa.DataFrameSchema(
-    {
+dataset_schema = pa.DataFrameSchema(
+    {
         # columns...
-    }
-)
+    },
+    strict=True,
+)

If upstream uses different dtypes (e.g., volume as float), decide on one canonical schema and convert at the ingestion boundary.


86-101: Scope of backfilling: avoid creating rows outside each ticker’s active date range

Using global min/max dates and a cross join will create rows for dates where a ticker didn’t exist/list—later filled by forward-fill/zeros—introducing long tails of synthetic data. Prefer per-ticker date ranges:

  • Compute per-ticker min/max dates.
  • Join each ticker to its own date range.

If you’d like, I can draft a per-ticker date_range construction using Polars’ group-wise operations.

Also applies to: 97-101, 123-123

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4cd84be and b9d2f20.

📒 Files selected for processing (2)
  • applications/portfoliomanager/src/portfoliomanager/preprocess.py (1 hunks)
  • libraries/python/src/internal/tft_dataset.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • applications/portfoliomanager/src/portfoliomanager/preprocess.py
🧰 Additional context used
🧬 Code graph analysis (1)
libraries/python/src/internal/tft_dataset.py (2)
libraries/python/src/internal/dataset.py (3)
  • TemporalFusionTransformerDataset (28-321)
  • __init__ (29-188)
  • Scaler (10-25)
libraries/python/tests/test_dataset.py (1)
  • test_dataset_load_data (5-45)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Run Python tests
🔇 Additional comments (4)
libraries/python/src/internal/tft_dataset.py (4)

1-3: Imports look good; Polars + Pandera(polars) + date type are correctly wired

Nothing to change here; the date import is needed for the schema and pandera.polars is the right module.


30-31: Nice: cloning input avoids side effects

Cloning data up front prevents subtle caller mutations. Good call.


147-154: Confirm Zero-Based day_of_week Representation

I’ve audited all occurrences of dt.weekday() and the day_of_week column in the repo—both reside in tft_dataset.py:

  • libraries/python/src/internal/tft_dataset.py
    • Lines 103–107: computes a temporary weekday via pl.col("date").dt.weekday().alias("temporary_weekday")
    • Lines 147–152: assigns day_of_week = pl.col("date").dt.weekday().alias("day_of_week")

No offsets (+1) or alternative conventions were found elsewhere. If your downstream embeddings or persisted models expect Monday = 1…Sunday = 7, please adjust these to dt.weekday() + 1. Otherwise, consider adding a brief doc comment (in the function docstring or module README) to clarify that day_of_week is zero-based, ensuring consistency across training and inference.


156-163: time_idx indexing is acceptable as-is

After reviewing the downstream usage of time_idx—sorting by it in the batching logic (around line 295) and declaring it in the schema (around line 380)—there are no existing consumers expecting a 0-based index. If you prefer 0-based numbering for array-style consistency, you can adjust the rank:

- .rank("dense")
+ .rank("dense") - 1

However, since no parts of the code assume 0-based indexing, the current 1-based implementation is fine.

@forstmeier forstmeier merged commit 5df48e6 into master Aug 24, 2025
4 checks passed
@forstmeier forstmeier deleted the 08-22-add_initial_inference_data_filtering_function branch August 24, 2025 02:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants