Skip to content

Commit

Permalink
Demote global variables to class attributes
Browse files Browse the repository at this point in the history
This will reduce emcee pooling performance (should look at this later), but restore sanity
  • Loading branch information
JelleAalbers committed Aug 24, 2022
1 parent 1050b41 commit f267267
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 63 deletions.
98 changes: 44 additions & 54 deletions paltas/Analysis/hierarchical_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@
from scipy import special
import numba

# Global error filters for python warnings.
LINALGWARNING = True

# The predicted samples need to be et as a global variable for the pooling to
# be efficient when done by emcee. This will have shape (num_params,num_samps,
# batch_size).
predict_samps_hier = None

# As with the predicted samples, the predicted mu and cov for the analytical
# calculations should also be set at the global level for optimal
# performance.
mu_pred_array = None
prec_pred_array = None
mu_pred_array_ensemble = None
prec_pred_array_ensemble = None


def log_p_xi_omega(predict_samps_hier,hyperparameters,eval_func_xi_omega):
""" Calculate log p(xi|omega), the probability of the lens paramaters given
Expand Down Expand Up @@ -165,6 +149,18 @@ class ProbabilityClass:
it easier to write fast evaluation functions using numba.
"""

LINALGWARNING = True

# The predicted samples; will have shape (num_params,num_samps, batch_size).
predict_samps_hier = None

# The predicted mu and cov for the analytical calculations
mu_pred_array = None
prec_pred_array = None
mu_pred_array_ensemble = None
prec_pred_array_ensemble = None


def __init__(self,eval_func_xi_omega_i,eval_func_xi_omega,
eval_func_omega):
# Save these functions to the class for later use.
Expand All @@ -175,8 +171,7 @@ def __init__(self,eval_func_xi_omega_i,eval_func_xi_omega,
self.samples_init = False

def set_samples(self,predict_samps_input=None,predict_samps_hier_input=None):
""" Set the global lens samples value. Using a global helps avoid data
being pickled.
""" Set the lens samples value.
Args:
predict_samps_input (np.array): An array of shape (n_samps,n_lenses,
Expand All @@ -192,12 +187,11 @@ def set_samples(self,predict_samps_input=None,predict_samps_hier_input=None):
package use predict_samps_hier (n_params,n_samps,n_lenses)
convention.
"""
# Set the global samples variable
global predict_samps_hier
# Set the samples attribute
if predict_samps_hier_input is not None:
predict_samps_hier = predict_samps_hier_input
self.predict_samps_hier = predict_samps_hier_input
elif predict_samps_input is not None:
predict_samps_hier = np.ascontiguousarray(np.transpose(
self.predict_samps_hier = np.ascontiguousarray(np.transpose(
predict_samps_input,[2,0,1]))
else:
raise ValueError('Either predict_samps_input or ' +
Expand All @@ -206,7 +200,7 @@ def set_samples(self,predict_samps_input=None,predict_samps_hier_input=None):

# Calculate the probability of the sample on the interim training
# distribution
self.p_samps_omega_i = self.eval_func_xi_omega_i(predict_samps_hier)
self.p_samps_omega_i = self.eval_func_xi_omega_i(self.predict_samps_hier)

def log_post_omega(self,hyperparameters):
""" Given the predicted samples, calculate the log posterior of a
Expand All @@ -227,8 +221,6 @@ def log_post_omega(self,hyperparameters):
if self.samples_init is False:
raise RuntimeError('Must set samples or behaviour is ill-defined.')

global predict_samps_hier

# Start with the prior on omega
lprior = log_p_omega(hyperparameters,self.eval_func_omega)

Expand All @@ -237,7 +229,9 @@ def log_post_omega(self,hyperparameters):
return lprior

# Calculate the probability of each datapoint given omega
p_samps_omega = log_p_xi_omega(predict_samps_hier,hyperparameters,
p_samps_omega = log_p_xi_omega(
self.predict_samps_hier,
hyperparameters,
self.eval_func_xi_omega)

# We can use our pre-calculated value of p_samps_omega_i.
Expand Down Expand Up @@ -278,7 +272,7 @@ def __init__(self,mu_omega_i,cov_omega_i,eval_func_omega):
self.predictions_init = False

def set_predictions(self,mu_pred_array_input,prec_pred_array_input):
""" Set the global lens mean and covariance prediction values.
""" Set the lens mean and covariance prediction values.
Args:
mu_pred_array_input (np.array): An array of shape (n_lenses,
Expand All @@ -288,11 +282,8 @@ def set_predictions(self,mu_pred_array_input,prec_pred_array_input):
n_params,n_params) that represents the predicted precision
matrix on each lens.
"""
# Call up the globals and set them.
global mu_pred_array
global prec_pred_array
mu_pred_array = mu_pred_array_input
prec_pred_array = prec_pred_array_input
self.mu_pred_array = mu_pred_array_input
self.prec_pred_array = prec_pred_array_input

# Set the flag for the predictions being initialized
self.predictions_init = True
Expand Down Expand Up @@ -352,10 +343,6 @@ def log_post_omega(self,hyperparameters):
raise RuntimeError('Must set predictions or behaviour is '
+'ill-defined.')

global mu_pred_array
global prec_pred_array
global LINALGWARNING

# Start with the prior on omega
lprior = log_p_omega(hyperparameters,self.eval_func_omega)

Expand All @@ -370,21 +357,26 @@ def log_post_omega(self,hyperparameters):
prec_omega = np.linalg.inv(cov_omega)
except np.linalg.LinAlgError:
# Singular covariance matrix
if LINALGWARNING:
if self.LINALGWARNING:
warnings.warn('Singular covariance matrix',
category=RuntimeWarning)
LINALGWARNING = False
self.LINALGWARNING = False
return -np.inf

try:
like_ratio = self.log_integral_product(mu_pred_array,prec_pred_array,
self.mu_omega_i,self.prec_omega_i,mu_omega,prec_omega)
like_ratio = self.log_integral_product(
self.mu_pred_array,
self.prec_pred_array,
self.mu_omega_i,
self.prec_omega_i,
mu_omega,
prec_omega)
except np.linalg.LinAlgError:
# Something else was singular, too bad
if LINALGWARNING:
if self.LINALGWARNING:
warnings.warn('Singular covariance matrix',
category=RuntimeWarning)
LINALGWARNING = False
self.LINALGWARNING = False
return -np.inf

# Return the likelihood and the prior combined
Expand All @@ -408,7 +400,7 @@ class ProbabilityClassEnsemble(ProbabilityClassAnalytical):
"""

def set_predictions(self,mu_pred_array_input,prec_pred_array_input):
""" Set the global lens mean and covariance prediction values.
""" Set the lens mean and covariance prediction values.
Args:
mu_pred_array_input (np.array): An array of shape (n_ensembles,
Expand All @@ -418,15 +410,12 @@ def set_predictions(self,mu_pred_array_input,prec_pred_array_input):
n_lenses,n_params,n_params) that represents the predicted
precision matrix on each lens.
"""
# Call up the globals and set them.
global mu_pred_array_ensemble
global prec_pred_array_ensemble
mu_pred_array_ensemble = mu_pred_array_input
prec_pred_array_ensemble = prec_pred_array_input
self.mu_pred_array_ensemble = mu_pred_array_input
self.prec_pred_array_ensemble = prec_pred_array_input

# Set the flag for the predictions being initialized
self.predictions_init = True
self.n_ensemble = len(mu_pred_array_ensemble)
self.n_ensemble = len(self.mu_pred_array_ensemble)

@staticmethod
@numba.njit
Expand Down Expand Up @@ -491,9 +480,6 @@ def log_post_omega(self,hyperparameters):
raise RuntimeError('Must set predictions or behaviour is '
+'ill-defined.')

global mu_pred_array_ensemble
global prec_pred_array_ensemble

# Extract mu_omega and prec_omega from the provided hyperparameters
mu_omega = hyperparameters[:len(hyperparameters)//2]
cov_omega = np.diag(np.exp(hyperparameters[len(hyperparameters)//2:]*2))
Expand All @@ -506,8 +492,12 @@ def log_post_omega(self,hyperparameters):
if lprior == -np.inf:
return lprior

like_ratio = self.log_integral_product(mu_pred_array_ensemble,
prec_pred_array_ensemble,self.mu_omega_i,self.prec_omega_i,mu_omega,
like_ratio = self.log_integral_product(
self.mu_pred_array_ensemble,
self.prec_pred_array_ensemble,
self.mu_omega_i,
self.prec_omega_i,
mu_omega,
prec_omega)

# Return the likelihood and the prior combined
Expand Down
16 changes: 7 additions & 9 deletions test/analysis_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,16 +1215,16 @@ def eval_func_xi_omega_i(predict_samps_hier):
# Try setting the samples with predict_samps_hier_input
prob_class.set_samples(predict_samps_hier_input=predict_samps_hier_input)
self.assertFalse(
Analysis.hierarchical_inference.predict_samps_hier is None)
prob_class.predict_samps_hier is None)
np.testing.assert_almost_equal(prob_class.p_samps_omega_i,
np.sum(predict_samps_hier_input,axis=0))

# Try setting the samples with predict_samps_input
prob_class.set_samples(predict_samps_input=predict_samps_input)
self.assertFalse(
Analysis.hierarchical_inference.predict_samps_hier is None)
prob_class.predict_samps_hier is None)
np.testing.assert_array_equal(
Analysis.hierarchical_inference.predict_samps_hier,
prob_class.predict_samps_hier,
predict_samps_hier_input)
np.testing.assert_almost_equal(prob_class.p_samps_omega_i,
np.sum(predict_samps_hier_input,axis=0))
Expand Down Expand Up @@ -1303,8 +1303,8 @@ def test_set_predictions(self):

# Try setting the predictions
prob_class.set_predictions(mu_pred_array_input,prec_pred_array_input)
self.assertFalse(Analysis.hierarchical_inference.mu_pred_array is None)
self.assertFalse(Analysis.hierarchical_inference.prec_pred_array is None)
self.assertFalse(prob_class.mu_pred_array is None)
self.assertFalse(prob_class.prec_pred_array is None)

def test_log_integral_product(self):
# Make sure that the log integral product just sums the log of each
Expand Down Expand Up @@ -1405,10 +1405,8 @@ def test_set_predictions(self):

# Try setting the predictions
prob_class.set_predictions(mu_pred_array_input,prec_pred_array_input)
self.assertFalse(
Analysis.hierarchical_inference.mu_pred_array_ensemble is None)
self.assertFalse(
Analysis.hierarchical_inference.prec_pred_array_ensemble is None)
self.assertFalse(prob_class.mu_pred_array_ensemble is None)
self.assertFalse(prob_class.prec_pred_array_ensemble is None)

def test_log_integral_product(self):
# Make sure that the log integral product just sums the log of each
Expand Down

0 comments on commit f267267

Please sign in to comment.