Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import numpy as np
from pyspark.sql import SparkSession
import pytest
import os
Expand Down Expand Up @@ -50,6 +51,16 @@ def spark(spark_builder):
sess = spark_builder.getOrCreate()
return sess.newSession()

def pytest_addoption(parser):
parser.addoption('--random-seed', action='store', type=int, help='Seed to use for random number generator')

@pytest.fixture(scope="function")
def rg(pytestconfig):
seed = pytestconfig.getoption('random_seed')
seed_seq = np.random.SeedSequence(seed)
print(f'Creating random number generator with seed {seed_seq.entropy}')
return np.random.default_rng(seed_seq)

def pytest_runtest_setup(item):
min_spark_version = next((mark.args[0] for mark in item.iter_markers(name='min_spark')), None)
if min_spark_version:
Expand Down
183 changes: 183 additions & 0 deletions python/glow/gwas/approx_firth_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any, Optional
import pandas as pd
from pandas import Series
import numpy as np
from dataclasses import dataclass
from typeguard import typechecked
from nptyping import Float, NDArray
from scipy import stats


@dataclass
class LogLikelihood:
pi: NDArray[(Any,), Float]
G: NDArray[(Any, Any), Float] # diag(pi(1-pi))
I: NDArray[(Any, Any), Float] # Fisher information matrix, X'GX
deviance: Float # 2 * penalized log likelihood


@dataclass
class FirthFit:
beta: NDArray[(Any,), Float]
log_likelihood: LogLikelihood


@dataclass
class ApproxFirthState:
logit_offset: NDArray[(Any, Any), Float]
null_model_deviance: NDArray[(Any,), Float]


@typechecked
def _calculate_log_likelihood(
beta: NDArray[(Any,), Float],
X: NDArray[(Any, Any), Float],
y: NDArray[(Any,), Float],
offset: NDArray[(Any,), Float],
eps: float = 1e-15) -> LogLikelihood:

pi = 1 - 1 / (np.exp(X @ beta + offset) + 1)
G = np.diagflat(pi * (1-pi))
I = np.atleast_2d(X.T @ G @ X)
unpenalized_log_likelihood = y @ np.log(pi + eps) + (1-y) @ np.log(1-pi + eps)
_, logdet = np.linalg.slogdet(I)
penalty = 0.5 * logdet
deviance = -2 * (unpenalized_log_likelihood + penalty)
return LogLikelihood(pi, G, I, deviance)


@typechecked
def _fit_firth(
beta_init: NDArray[(Any,), Float],
X: NDArray[(Any, Any), Float],
y: NDArray[(Any,), Float],
offset: NDArray[(Any,), Float],
convergence_limit: float = 1e-5,
deviance_tolerance: float = 1e-6,
max_iter: int = 250,
max_step_size: int = 25,
max_half_steps: int = 25) -> Optional[FirthFit]:
'''
Firth’s bias-Reduced penalized-likelihood logistic regression.

:param beta_init: Initial beta values
:param X: Independent variable (covariate for null fit, genotype for SNP fit)
:param y: Dependent variable (phenotype)
:param offset: Offset (phenotype offset only for null fit, also with covariate offset for SNP fit)
:param convergence_limit: Convergence is reached if all entries of the penalized score have smaller magnitude
:param deviance_tolerance: Non-inferiority margin when halving step size
:param max_iter: Maximum number of Firth iterations
:param max_step_size: Maximum step size during a Firth iteration
:param max_half_steps: Maximum number of half-steps during a Firth iteration
:return: None if the fit failed
'''

n_iter = 0
beta = beta_init
log_likelihood = _calculate_log_likelihood(beta, X, y, offset)
while n_iter < max_iter:
invI = np.linalg.pinv(log_likelihood.I)

# build hat matrix
rootG_X = np.sqrt(log_likelihood.G) @ X
h = np.diagonal(rootG_X @ invI @ rootG_X.T)
U = X.T @ (y - log_likelihood.pi + h * (0.5 - log_likelihood.pi))

# f' / f''
delta = invI @ U

# force absolute step size to be less than max_step_size for each entry of beta
step_size = np.linalg.norm(delta, np.inf)
mx = step_size / max_step_size
if mx > 1:
delta = delta / mx

new_log_likelihood = _calculate_log_likelihood(beta + delta, X, y, offset)

# if the penalized log likelihood decreased, recompute with step-halving
n_half_steps = 0
while new_log_likelihood.deviance >= log_likelihood.deviance + deviance_tolerance:
if n_half_steps == max_half_steps:
print(f"Too many half-steps! {new_log_likelihood.deviance} vs {log_likelihood.deviance}")
return None
delta /= 2
new_log_likelihood = _calculate_log_likelihood(beta + delta, X, y, offset)
n_half_steps += 1

beta = beta + delta
log_likelihood = new_log_likelihood

if np.linalg.norm(U, np.inf) < convergence_limit:
break

n_iter += 1

if n_iter == max_iter:
print("Too many iterations!")
return None

return FirthFit(beta, log_likelihood)


@typechecked
def create_approx_firth_state(
Y: NDArray[(Any, Any), Float],
offset_df: Optional[pd.DataFrame],
C: NDArray[(Any, Any), Float],
Y_mask: NDArray[(Any, Any), Float],
fit_intercept: bool) -> ApproxFirthState:
'''
Performs the null fit for approximate Firth.

:return: Penalized log-likelihood of null fit and offset with covariate effects for SNP fit
'''

num_Y = Y.shape[1]
null_model_deviance = np.zeros(num_Y)
logit_offset = np.zeros(Y.shape)

for i in range(num_Y):
y = Y[:, i]
y_mask = Y_mask[:, i]
offset = offset_df.iloc[:, i].to_numpy() if offset_df is not None else np.zeros(y.shape)
b0_null_fit = np.zeros(C.shape[1])
if fit_intercept:
b0_null_fit[-1] = (0.5 + y.sum()) / (y_mask.sum() + 1)
b0_null_fit[-1] = np.log(b0_null_fit[-1] / (1 - b0_null_fit[-1])) - offset.mean()
# In regenie, this may retry with max_step_size=5, max_iter=5000
firth_fit_result = _fit_firth(b0_null_fit, C, y, offset)
if firth_fit_result is None:
raise ValueError("Null fit failed!")
null_model_deviance[i] = firth_fit_result.log_likelihood.deviance
logit_offset[:, i] = offset + (C @ firth_fit_result.beta)

return ApproxFirthState(logit_offset, null_model_deviance)


@typechecked
def correct_approx_firth(
x_res: NDArray[(Any,), Float],
y_res: NDArray[(Any,), Float],
logit_offset: NDArray[(Any,), Float],
null_model_deviance: Float) -> Optional[Series]:
'''
Calculate LRT statistics for a SNP using the approximate Firth method.

:return: None if the Firth fit did not converge, LRT statistics otherwise
'''

firth_fit = _fit_firth(
np.zeros(1),
np.expand_dims(x_res, axis=1),
y_res,
logit_offset
)
if firth_fit is None:
return None
# Likelihood-ratio test
tvalue = -1 * (firth_fit.log_likelihood.deviance - null_model_deviance)
pvalue = stats.chi2.sf(tvalue, 1)
effect = firth_fit.beta.item()
# Hessian of the unpenalized log-likelihood
stderr = np.linalg.pinv(firth_fit.log_likelihood.I).diagonal()[-1]
return Series({'tvalue': tvalue, 'pvalue': pvalue, 'effect': effect, 'stderr': stderr})
54 changes: 41 additions & 13 deletions python/glow/gwas/log_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import opt_einsum as oe
from . import functions as gwas_fx
from .functions import _VALUES_COLUMN_NAME
from .approx_firth_correction import *

__all__ = ['logistic_regression']

fallback_none = 'none'
correction_none = 'none'
correction_approx_firth = 'approx-firth'


@typechecked
Expand All @@ -23,14 +25,14 @@ def logistic_regression(
phenotype_df: pd.DataFrame,
covariate_df: pd.DataFrame = pd.DataFrame({}),
offset_df: pd.DataFrame = pd.DataFrame({}),
correction: str = 'none', # TODO: Make approx-firth default
correction: str = correction_approx_firth,
pvalue_threshold: float = 0.05,
fit_intercept: bool = True,
values_column: str = 'values',
dt: type = np.float64) -> DataFrame:
'''
Uses logistic regression to test for association between genotypes and one or more binary
phenotypes. This is a distributed version of the method from regenie:
phenotypes. This is a distributed version of the method from regenie:
https://www.biorxiv.org/content/10.1101/2020.06.19.162354v2

On the driver node, we fit a logistic regression model based on the covariates for each
Expand Down Expand Up @@ -94,14 +96,14 @@ def logistic_regression(

state = gwas_fx._loco_make_state(
Y, phenotype_df, offset_df,
lambda y, pdf, odf: _create_log_reg_state(y, pdf, odf, C, Y_mask))
lambda y, pdf, odf: _create_log_reg_state(y, pdf, odf, C, Y_mask, correction, fit_intercept))

phenotype_names = phenotype_df.columns.to_series().astype('str')

def map_func(pdf_iterator):
for pdf in pdf_iterator:
yield gwas_fx._loco_dispatch(pdf, state, _logistic_regression_inner,
C, Y_mask, correction, phenotype_names)
yield gwas_fx._loco_dispatch(pdf, state, _logistic_regression_inner, C, Y_mask,
correction, pvalue_threshold, phenotype_names)

return genotype_df.mapInPandas(map_func, result_struct)

Expand Down Expand Up @@ -135,6 +137,7 @@ class LogRegState:
inv_CtGammaC: NDArray[(Any, Any), Float]
gamma: NDArray[(Any, Any), Float]
Y_res: NDArray[(Any, Any), Float]
approx_firth_state: Optional[ApproxFirthState]


@typechecked
Expand All @@ -143,7 +146,9 @@ def _create_log_reg_state(
phenotype_df: pd.DataFrame, # Unused, only to share code with lin_reg.py
offset_df: Optional[pd.DataFrame],
C: NDArray[(Any, Any), Float],
Y_mask: NDArray[(Any, Any), Float]) -> LogRegState:
Y_mask: NDArray[(Any, Any), Float],
correction: str,
fit_intercept: bool) -> LogRegState:
Y_pred = np.row_stack([
_logistic_null_model_predictions(
Y[:, i], C, Y_mask[:, i],
Expand All @@ -153,7 +158,13 @@ def _create_log_reg_state(
gamma = Y_pred * (1 - Y_pred)
CtGammaC = C.T @ (gamma[:, :, None] * C)
inv_CtGammaC = np.linalg.inv(CtGammaC)
return LogRegState(inv_CtGammaC, gamma, (Y - Y_pred.T) * Y_mask)

if correction == correction_approx_firth:
approx_firth_state = create_approx_firth_state(Y, offset_df, C, Y_mask, fit_intercept)
else:
approx_firth_state = None

return LogRegState(inv_CtGammaC, gamma, (Y - Y_pred.T) * Y_mask, approx_firth_state)


def _logistic_residualize(X: NDArray[(Any, Any), Float], C: NDArray[(Any, Any), Float],
Expand All @@ -169,7 +180,7 @@ def _logistic_residualize(X: NDArray[(Any, Any), Float], C: NDArray[(Any, Any),

def _logistic_regression_inner(genotype_pdf: pd.DataFrame, log_reg_state: LogRegState,
C: NDArray[(Any, Any), Float], Y_mask: NDArray[(Any, Any), Float],
fallback_method: str, phenotype_names: pd.Series) -> pd.DataFrame:
correction: str, pvalue_threshold: float, phenotype_names: pd.Series) -> pd.DataFrame:
'''
Tests a block of genotypes for association with binary traits. We first residualize
the genotypes based on the null model fit, then perform a fast score test to check for
Expand All @@ -189,13 +200,30 @@ def _logistic_regression_inner(genotype_pdf: pd.DataFrame, log_reg_state: LogReg
t_values = np.ravel(num / denom)
p_values = stats.chi2.sf(t_values, 1)

if fallback_method != fallback_none:
# TODO: Call approx firth here
()

del genotype_pdf[_VALUES_COLUMN_NAME]
out_df = pd.concat([genotype_pdf] * log_reg_state.Y_res.shape[1])
out_df['tvalue'] = list(np.ravel(t_values))
out_df['pvalue'] = list(np.ravel(p_values))
out_df['phenotype'] = phenotype_names.repeat(genotype_pdf.shape[0]).tolist()

if correction != correction_none:
correction_indices = out_df.index[out_df['pvalue'] < pvalue_threshold]
if correction == correction_approx_firth:
for correction_idx in correction_indices:
snp_index = correction_idx % genotype_pdf.shape[0]
phenotype_index = int(correction_idx / phenotype_names.size)
approx_firth_snp_fit = correct_approx_firth(
X_res[snp_index][phenotype_index],
log_reg_state.Y_res[phenotype_index],
log_reg_state.approx_firth_state.logit_offset[phenotype_index],
log_reg_state.approx_firth_state.null_model_deviance[phenotype_index],
)
if approx_firth_snp_fit is not None:
out_df.iloc[correction_idx]['tvalue'] = approx_firth_snp_fit.tvalue
out_df.iloc[correction_idx]['pvalue'] = approx_firth_snp_fit.pvalue
else:
print(f"Could not correct {out_df.iloc[correction_idx]}")
else:
raise ValueError(f"Only supported correction method is {correction_approx_firth}")

return out_df
Loading