From 5419a5756b0c579559c0052a810aeaf06de8d5dd Mon Sep 17 00:00:00 2001 From: Apoorva Lal Date: Wed, 3 Sep 2025 20:56:57 -0700 Subject: [PATCH 1/2] add streaming OLS implementation with duckdb arrow interop --- duckreg/__init__.py | 7 ++ duckreg/stream.py | 150 +++++++++++++++++++++++++++++++++++++++++++ tests/test_stream.py | 47 ++++++++++++++ 3 files changed, 204 insertions(+) create mode 100644 duckreg/stream.py create mode 100644 tests/test_stream.py diff --git a/duckreg/__init__.py b/duckreg/__init__.py index ae675c2..40133a2 100644 --- a/duckreg/__init__.py +++ b/duckreg/__init__.py @@ -1,3 +1,10 @@ """ .. include:: ../README.md """ +from .estimators import ( + DuckRegression, + DuckMundlak, + DuckMundlakEventStudy, +) +from .regularized import DuckRidge +from .stream import StreamingRegression diff --git a/duckreg/stream.py b/duckreg/stream.py new file mode 100644 index 0000000..efb7ddc --- /dev/null +++ b/duckreg/stream.py @@ -0,0 +1,150 @@ +""" +Streaming regression leveraging DuckDB's native Arrow IPC support. +""" + +import numpy as np +import duckdb +from dataclasses import dataclass +from typing import Optional, Iterator + + +@dataclass +class RegressionStats: + """Sufficient statistics for streaming regression.""" + + XtX: Optional[np.ndarray] = None + Xty: Optional[np.ndarray] = None + yty: float = 0.0 + n: int = 0 + k: Optional[int] = None + + def update(self, X: np.ndarray, y: np.ndarray) -> "RegressionStats": + """Update statistics with new batch.""" + n_batch, k_batch = X.shape + + if self.XtX is None: + self.k = k_batch + self.XtX = np.zeros((k_batch, k_batch)) + self.Xty = np.zeros(k_batch) + + self.XtX += X.T @ X + self.Xty += X.T @ y + self.yty += y @ y + self.n += n_batch + return self + + def solve_ols(self) -> np.ndarray: + """Compute OLS coefficients.""" + if self.n < self.k: + return None + self.check_condition() + return np.linalg.solve(self.XtX, self.Xty) + + def solve_ridge(self, lambda_: float = 1.0) -> np.ndarray: + """Compute Ridge coefficients.""" + if self.XtX is None: + return None + XtX_reg = self.XtX + lambda_ * np.eye(self.k) + return np.linalg.solve(XtX_reg, self.Xty) + + def check_condition(self, threshold: float = 1e10): + """Check the condition number of the XtX matrix.""" + if self.XtX is None: + return None + cond = np.linalg.cond(self.XtX) + if cond > threshold: + import warnings + + warnings.warn( + f"High condition number: {cond:.2e}. Consider using Ridge regression." + ) + return cond + + +class DuckDBArrowStream: + """ + Stream data from DuckDB using native Arrow IPC support. + """ + + def __init__( + self, + conn: duckdb.DuckDBPyConnection, + query: str, + chunk_size: int = 10000, + feature_cols: list[str] = None, + target_col: str = None, + ): + self.conn = conn + self.query = query + self.chunk_size = chunk_size + self.feature_cols = feature_cols + self.target_col = target_col + + def __iter__(self) -> Iterator[tuple[np.ndarray, np.ndarray]]: + """Stream data in chunks using DuckDB's Arrow support.""" + result = self.conn.execute(self.query) + + while True: + arrow_chunk = result.fetch_arrow_table(self.chunk_size) + + if arrow_chunk is None or arrow_chunk.num_rows == 0: + break + + if self.feature_cols is None: + self.feature_cols = sorted( + [col for col in arrow_chunk.column_names if col.startswith("x")] + ) + + if self.target_col is None: + self.target_col = "y" + + X = np.column_stack( + [arrow_chunk[col].to_numpy() for col in self.feature_cols] + ) + y = arrow_chunk[self.target_col].to_numpy() + + yield (X, y) + + +class StreamingRegression: + """ + Streaming regression for duckreg using sufficient statistics. + Leverages DuckDB's native Arrow IPC support. + """ + + def __init__( + self, conn: duckdb.DuckDBPyConnection, query: str, chunk_size: int = 10000 + ): + self.conn = conn + self.query = query + self.chunk_size = chunk_size + self.stats = RegressionStats() + + def fit(self, feature_cols: list[str], target_col: str): + """ + Perform streaming regression. + """ + stream = DuckDBArrowStream( + self.conn, self.query, self.chunk_size, feature_cols, target_col + ) + for X, y in stream: + self.stats.update(X, y) + return self + + def solve_ols(self): + """ + Solve OLS regression. + """ + return self.stats.solve_ols() + + def solve_ridge(self, lambda_: float = 1.0): + """ + Solve Ridge regression. + """ + return self.stats.solve_ridge(lambda_) + + @classmethod + def from_table(cls, conn: duckdb.DuckDBPyConnection, table_name: str, **kwargs): + """Create a StreamingRegression instance from a table name.""" + query = f"SELECT * FROM {table_name}" + return cls(conn, query, **kwargs) diff --git a/tests/test_stream.py b/tests/test_stream.py new file mode 100644 index 0000000..ab912ac --- /dev/null +++ b/tests/test_stream.py @@ -0,0 +1,47 @@ +import duckdb +import numpy as np +import pytest +from duckreg.stream import StreamingRegression + + +@pytest.fixture +def duckdb_conn(): + """Create an in-memory DuckDB connection.""" + conn = duckdb.connect(':memory:') + yield conn + conn.close() + + +def test_streaming_regression(duckdb_conn): + """Test streaming regression with a simple example.""" + # Create sample data + duckdb_conn.execute(""" + CREATE TABLE regression_data AS + WITH features AS ( + SELECT + random() as x0, + random() as x1, + random() as x2 + FROM generate_series(1, 100000) t(i) + ) + SELECT + x0, + x1, + x2, + 2.0*x0 - 1.5*x1 + 0.8*x2 + 0.1*random() as y + FROM features + """) + + # Perform streaming regression + stream_reg = StreamingRegression.from_table(duckdb_conn, "regression_data") + stream_reg.fit(feature_cols=["x0", "x1", "x2"], target_col="y") + beta = stream_reg.solve_ols() + + # Check the results + true_beta = np.array([2.0, -1.5, 0.8]) + assert np.allclose(beta, true_beta, atol=0.1) + + # Check that the condition number warning is raised + with pytest.warns(UserWarning, match='High condition number'): + stream_reg.stats.XtX = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + stream_reg.stats.check_condition() \ No newline at end of file From e8b60c6f55a11973fe7d901a3af5fb4b66980690 Mon Sep 17 00:00:00 2001 From: Apoorva Lal Date: Wed, 3 Sep 2025 21:34:21 -0700 Subject: [PATCH 2/2] add arrow dep --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index c19b5ac..1e5b3d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ tqdm duckdb numba pdoc +pyarrow