Skip to content

Commit

Permalink
Merge pull request #59 from EducationalTestingService/add-weights-as-…
Browse files Browse the repository at this point in the history
…attribute

Save weights as an attribute
  • Loading branch information
desilinguist authored Jun 2, 2020
2 parents 39b9f9e + 433e8b5 commit d3df9c6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
8 changes: 5 additions & 3 deletions factor_analyzer/factor_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
with optional rotation using Varimax or Promax.
:author: Jeremy Biggs ([email protected])
:author: Nitin Madnani ([email protected])
:date: 10/25/2017
:organization: ETS
"""
Expand Down Expand Up @@ -288,6 +289,7 @@ def __init__(self,
self.corr_ = None
self.loadings_ = None
self.rotation_matrix_ = None
self.weights_ = None

@staticmethod
def _fit_uls_objective(psi, corr_mtx, n_factors):
Expand Down Expand Up @@ -734,13 +736,13 @@ def transform(self, X):
structure = self.loadings_

try:
weights = np.linalg.solve(self.corr_, structure)
self.weights_ = np.linalg.solve(self.corr_, structure)
except Exception as error:
warnings.warn('Unable to calculate the factor score weights; '
'factor loadings used instead: {}'.format(error))
weights = self.loadings_
self.weights_ = self.loadings_

scores = np.dot(X_scale, weights)
scores = np.dot(X_scale, self.weights_)
return scores

def get_eigenvalues(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_factor_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def test_calculate_kmo():

class TestFactorAnalyzer:

def test_analyze_weights(self):

data = pd.DataFrame({'A': [2, 4, 5, 6, 8, 9],
'B': [4, 8, 9, 10, 16, 18],
'C': [6, 12, 15, 12, 26, 27]})

fa = FactorAnalyzer(rotation=None)
fa.fit(data)
_ = fa.transform(data)
expected_weights = np.array(([[0.33536334, -2.72509646, 0],
[0.33916605, -0.29388849, 0],
[0.33444588, 3.03060826, 0]]))
assert_array_almost_equal(expected_weights, fa.weights_)

def test_analyze_impute_mean(self):

data = pd.DataFrame({'A': [2, 4, 5, 6, 8, 9],
Expand Down

0 comments on commit d3df9c6

Please sign in to comment.