-
Couldn't load subscription status.
- Fork 5
add streaming OLS implementation with duckdb arrow interop #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,10 @@ | ||
| """ | ||
| .. include:: ../README.md | ||
| """ | ||
| from .estimators import ( | ||
| DuckRegression, | ||
| DuckMundlak, | ||
| DuckMundlakEventStudy, | ||
| ) | ||
| from .regularized import DuckRidge | ||
| from .stream import StreamingRegression |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,150 @@ | ||||||||||
| """ | ||||||||||
| Streaming regression leveraging DuckDB's native Arrow IPC support. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| import numpy as np | ||||||||||
| import duckdb | ||||||||||
| from dataclasses import dataclass | ||||||||||
| from typing import Optional, Iterator | ||||||||||
|
|
||||||||||
|
|
||||||||||
| @dataclass | ||||||||||
| class RegressionStats: | ||||||||||
| """Sufficient statistics for streaming regression.""" | ||||||||||
|
|
||||||||||
| XtX: Optional[np.ndarray] = None | ||||||||||
| Xty: Optional[np.ndarray] = None | ||||||||||
| yty: float = 0.0 | ||||||||||
| n: int = 0 | ||||||||||
| k: Optional[int] = None | ||||||||||
|
|
||||||||||
| def update(self, X: np.ndarray, y: np.ndarray) -> "RegressionStats": | ||||||||||
| """Update statistics with new batch.""" | ||||||||||
| n_batch, k_batch = X.shape | ||||||||||
|
|
||||||||||
| if self.XtX is None: | ||||||||||
| self.k = k_batch | ||||||||||
| self.XtX = np.zeros((k_batch, k_batch)) | ||||||||||
| self.Xty = np.zeros(k_batch) | ||||||||||
|
|
||||||||||
| self.XtX += X.T @ X | ||||||||||
| self.Xty += X.T @ y | ||||||||||
| self.yty += y @ y | ||||||||||
| self.n += n_batch | ||||||||||
| return self | ||||||||||
|
|
||||||||||
| def solve_ols(self) -> np.ndarray: | ||||||||||
| """Compute OLS coefficients.""" | ||||||||||
| if self.n < self.k: | ||||||||||
| return None | ||||||||||
| self.check_condition() | ||||||||||
| return np.linalg.solve(self.XtX, self.Xty) | ||||||||||
|
|
||||||||||
| def solve_ridge(self, lambda_: float = 1.0) -> np.ndarray: | ||||||||||
| """Compute Ridge coefficients.""" | ||||||||||
| if self.XtX is None: | ||||||||||
| return None | ||||||||||
| XtX_reg = self.XtX + lambda_ * np.eye(self.k) | ||||||||||
| return np.linalg.solve(XtX_reg, self.Xty) | ||||||||||
|
|
||||||||||
| def check_condition(self, threshold: float = 1e10): | ||||||||||
| """Check the condition number of the XtX matrix.""" | ||||||||||
| if self.XtX is None: | ||||||||||
| return None | ||||||||||
| cond = np.linalg.cond(self.XtX) | ||||||||||
| if cond > threshold: | ||||||||||
| import warnings | ||||||||||
apoorvalal marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
|
||||||||||
| warnings.warn( | ||||||||||
| f"High condition number: {cond:.2e}. Consider using Ridge regression." | ||||||||||
| ) | ||||||||||
| return cond | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class DuckDBArrowStream: | ||||||||||
| """ | ||||||||||
| Stream data from DuckDB using native Arrow IPC support. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
| conn: duckdb.DuckDBPyConnection, | ||||||||||
| query: str, | ||||||||||
| chunk_size: int = 10000, | ||||||||||
| feature_cols: list[str] = None, | ||||||||||
| target_col: str = None, | ||||||||||
|
Comment on lines
+74
to
+75
|
||||||||||
| feature_cols: list[str] = None, | |
| target_col: str = None, | |
| feature_cols: Optional[list[str]] = None, | |
| target_col: Optional[str] = None, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,4 @@ tqdm | |
| duckdb | ||
| numba | ||
| pdoc | ||
| pyarrow | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import duckdb | ||
| import numpy as np | ||
| import pytest | ||
| from duckreg.stream import StreamingRegression | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def duckdb_conn(): | ||
| """Create an in-memory DuckDB connection.""" | ||
| conn = duckdb.connect(':memory:') | ||
| yield conn | ||
| conn.close() | ||
|
|
||
|
|
||
| def test_streaming_regression(duckdb_conn): | ||
| """Test streaming regression with a simple example.""" | ||
| # Create sample data | ||
| duckdb_conn.execute(""" | ||
| CREATE TABLE regression_data AS | ||
| WITH features AS ( | ||
| SELECT | ||
| random() as x0, | ||
| random() as x1, | ||
| random() as x2 | ||
| FROM generate_series(1, 100000) t(i) | ||
| ) | ||
| SELECT | ||
| x0, | ||
| x1, | ||
| x2, | ||
| 2.0*x0 - 1.5*x1 + 0.8*x2 + 0.1*random() as y | ||
| FROM features | ||
| """) | ||
|
|
||
| # Perform streaming regression | ||
| stream_reg = StreamingRegression.from_table(duckdb_conn, "regression_data") | ||
| stream_reg.fit(feature_cols=["x0", "x1", "x2"], target_col="y") | ||
| beta = stream_reg.solve_ols() | ||
|
|
||
| # Check the results | ||
| true_beta = np.array([2.0, -1.5, 0.8]) | ||
| assert np.allclose(beta, true_beta, atol=0.1) | ||
|
|
||
| # Check that the condition number warning is raised | ||
| with pytest.warns(UserWarning, match='High condition number'): | ||
| stream_reg.stats.XtX = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) | ||
| stream_reg.stats.check_condition() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition
self.n < self.kchecks if there are fewer observations than features, butself.kcould be None if no data has been processed yet. This will raise a TypeError when comparing int with NoneType.