Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade treetime to v0.7 #431

Merged
merged 7 commits into from
Feb 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 45 additions & 30 deletions augur/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import defaultdict

def ancestral_sequence_inference(tree=None, aln=None, ref=None, infer_gtr=True,
marginal=False, fill_overhangs=True):
marginal=False, fill_overhangs=True, infer_tips=False):
"""infer ancestral sequences using TreeTime

Parameters
Expand All @@ -29,6 +29,11 @@ def ancestral_sequence_inference(tree=None, aln=None, ref=None, infer_gtr=True,
filled with the gap character ('-'). If set to True, these end-gaps are
converted to "ambiguous" characters ('N' for nucleotides, 'X' for
aminoacids). Otherwise, the alignment is treated as-is
infer_tips : bool
Since v0.7, TreeTime does not reconstruct tip states by default.
This is only relevant when tip-state are not exactly specified, e.g. via
characters that signify ambiguous states. To replace those with the
most-likely state, set infer_tips=True

Returns
-------
Expand All @@ -44,15 +49,16 @@ def ancestral_sequence_inference(tree=None, aln=None, ref=None, infer_gtr=True,
bool_marginal = (marginal == "marginal")

# only infer ancestral sequences, leave branch length untouched
tt.infer_ancestral_sequences(infer_gtr=infer_gtr, marginal=bool_marginal)
tt.infer_ancestral_sequences(infer_gtr=infer_gtr, marginal=bool_marginal,
reconstruct_tip_states=infer_tips)

print("\nInferred ancestral sequence states using TreeTime:"
"\n\tSagulenko et al. TreeTime: Maximum-likelihood phylodynamic analysis"
"\n\tVirus Evolution, vol 4, https://academic.oup.com/ve/article/4/1/vex042/4794731\n")

return tt

def collect_sequences_and_mutations(T, is_vcf=False):
def collect_mutations_and_sequences(tt, infer_tips=False, full_sequences=False, character_map=None):
"""iterates of the tree and produces dictionaries with
mutations and sequences for each node.

Expand All @@ -68,20 +74,23 @@ def collect_sequences_and_mutations(T, is_vcf=False):
dict
dictionary of mutations and sequences
"""
if character_map is None:
cm = lambda x:x
else:
cm = lambda x: character_map.get(x, x)

data = defaultdict(dict)
inc = 1 # convert python numbering to start-at-1
for n in T.find_clades():
if hasattr(n, "mutations"):
mutations_attr = n.__getattribute__("mutations")
data[n.name]['muts'] = [str(a)+str(int(pos)+inc)+str(d)
for a,pos,d in mutations_attr]
if not is_vcf:
for n in T.find_clades():
if hasattr(n, "sequence"):
sequence_attr = n.__getattribute__("sequence")
data[n.name]['sequence'] = ''.join(sequence_attr)
else:
data[n.name]['sequence'] = ''
for n in tt.tree.find_clades():
data[n.name]['muts'] = [a+str(int(pos)+inc)+cm(d)
for a,pos,d in n.mutations]

if full_sequences:
for n in tt.tree.find_clades():
try:
data[n.name]['sequence'] = tt.sequence(n,reconstructed=infer_tips, as_string=True)
except:
print("No sequence available for node ",n.name)

return data

Expand All @@ -96,8 +105,11 @@ def register_arguments(parser):
help="calculate joint or marginal maximum likelihood ancestral sequence states")
parser.add_argument('--vcf-reference', type=str, help='fasta file of the sequence the VCF was mapped to')
parser.add_argument('--output-vcf', type=str, help='name of output VCF file which will include ancestral seqs')
parser.add_argument('--keep-ambiguous', action="store_true", default=False,
help='do not infer nucleotides at ambiguous (N) sites on tip sequences (leave as N). Always true for VCF input.')
ambiguous = parser.add_mutually_exclusive_group()
ambiguous.add_argument('--keep-ambiguous', action="store_false", dest='infer_ambiguous',
help='do not infer nucleotides at ambiguous (N) sites on tip sequences (leave as N).')
ambiguous.add_argument('--infer-ambiguous', action="store_true",
help='infer nucleotides at ambiguous (N,W,R,..) sites on tip sequences and replace with most likely state.')
parser.add_argument('--keep-overhangs', action="store_true", default=False,
help='do not infer nucleotides for gaps (-) on either side of the alignment')

Expand Down Expand Up @@ -135,25 +147,28 @@ def run(args):
else:
aln = args.alignment

# Only allow recovery of ambig sites for Fasta-input if TreeTime is version 0.5.6 or newer
# Otherwise it returns nonsense.
# Enfore treetime 0.7 or later
from distutils.version import StrictVersion
import treetime
if args.keep_ambiguous and not is_vcf and StrictVersion(treetime.version) < StrictVersion('0.5.6'):
print("ERROR: Keeping ambiguous sites for Fasta-input requires TreeTime version 0.5.6 or newer."+
"\nYour version is "+treetime.version+
"\nUpdate TreeTime or run without the --keep-ambiguous flag.")
if StrictVersion(treetime.version) < StrictVersion('0.7.0'):
print("ERROR: this version of augur requires TreeTime 0.7 or later.")
return 1

tt = ancestral_sequence_inference(tree=T, aln=aln, ref=ref, marginal=args.inference,
fill_overhangs = not(args.keep_overhangs))

if is_vcf or args.keep_ambiguous:
# TreeTime overwrites ambig sites on tips during ancestral reconst.
# Put these back in tip sequences now, to avoid misleading
tt.recover_var_ambigs()
fill_overhangs = not(args.keep_overhangs),
infer_tips = args.infer_ambiguous)

character_map = {}
for x in tt.gtr.profile_map:
if tt.gtr.profile_map[x].sum()==tt.gtr.n_states:
# TreeTime treats all characters that are not valid IUPAC nucleotide chars as fully ambiguous
# To clean up auspice output, we map all those to 'N'
character_map[x] = 'N'
else:
character_map[x] = x

anc_seqs['nodes'] = collect_sequences_and_mutations(T, is_vcf)
anc_seqs['nodes'] = collect_mutations_and_sequences(tt, full_sequences=not is_vcf,
infer_tips=args.infer_ambiguous, character_map=character_map)
# add reference sequence to json structure. This is the sequence with
# respect to which mutations on the tree are defined.
if is_vcf:
Expand Down
2 changes: 1 addition & 1 deletion augur/data/schema-export-v1-tree.json
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
"type": "array",
"items": {
"oneOf": [
{"type": "string", "pattern": "^[ATCGN-][0-9]+[ATCGN-]$"},
{"type": "string", "pattern": "^[ATCGNYRWSKMDVHB-][0-9]+[ATCGNYRWSKMDVHB-]$"},
{"type": "string", "pattern": "^insertion [0-9]+-[0-9]+$", "$comment": "unused by auspice"},
{"type": "string", "pattern": "^deletion [0-9]+-[0-9]+$", "$comment": "unused by auspice"}
]
Expand Down
5 changes: 5 additions & 0 deletions augur/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def register_arguments(parser):
parser.add_argument('--output-node-data', type=str, help='file name to write branch lengths as node data')
parser.add_argument('--timetree', action="store_true", help="produce timetree using treetime")
parser.add_argument('--coalescent', help="coalescent time scale in units of inverse clock rate (float), optimize as scalar ('opt'), or skyline ('skyline')")
parser.add_argument('--gen-per-year', default=50, type=float, help="number of generations per year, relevant for skyline output('skyline')")
parser.add_argument('--clock-rate', type=float, help="fixed clock rate")
parser.add_argument('--clock-std-dev', type=float, help="standard deviation of the fixed clock_rate estimate")
parser.add_argument('--root', nargs="+", default='best', help="rooting mechanism ('best', least-squares', 'min_dev', 'oldest') "
Expand Down Expand Up @@ -202,6 +203,10 @@ def run(args):
node_data['clock'] = {'rate': tt.date2dist.clock_rate,
'intercept': tt.date2dist.intercept,
'rtt_Tmrca': -tt.date2dist.intercept/tt.date2dist.clock_rate}
if args.coalescent=='skyline':
skyline, conf = tt.merger_model.skyline_inferred(gen=args.gen_per_year, confidence=2)
node_data['skyline'] = [[float(x) for x in skyline.x], [float(y) for y in conf[0]],
[float(y) for y in skyline.y], [float(y) for y in conf[1]]]
attributes.extend(['numdate', 'clock_length', 'mutation_length', 'raw_date', 'date'])
if args.date_confidence:
attributes.append('num_date_confidence')
Expand Down
100 changes: 31 additions & 69 deletions augur/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TINY = 1e-12

def mugration_inference(tree=None, seq_meta=None, field='country', confidence=True,
infer_gtr=True, root_state=None, missing='?', sampling_bias_correction=None):
missing='?', sampling_bias_correction=None):
"""
Infer likely ancestral states of a discrete character assuming a time reversible model.

Expand All @@ -24,10 +24,6 @@ def mugration_inference(tree=None, seq_meta=None, field='country', confidence=Tr
meta data field to use
confidence : bool, optional
calculate confidence values for inferences
infer_gtr : bool, optional
infer a GTR model for trait transitions (otherwises uses a flat model with rate 1)
root_state : None, optional
force the state of the root node (currently not implemented)
missing : str, optional
character that is to be interpreted as missing data, default='?'

Expand All @@ -40,95 +36,60 @@ def mugration_inference(tree=None, seq_meta=None, field='country', confidence=Tr
alphabet : dict
mapping of character states to
"""
from treetime import GTR
from treetime.wrappers import reconstruct_discrete_traits
from Bio.Align import MultipleSeqAlignment
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from Bio import Phylo

T = Phylo.read(tree, 'newick')
traits = {}
nodes = {n.name:n for n in T.get_terminals()}

# Determine alphabet only counting tips in the tree
places = set()
for name, meta in seq_meta.items():
if field in meta and name in nodes:
places.add(meta[field])
if root_state is not None:
places.add(root_state)

# construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
places = sorted(places)
nc = len(places)
if nc>180:
print("ERROR: geo_inference: can't have more than 180 places!", file=sys.stderr)
return None,None,None
elif nc==0:
print("ERROR: geo_inference: list of places is empty!", file=sys.stderr)
return None,None,None
elif nc==1:
print("WARNING: geo_inference: only one place found -- set every internal node to %s!"%places[0], file=sys.stderr)
alphabet = {'A':places[0]}
alphabet_values = ['A']
gtr = None
traits[name] = meta[field]
unique_states = list(set(traits.values()))

if len(unique_states)==0:
print("WARNING: no states found for discrete state reconstruction.")
for node in T.find_clades():
node.sequence=['A']
node.marginal_profile=np.array([[1.0]])
node.__setattr__(field, None)
return T, None, {}
elif len(unique_states)==1:
print("WARNING: only one state found for discrete state reconstruction:", unique_states)
for node in T.find_clades():
node.__setattr__(field, unique_states[0])
return T, None, {}
elif len(unique_states)<180:
tt, letter_to_state, reverse_alphabet = \
reconstruct_discrete_traits(T, traits, missing_data=missing,
sampling_bias_correction=sampling_bias_correction)
else:
# set up model
alphabet = {chr(65+i):place for i,place in enumerate(places)}
model = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
alphabet = np.array(sorted(alphabet.keys())))

missing_char = chr(65+nc)
alphabet[missing_char]=missing
model.profile_map[missing_char] = np.ones(nc)
model.ambiguous = missing_char
alphabet_rev = {v:k for k,v in alphabet.items()}

# construct pseudo alignment
pseudo_seqs = []
for name, meta in seq_meta.items():
if name in nodes:
s=alphabet_rev[meta[field]] if field in meta else missing_char
pseudo_seqs.append(SeqRecord(Seq(s), name=name, id=name))
aln = MultipleSeqAlignment(pseudo_seqs)

# set up treetime and infer
from treetime import TreeAnc
tt = TreeAnc(tree=tree, aln=aln, gtr=model, convert_upper=False, verbose=0)
tt.use_mutation_length = False
tt.infer_ancestral_sequences(infer_gtr=infer_gtr, store_compressed=False, pc=1.0,
marginal=True, normalized_rate=False)

if sampling_bias_correction:
tt.gtr.mu *= sampling_bias_correction
tt.infer_ancestral_sequences(infer_gtr=False, store_compressed=False,
marginal=True, normalized_rate=False)

T = tt.tree
gtr = tt.gtr
alphabet_values = tt.gtr.alphabet
print("ERROR: 180 or more distinct discrete states found. TreeTime is currently not set up to handle that many states.")
sys.exit(1)

if tt is None:
print("ERROR in discrete state reconstruction in TreeTime. Please look for errors above.")
sys.exit(1)

# attach inferred states as e.g. node.region = 'africa'
for node in T.find_clades():
node.__setattr__(field, alphabet[node.sequence[0]])
for node in tt.tree.find_clades():
node.__setattr__(field, letter_to_state[node.sequence[0]])

# if desired, attach entropy and confidence as e.g. node.region_entropy = 0.03
if confidence:
for node in T.find_clades():
for node in tt.tree.find_clades():
pdis = node.marginal_profile[0]
S = -np.sum(pdis*np.log(pdis+TINY))

marginal = [(alphabet[alphabet_values[i]], pdis[i]) for i in range(len(alphabet_values))]
marginal = [(letter_to_state[tt.gtr.alphabet[i]], pdis[i]) for i in range(len(tt.gtr.alphabet))]
marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
marginal = [(a, b) for a, b in marginal if b > 0.001][:4] #only take stuff over .1% and the top 4 elements
conf = {a:b for a,b in marginal}
node.__setattr__(field + "_entropy", S)
node.__setattr__(field + "_confidence", conf)

return T, gtr, alphabet
return tt.tree, tt.gtr, letter_to_state


def register_arguments(parser):
Expand Down Expand Up @@ -182,7 +143,8 @@ def run(args):
out_prefix = '.'.join(args.tree.split('.')[:-1])
for column in args.columns:
T, gtr, alphabet = mugration_inference(tree=tree_fname, seq_meta=traits,
field=column, confidence=args.confidence, sampling_bias_correction=args.sampling_bias_correction)
field=column, confidence=args.confidence,
sampling_bias_correction=args.sampling_bias_correction)
if T is None: # something went wrong
continue

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"jsonschema >=3.0.0, ==3.*",
"packaging >=19.2",
"pandas >=0.20.0, ==0.*",
"phylo-treetime >=0.5.6, <0.7",
"phylo-treetime ==0.7.*",
"snakemake >=5.4.0, ==5.*"
],
extras_require = {
Expand Down