Skip to content

Commit

Permalink
Merge pull request #420 from nextstrain/fix-weighted-frequencies-rebase
Browse files Browse the repository at this point in the history
Fix weighted frequencies when weight keys are unrepresented
  • Loading branch information
trvrb authored Dec 11, 2019
2 parents 3b14fe3 + bf65e28 commit a456cba
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
22 changes: 22 additions & 0 deletions augur/frequency_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,28 @@ def estimate(self, tree):
# a given value sums to the proportion of tips that have that value.
# If weights are not defined, estimate frequencies such that they sum to 1.
if self.weights:
# Determine which weight attributes are represented in the
# tree. Drop any unrepresented attributes and renormalize the
# remaining proportions per attribute to sum to one before
# estimating frequencies.
weights_represented = set([
tip.attr[self.weights_attribute]
for tip in tree.find_clades(terminal=True)
])

if len(self.weights) != len(weights_represented):
# Remove unrepresented weights.
weights_unrepresented = set(self.weights.keys()) - weights_represented
for weight in weights_unrepresented:
del self.weights[weight]

# Renormalize the remaining weights.
weight_total = sum(self.weights.values())
for key, value in self.weights.items():
self.weights[key] = value / weight_total

# Estimate frequencies for all tips within each weight attribute
# group.
weight_keys, weight_values = zip(*sorted(self.weights.items()))
proportions = np.array(weight_values) / np.array(weight_values).sum(axis=0)
frequencies = {}
Expand Down
49 changes: 43 additions & 6 deletions tests/python3/test_frequencies.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""
Unit tests for frequency estimation
"""
import Bio
import json
import sys
from pathlib import Path
import numpy as np
from pathlib import Path
import pytest
import Bio

from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies
from augur.utils import json_to_tree
import sys
import os

# we assume (and assert) that this script is running from the tests/ directory
sys.path.append(str(Path(__file__).parent.parent.parent))

from augur.frequency_estimators import get_pivots, TreeKdeFrequencies, AlignmentKdeFrequencies
from augur.utils import json_to_tree

# Define regions to use for testing weighted frequencies.
REGIONS = [
('africa', 1.02),
Expand Down Expand Up @@ -95,6 +96,9 @@ def test_estimate(self, tree):
assert hasattr(kde_frequencies, "frequencies")
assert list(frequencies.values())[0].shape == kde_frequencies.pivots.shape

# Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))

def test_estimate_with_time_interval(self, tree):
"""Test frequency estimation with a given time interval.
"""
Expand Down Expand Up @@ -124,6 +128,9 @@ def test_weighted_estimate(self, tree):
assert hasattr(kde_frequencies, "frequencies")
assert list(frequencies.values())[0].shape == kde_frequencies.pivots.shape

# Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))

# Estimate unweighted frequencies to compare with weighted frequencies.
unweighted_kde_frequencies = TreeKdeFrequencies()
unweighted_frequencies = unweighted_kde_frequencies.estimate(tree)
Expand All @@ -136,6 +143,36 @@ def test_weighted_estimate(self, tree):
unweighted_frequencies[clade_to_test.name]
)

def test_weighted_estimate_with_unrepresented_weights(self, tree):
"""Test frequency estimation with weighted tips when any of the weight
attributes is unrepresented.
In this case, normalization of frequencies to the proportions
represented by the weights should be followed by a second normalization
to sum to 1.
"""
# Drop all tips sampled from Africa from the tree. Despite dropping a
# populous region, the estimated frequencies should still sum to 1
# below.
tips_from_africa = [
tip
for tip in tree.find_clades(terminal=True)
if tip.attr["region"] == "africa"
]
for tip in tips_from_africa:
tree.prune(tip)

# Estimate weighted frequencies.
weights = {region[0]: region[1] for region in REGIONS}
kde_frequencies = TreeKdeFrequencies(
weights=weights,
weights_attribute="region"
)
frequencies = kde_frequencies.estimate(tree)

# Frequencies should sum to 1 at all pivots.
assert np.allclose(np.array(list(frequencies.values())).sum(axis=0), np.ones_like(kde_frequencies.pivots))

def test_only_tip_estimates(self, tree):
"""Test frequency estimation for only tips in a given tree.
"""
Expand Down

0 comments on commit a456cba

Please sign in to comment.