diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index d4ff79e08..164663bb9 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -43,6 +43,7 @@ class Metrics(object): APPROX_BLEU = "approx_bleu_score" RMSE = "rmse" LOG_POISSON = "log_poisson" + PEARSON = "pearson" R2 = "r_squared" ROUGE_2_F = "rouge_2_fscore" ROUGE_L_F = "rouge_L_fscore" @@ -741,6 +742,20 @@ def from_characters(raw, lookup_): return distance / reference_length, reference_length +def pearson_correlation_coefficient(predictions, labels, weights_fn=None): + """Calculate pearson correlation coefficient. + + Args: + predictions: The raw predictions. + labels: The actual labels. + weights_fn: Weighting function. + + Returns: + The pearson correlation coefficient. + """ + _, pearson = tf.contrib.metrics.streaming_pearson_correlation(predictions, labels) + return pearson, tf.constant(1.0) + # Metrics are functions that take predictions and labels and return # a tensor of metrics and a tensor of weights. # If the function has "features" as an argument, it will receive the whole @@ -756,6 +771,7 @@ def from_characters(raw, lookup_): Metrics.APPROX_BLEU: bleu_hook.bleu_score, Metrics.RMSE: padded_rmse, Metrics.LOG_POISSON: padded_log_poisson, + Metrics.PEARSON: pearson_correlation_coefficient, Metrics.R2: padded_variance_explained, Metrics.ROUGE_2_F: rouge.rouge_2_fscore, Metrics.ROUGE_L_F: rouge.rouge_l_fscore, diff --git a/tensor2tensor/utils/metrics_test.py b/tensor2tensor/utils/metrics_test.py index b6228483c..454c06b8d 100644 --- a/tensor2tensor/utils/metrics_test.py +++ b/tensor2tensor/utils/metrics_test.py @@ -319,6 +319,21 @@ def testMultilabelMatch3(self): actual = session.run(a) self.assertAlmostEqual(actual, expected, places=6) + def testPearsonCorrelationCoefficient(self): + predictions = np.random.rand(12, 1) + targets = np.random.rand(12, 1) + + expected = np.corrcoef(np.squeeze(predictions), np.squeeze(targets))[0][1] + with self.test_session() as session: + pearson, _ = metrics.pearson_correlation_coefficient( + tf.constant(predictions, dtype=tf.float32), + tf.constant(targets, dtype=tf.float32)) + session.run(tf.global_variables_initializer()) + session.run(tf.local_variables_initializer()) + actual = session.run(pearson) + print(actual) + print(expected) + self.assertAlmostEqual(actual, expected) if __name__ == '__main__': tf.test.main()