Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions duckreg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
.. include:: ../README.md
"""
from .estimators import (
DuckRegression,
DuckMundlak,
DuckMundlakEventStudy,
)
from .regularized import DuckRidge
from .stream import StreamingRegression
150 changes: 150 additions & 0 deletions duckreg/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Streaming regression leveraging DuckDB's native Arrow IPC support.
"""

import numpy as np
import duckdb
from dataclasses import dataclass
from typing import Optional, Iterator


@dataclass
class RegressionStats:
"""Sufficient statistics for streaming regression."""

XtX: Optional[np.ndarray] = None
Xty: Optional[np.ndarray] = None
yty: float = 0.0
n: int = 0
k: Optional[int] = None

def update(self, X: np.ndarray, y: np.ndarray) -> "RegressionStats":
"""Update statistics with new batch."""
n_batch, k_batch = X.shape

if self.XtX is None:
self.k = k_batch
self.XtX = np.zeros((k_batch, k_batch))
self.Xty = np.zeros(k_batch)

self.XtX += X.T @ X
self.Xty += X.T @ y
self.yty += y @ y
self.n += n_batch
return self

def solve_ols(self) -> np.ndarray:
"""Compute OLS coefficients."""
if self.n < self.k:
Copy link

Copilot AI Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.n < self.k checks if there are fewer observations than features, but self.k could be None if no data has been processed yet. This will raise a TypeError when comparing int with NoneType.

Suggested change
if self.n < self.k:
if self.k is None or self.n < self.k:

Copilot uses AI. Check for mistakes.
return None
self.check_condition()
return np.linalg.solve(self.XtX, self.Xty)

def solve_ridge(self, lambda_: float = 1.0) -> np.ndarray:
"""Compute Ridge coefficients."""
if self.XtX is None:
return None
XtX_reg = self.XtX + lambda_ * np.eye(self.k)
return np.linalg.solve(XtX_reg, self.Xty)

def check_condition(self, threshold: float = 1e10):
"""Check the condition number of the XtX matrix."""
if self.XtX is None:
return None
cond = np.linalg.cond(self.XtX)
if cond > threshold:
import warnings

warnings.warn(
f"High condition number: {cond:.2e}. Consider using Ridge regression."
)
return cond


class DuckDBArrowStream:
"""
Stream data from DuckDB using native Arrow IPC support.
"""

def __init__(
self,
conn: duckdb.DuckDBPyConnection,
query: str,
chunk_size: int = 10000,
feature_cols: list[str] = None,
target_col: str = None,
Comment on lines +74 to +75
Copy link

Copilot AI Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using mutable default arguments (None) should be replaced with proper Optional type annotations and default to None explicitly. The type hints should be Optional[list[str]] and Optional[str].

Suggested change
feature_cols: list[str] = None,
target_col: str = None,
feature_cols: Optional[list[str]] = None,
target_col: Optional[str] = None,

Copilot uses AI. Check for mistakes.
):
self.conn = conn
self.query = query
self.chunk_size = chunk_size
self.feature_cols = feature_cols
self.target_col = target_col

def __iter__(self) -> Iterator[tuple[np.ndarray, np.ndarray]]:
"""Stream data in chunks using DuckDB's Arrow support."""
result = self.conn.execute(self.query)

while True:
arrow_chunk = result.fetch_arrow_table(self.chunk_size)

if arrow_chunk is None or arrow_chunk.num_rows == 0:
break

if self.feature_cols is None:
self.feature_cols = sorted(
[col for col in arrow_chunk.column_names if col.startswith("x")]
)

if self.target_col is None:
self.target_col = "y"

X = np.column_stack(
[arrow_chunk[col].to_numpy() for col in self.feature_cols]
)
y = arrow_chunk[self.target_col].to_numpy()

yield (X, y)


class StreamingRegression:
"""
Streaming regression for duckreg using sufficient statistics.
Leverages DuckDB's native Arrow IPC support.
"""

def __init__(
self, conn: duckdb.DuckDBPyConnection, query: str, chunk_size: int = 10000
):
self.conn = conn
self.query = query
self.chunk_size = chunk_size
self.stats = RegressionStats()

def fit(self, feature_cols: list[str], target_col: str):
"""
Perform streaming regression.
"""
stream = DuckDBArrowStream(
self.conn, self.query, self.chunk_size, feature_cols, target_col
)
for X, y in stream:
self.stats.update(X, y)
return self

def solve_ols(self):
"""
Solve OLS regression.
"""
return self.stats.solve_ols()

def solve_ridge(self, lambda_: float = 1.0):
"""
Solve Ridge regression.
"""
return self.stats.solve_ridge(lambda_)

@classmethod
def from_table(cls, conn: duckdb.DuckDBPyConnection, table_name: str, **kwargs):
"""Create a StreamingRegression instance from a table name."""
query = f"SELECT * FROM {table_name}"
return cls(conn, query, **kwargs)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ tqdm
duckdb
numba
pdoc
pyarrow
47 changes: 47 additions & 0 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import duckdb
import numpy as np
import pytest
from duckreg.stream import StreamingRegression


@pytest.fixture
def duckdb_conn():
"""Create an in-memory DuckDB connection."""
conn = duckdb.connect(':memory:')
yield conn
conn.close()


def test_streaming_regression(duckdb_conn):
"""Test streaming regression with a simple example."""
# Create sample data
duckdb_conn.execute("""
CREATE TABLE regression_data AS
WITH features AS (
SELECT
random() as x0,
random() as x1,
random() as x2
FROM generate_series(1, 100000) t(i)
)
SELECT
x0,
x1,
x2,
2.0*x0 - 1.5*x1 + 0.8*x2 + 0.1*random() as y
FROM features
""")

# Perform streaming regression
stream_reg = StreamingRegression.from_table(duckdb_conn, "regression_data")
stream_reg.fit(feature_cols=["x0", "x1", "x2"], target_col="y")
beta = stream_reg.solve_ols()

# Check the results
true_beta = np.array([2.0, -1.5, 0.8])
assert np.allclose(beta, true_beta, atol=0.1)

# Check that the condition number warning is raised
with pytest.warns(UserWarning, match='High condition number'):
stream_reg.stats.XtX = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
stream_reg.stats.check_condition()