Skip to content

Commit

Permalink
More robust DenseMatrix._get_col_stds (#436)
Browse files Browse the repository at this point in the history
* Add comment to stat_mat.standardize()

* More conservative approach to compute standard deviations.

* Format

* Update changelog.

* Add example to CI.

* try something

* Run ci on pushes on prs

* Update .github/workflows/ci.yml

Co-authored-by: Martin Stancsics <martin.stancsics@gmail.com>

* Add test.

* Test the actual problem.

* Format

* Update precision to sqrt(eps)

* Format

* Remove jan's example file.

---------

Co-authored-by: Jan Tilly <jan.tilly@quantco.com>
Co-authored-by: Martin Stancsics <martin.stancsics@gmail.com>
3 people authored Jan 29, 2025

Verified

This commit was signed with the committer’s verified signature.
sauclovian-g David Holland
1 parent 6fe2703 commit 04a6f68
Showing 6 changed files with 56 additions and 7 deletions.
13 changes: 12 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
name: CI
on: [push]

on:
# We would like to trigger for CI for any pull request action -
# both from QuantCo's branches as well as forks.
pull_request:
# In addition to pull requests, we want to run CI for pushes
# to the main branch and tags.
push:
branches:
- "main"
tags:
- "*"

jobs:
pre-commit-checks:
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -10,6 +10,10 @@ Changelog
Unreleased
----------

**Bug fix:**

- A more robust :meth:`DenseMatrix._get_col_stds` results in more accurate :meth:`StandardizedMatrix.sandwich` results.

**Other changes:**

- Build wheel for pypi on python 3.13.
4 changes: 2 additions & 2 deletions src/tabmat/dense_matrix.py
Original file line number Diff line number Diff line change
@@ -164,8 +164,8 @@ def _cross_sandwich(
raise TypeError

def _get_col_stds(self, weights: np.ndarray, col_means: np.ndarray) -> np.ndarray:
"""Get standard deviations of columns."""
sqrt_arg = transpose_square_dot_weights(self._array, weights) - col_means**2
"""Get standard deviations of columns using weights `weights`."""
sqrt_arg = transpose_square_dot_weights(self._array, weights, col_means)
# Minor floating point errors above can result in a very slightly
# negative sqrt_arg (e.g. -5e-16). We just set those values equal to
# zero.
6 changes: 3 additions & 3 deletions src/tabmat/ext/dense.pyx
Original file line number Diff line number Diff line change
@@ -100,7 +100,7 @@ def dense_matvec(np.ndarray X, floating[:] v, int[:] rows, int[:] cols):
raise Exception("The matrix X is not contiguous.")
return out

def transpose_square_dot_weights(np.ndarray X, floating[:] weights):
def transpose_square_dot_weights(np.ndarray X, floating[:] weights, floating[:] shift):
cdef floating* Xp = <floating*>X.data
cdef int nrows = weights.shape[0]
cdef int ncols = X.shape[1]
@@ -112,11 +112,11 @@ def transpose_square_dot_weights(np.ndarray X, floating[:] weights):
if X.flags["C_CONTIGUOUS"]:
for j in prange(ncols, nogil=True):
for i in range(nrows):
outp[j] = outp[j] + weights[i] * (Xp[i * ncols + j] ** 2)
outp[j] = outp[j] + weights[i] * ((Xp[i * ncols + j] - shift[j]) ** 2)
elif X.flags["F_CONTIGUOUS"]:
for j in prange(ncols, nogil=True):
for i in range(nrows):
outp[j] = outp[j] + weights[i] * (Xp[j * nrows + i] ** 2)
outp[j] = outp[j] + weights[i] * ((Xp[j * nrows + i] - shift[j]) ** 2)
else:
raise Exception("The matrix X is not contiguous.")
return out
10 changes: 9 additions & 1 deletion src/tabmat/standardized_mat.py
Original file line number Diff line number Diff line change
@@ -130,7 +130,15 @@ def sandwich(
if not hasattr(d, "dtype"):
d = np.asarray(d)
check_sandwich_compatible(self, d)

# stat_mat = mat * mult[newaxis, :] + shift[newaxis, :]
# stat_mat.T @ d[:, newaxis] * stat_mat
# = mult[:, newaxis] * mat.T @ d[:, newaxis] * mat * mult[newaxis, :] + (1)
# mult[:, newaxis] * mat.T @ d[:, newaxis] * np.outer(ones, shift) + (2)
# shift[:, newaxis] @ d[:, newaxis] * mat * mult[newaxis, :] + (3)
# shift[:, newaxis] @ d[:, newaxis] * shift[newaxis, :] (4)
#
# (1) = self.mat.sandwich(d) * np.outer(limited_mult, limited_mult)
# (2) = mult * self.transpose_matvec(d) * shift[newaxis, :]
if rows is not None or cols is not None:
setup_rows, setup_cols = setup_restrictions(self.shape, rows, cols)
if rows is not None:
26 changes: 26 additions & 0 deletions tests/test_matrices.py
Original file line number Diff line number Diff line change
@@ -813,3 +813,29 @@ def test_combine_names(mat_1, mat_2):

assert combined.column_names == mat_1.column_names + mat_2.column_names
assert combined.term_names == mat_1.term_names + mat_2.term_names


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_dense_matrix_get_col_stds(dtype):
# https://github.com/Quantco/tabmat/issues/414
X = np.array(
[
[46.231056, 126.05263, 144.46439],
[46.231224, 128.66818, 0.7667693],
[46.231186, 104.97506, 193.8872],
[46.230835, 130.10156, 143.88954],
[46.230896, 116.76007, 7.5629334],
],
dtype=dtype,
)
weights = np.full(X.shape[0], 1 / X.shape[0], dtype=dtype)

standardized_mat, _, col_stds = tm.DenseMatrix(X).standardize(
weights, center_predictors=True, scale_predictors=True
)

eps = np.sqrt(np.finfo(dtype).eps) # sqrt since std = sqrt(var)
np.testing.assert_allclose(col_stds, np.std(X, axis=0, ddof=0), rtol=eps)
np.testing.assert_allclose(
standardized_mat.mult, 1 / np.std(X, axis=0, ddof=0), rtol=eps
)

0 comments on commit 04a6f68

Please sign in to comment.