-
Notifications
You must be signed in to change notification settings - Fork 128
/
Copy pathfilter.py
1878 lines (1544 loc) · 74.2 KB
/
filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Filter and subsample a sequence set.
"""
from Bio import SeqIO
from collections import defaultdict
import csv
import datetime
import heapq
import itertools
import json
import numpy as np
import operator
import os
import pandas as pd
import random
import re
import sys
from tempfile import NamedTemporaryFile
import treetime.utils
from typing import Collection
from .index import index_sequences, index_vcf
from .io import open_file, read_metadata, read_sequences, write_sequences
from .utils import is_vcf as filename_is_vcf, read_vcf, read_strains, get_numerical_dates, run_shell_command, shquote, is_date_ambiguous
comment_char = '#'
SEQUENCE_ONLY_FILTERS = (
"min_length",
"non_nucleotide",
)
class FilterException(Exception):
"""Representation of an error that occurred during filtering.
"""
pass
def write_vcf(input_filename, output_filename, dropped_samps):
if _filename_gz(input_filename):
input_arg = "--gzvcf"
else:
input_arg = "--vcf"
if _filename_gz(output_filename):
output_pipe = "| gzip -c"
else:
output_pipe = ""
drop_args = ["--remove-indv " + shquote(s) for s in dropped_samps]
call = ["vcftools"] + drop_args + [input_arg, shquote(input_filename), "--recode --stdout", output_pipe, ">", shquote(output_filename)]
print("Filtering samples using VCFTools with the call:")
print(" ".join(call))
run_shell_command(" ".join(call), raise_errors = True)
# remove vcftools log file
try:
os.remove('out.log')
except OSError:
pass
def read_priority_scores(fname):
def constant_factory(value):
return lambda: value
try:
with open(fname, encoding='utf-8') as pfile:
return defaultdict(constant_factory(-np.inf), {
elems[0]: float(elems[1])
for elems in (line.strip().split('\t') if '\t' in line else line.strip().split() for line in pfile.readlines())
})
except Exception as e:
print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr)
raise e
# Define metadata filters.
def filter_by_exclude_all(metadata):
"""Exclude all strains regardless of the given metadata content.
This is a placeholder function that can be called as part of a generalized
loop through all possible functions.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
Returns
-------
set[str]:
Empty set of strains
>>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"])
>>> filter_by_exclude_all(metadata)
set()
"""
return set()
def filter_by_exclude(metadata, exclude_file):
"""Exclude the given set of strains from the given metadata.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
exclude_file : str
Filename with strain names to exclude from the given metadata
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"])
>>> with NamedTemporaryFile(delete=False) as exclude_file:
... characters_written = exclude_file.write(b'strain1')
>>> filter_by_exclude(metadata, exclude_file.name)
{'strain2'}
>>> os.unlink(exclude_file.name)
"""
excluded_strains = read_strains(exclude_file)
return set(metadata.index.values) - excluded_strains
def parse_filter_query(query):
"""Parse an augur filter-style query and return the corresponding column,
operator, and value for the query.
Parameters
----------
query : str
augur filter-style query following the pattern of `"property=value"` or `"property!=value"`
Returns
-------
str :
Name of column to query
callable :
Operator function to test equality or non-equality of values
str :
Value of column to query
>>> parse_filter_query("property=value")
('property', <built-in function eq>, 'value')
>>> parse_filter_query("property!=value")
('property', <built-in function ne>, 'value')
"""
column, value = re.split(r'!?=', query)
op = operator.eq
if "!=" in query:
op = operator.ne
return column, op, value
def filter_by_exclude_where(metadata, exclude_where):
"""Exclude all strains from the given metadata that match the given exclusion query.
Unlike pandas query syntax, exclusion queries should follow the pattern of
`"property=value"` or `"property!=value"`. Additionally, this filter treats
all values like lowercase strings, so we convert all values to strings first
and then lowercase them before testing the given query.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
exclude_where : str
Filter query used to exclude strains
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"])
>>> filter_by_exclude_where(metadata, "region!=Europe")
{'strain2'}
>>> filter_by_exclude_where(metadata, "region=Europe")
{'strain1'}
>>> filter_by_exclude_where(metadata, "region=europe")
{'strain1'}
If the column referenced in the given query does not exist, skip the filter.
>>> sorted(filter_by_exclude_where(metadata, "missing_column=value"))
['strain1', 'strain2']
"""
column, op, value = parse_filter_query(exclude_where)
if column in metadata.columns:
# Apply a test operator (equality or inequality) to values from the
# column in the given query. This produces an array of boolean values we
# can index with.
excluded = op(
metadata[column].astype(str).str.lower(),
value.lower()
)
# Negate the boolean index of excluded strains to get the index of
# strains that passed the filter.
included = ~excluded
filtered = set(metadata[included].index.values)
else:
# Skip the filter, if the requested column does not exist.
filtered = set(metadata.index.values)
return filtered
def filter_by_query(metadata, query):
"""Filter metadata in the given pandas DataFrame with a query string and return
the strain names that pass the filter.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
query : str
Query string for the dataframe.
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"])
>>> filter_by_query(metadata, "region == 'Africa'")
{'strain1'}
>>> filter_by_query(metadata, "region == 'North America'")
set()
"""
return set(metadata.query(query).index.values)
def filter_by_ambiguous_date(metadata, date_column="date", ambiguity="any"):
"""Filter metadata in the given pandas DataFrame where values in the given date
column have a given level of ambiguity.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
date_column : str
Column in the dataframe with dates.
ambiguity : str
Level of date ambiguity to filter metadata by
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-XX"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"])
>>> filter_by_ambiguous_date(metadata)
{'strain2'}
>>> sorted(filter_by_ambiguous_date(metadata, ambiguity="month"))
['strain1', 'strain2']
If the requested date column does not exist, we quietly skip this filter.
>>> sorted(filter_by_ambiguous_date(metadata, date_column="missing_column"))
['strain1', 'strain2']
"""
if date_column in metadata.columns:
date_is_ambiguous = metadata[date_column].apply(
lambda date: is_date_ambiguous(date, ambiguity)
)
filtered = set(metadata[~date_is_ambiguous].index.values)
else:
filtered = set(metadata.index.values)
return filtered
def filter_by_date(metadata, date_column="date", min_date=None, max_date=None):
"""Filter metadata by minimum or maximum date.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
date_column : str
Column in the dataframe with dates.
min_date : float
Minimum date
max_date : float
Maximum date
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"])
>>> filter_by_date(metadata, min_date=numeric_date("2020-01-02"))
{'strain2'}
>>> filter_by_date(metadata, max_date=numeric_date("2020-01-01"))
{'strain1'}
>>> filter_by_date(metadata, min_date=numeric_date("2020-01-03"), max_date=numeric_date("2020-01-10"))
set()
>>> sorted(filter_by_date(metadata, min_date=numeric_date("2019-12-30"), max_date=numeric_date("2020-01-10")))
['strain1', 'strain2']
>>> sorted(filter_by_date(metadata))
['strain1', 'strain2']
If the requested date column does not exist, we quietly skip this filter.
>>> sorted(filter_by_date(metadata, date_column="missing_column", min_date=numeric_date("2020-01-02")))
['strain1', 'strain2']
"""
strains = set(metadata.index.values)
# Skip this filter if no valid min/max date is given or the date column does
# not exist.
if (not min_date and not max_date) or date_column not in metadata.columns:
return strains
dates = get_numerical_dates(metadata, date_col=date_column, fmt="%Y-%m-%d")
filtered = {strain for strain in strains if dates[strain] is not None}
if min_date:
filtered = {s for s in filtered if (np.isscalar(dates[s]) or all(dates[s])) and np.max(dates[s]) >= min_date}
if max_date:
filtered = {s for s in filtered if (np.isscalar(dates[s]) or all(dates[s])) and np.min(dates[s]) <= max_date}
return filtered
def filter_by_sequence_index(metadata, sequence_index):
"""Filter metadata by presence of corresponding entries in a given sequence
index. This filter effectively intersects the strain ids in the metadata and
sequence index.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
sequence_index : pandas.DataFrame
Sequence index
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"])
>>> sequence_index = pd.DataFrame([{"strain": "strain1", "ACGT": 28000}]).set_index("strain")
>>> filter_by_sequence_index(metadata, sequence_index)
{'strain1'}
"""
metadata_strains = set(metadata.index.values)
sequence_index_strains = set(sequence_index.index.values)
return metadata_strains & sequence_index_strains
def filter_by_sequence_length(metadata, sequence_index, min_length=0):
"""Filter metadata by sequence length from a given sequence index.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
sequence_index : pandas.DataFrame
Sequence index
min_length : int
Minimum number of standard nucleotide characters (A, C, G, or T) in each sequence
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"])
>>> sequence_index = pd.DataFrame([{"strain": "strain1", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}]).set_index("strain")
>>> filter_by_sequence_length(metadata, sequence_index, min_length=27000)
{'strain1'}
It is possible for the sequence index to be missing strains present in the metadata.
>>> sequence_index = pd.DataFrame([{"strain": "strain3", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}]).set_index("strain")
>>> filter_by_sequence_length(metadata, sequence_index, min_length=27000)
set()
"""
strains = set(metadata.index.values)
filtered_sequence_index = sequence_index.loc[
sequence_index.index.intersection(strains)
]
filtered_sequence_index["ACGT"] = filtered_sequence_index.loc[:, ["A", "C", "G", "T"]].sum(axis=1)
return set(filtered_sequence_index[filtered_sequence_index["ACGT"] >= min_length].index.values)
def filter_by_non_nucleotide(metadata, sequence_index):
"""Filter metadata for strains with invalid nucleotide content.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
sequence_index : pandas.DataFrame
Sequence index
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"])
>>> sequence_index = pd.DataFrame([{"strain": "strain1", "invalid_nucleotides": 0}, {"strain": "strain2", "invalid_nucleotides": 1}]).set_index("strain")
>>> filter_by_non_nucleotide(metadata, sequence_index)
{'strain1'}
"""
strains = set(metadata.index.values)
filtered_sequence_index = sequence_index.loc[
sequence_index.index.intersection(strains)
]
no_invalid_nucleotides = filtered_sequence_index["invalid_nucleotides"] == 0
return set(filtered_sequence_index[no_invalid_nucleotides].index.values)
def include(metadata, include_file):
"""Include strains in the given text file from the given metadata.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
include_file : str
Filename with strain names to include from the given metadata
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"])
>>> with NamedTemporaryFile(delete=False) as include_file:
... characters_written = include_file.write(b'strain1')
>>> include(metadata, include_file.name)
{'strain1'}
>>> os.unlink(include_file.name)
"""
included_strains = read_strains(include_file)
return set(metadata.index.values) & included_strains
def include_by_include_where(metadata, include_where):
"""Include all strains from the given metadata that match the given query.
Unlike pandas query syntax, inclusion queries should follow the pattern of
`"property=value"` or `"property!=value"`. Additionally, this filter treats
all values like lowercase strings, so we convert all values to strings first
and then lowercase them before testing the given query.
Parameters
----------
metadata : pandas.DataFrame
Metadata indexed by strain name
include_where : str
Filter query used to include strains
Returns
-------
set[str]:
Strains that pass the filter
>>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"])
>>> include_by_include_where(metadata, "region!=Europe")
{'strain1'}
>>> include_by_include_where(metadata, "region=Europe")
{'strain2'}
>>> include_by_include_where(metadata, "region=europe")
{'strain2'}
If the column referenced in the given query does not exist, skip the filter.
>>> include_by_include_where(metadata, "missing_column=value")
set()
"""
column, op, value = parse_filter_query(include_where)
if column in metadata.columns:
# Apply a test operator (equality or inequality) to values from the
# column in the given query. This produces an array of boolean values we
# can index with.
included_index = op(
metadata[column].astype(str).str.lower(),
value.lower()
)
included = set(metadata[included_index].index.values)
else:
# Skip the inclusion filter if the requested column does not exist.
included = set()
return included
def construct_filters(args, sequence_index):
"""Construct lists of filters and inclusion criteria based on user-provided
arguments.
Parameters
----------
args : argparse.Namespace
Command line arguments provided by the user.
sequence_index : pandas.DataFrame
Sequence index for the provided arguments.
Returns
-------
list :
A list of 2-element tuples with a callable to use as a filter and a
dictionary of kwargs to pass to the callable.
list :
A list of 2-element tuples with a callable and dictionary of kwargs that
determines whether to force include strains in the final output.
"""
exclude_by = []
include_by = []
# Force include sequences specified in file(s).
if args.include:
# Collect the union of all given strains to include.
for include_file in args.include:
include_by.append((
include,
{
"include_file": include_file,
}
))
# Add sequences with particular metadata attributes.
if args.include_where:
for include_where in args.include_where:
include_by.append((
include_by_include_where,
{
"include_where": include_where,
}
))
# Exclude all strains by default.
if args.exclude_all:
exclude_by.append((filter_by_exclude_all, {}))
# Filter by sequence index.
if sequence_index is not None:
exclude_by.append((
filter_by_sequence_index,
{
"sequence_index": sequence_index,
},
))
# Remove strains explicitly excluded by name.
if args.exclude:
for exclude_file in args.exclude:
exclude_by.append((
filter_by_exclude,
{
"exclude_file": exclude_file,
}
))
# Exclude strain my metadata field like 'host=camel'.
if args.exclude_where:
for exclude_where in args.exclude_where:
exclude_by.append((
filter_by_exclude_where,
{"exclude_where": exclude_where}
))
# Exclude strains by metadata, using pandas querying.
if args.query:
exclude_by.append((
filter_by_query,
{"query": args.query}
))
# Filter by ambiguous dates.
if args.exclude_ambiguous_dates_by:
exclude_by.append((
filter_by_ambiguous_date,
{
"date_column": "date",
"ambiguity": args.exclude_ambiguous_dates_by,
}
))
# Filter by date.
if args.min_date or args.max_date:
exclude_by.append((
filter_by_date,
{
"date_column": "date",
"min_date": args.min_date,
"max_date": args.max_date,
}
))
# Filter by sequence length.
if args.min_length:
# Skip VCF files and warn the user that the min length filter does not
# make sense for VCFs.
is_vcf = filename_is_vcf(args.sequences)
if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
exclude_by.append((
filter_by_sequence_length,
{
"sequence_index": sequence_index,
"min_length": args.min_length,
}
))
# Exclude sequences with non-nucleotide characters.
if args.non_nucleotide:
exclude_by.append((
filter_by_non_nucleotide,
{
"sequence_index": sequence_index,
}
))
return exclude_by, include_by
def filter_kwargs_to_str(kwargs):
"""Convert a dictionary of kwargs to a JSON string for downstream reporting.
This structured string can be converted back into a Python data structure
later for more sophisticated reporting by specific kwargs.
This function excludes data types from arguments like pandas DataFrames and
also converts floating point numbers to a fixed precision for better
readability and reproducibility.
Parameters
----------
kwargs : dict
Dictionary of kwargs passed to a given filter function.
Returns
-------
str :
String representation of the kwargs for reporting.
>>> sequence_index = pd.DataFrame([{"strain": "strain1", "ACGT": 28000}, {"strain": "strain2", "ACGT": 26000}, {"strain": "strain3", "ACGT": 5000}]).set_index("strain")
>>> exclude_by = [(filter_by_sequence_length, {"sequence_index": sequence_index, "min_length": 27000})]
>>> filter_kwargs_to_str(exclude_by[0][1])
'[["min_length", 27000]]'
>>> exclude_by = [(filter_by_date, {"max_date": numeric_date("2020-04-01"), "min_date": numeric_date("2020-03-01")})]
>>> filter_kwargs_to_str(exclude_by[0][1])
'[["max_date", 2020.25], ["min_date", 2020.17]]'
"""
# Sort keys prior to processing to guarantee the same output order
# regardless of the input order.
sorted_keys = sorted(kwargs.keys())
kwarg_list = []
for key in sorted_keys:
value = kwargs[key]
# Handle special cases for data types that we want to represent
# differently from their defaults or not at all.
if isinstance(value, pd.DataFrame):
continue
elif isinstance(value, float):
value = round(value, 2)
kwarg_list.append((key, value))
return json.dumps(kwarg_list)
def apply_filters(metadata, exclude_by, include_by):
"""Apply a list of filters to exclude or force-include records from the given
metadata and return the strains to keep, to exclude, and to force include.
Parameters
----------
metadata : pandas.DataFrame
Metadata to filter
exclude_by : list[tuple]
A list of 2-element tuples with a callable to filter by in the first
index and a dictionary of kwargs to pass to the function in the second
index.
include_by : list[tuple]
A list of 2-element tuples in the same format as the ``exclude_by``
argument.
Returns
-------
set :
Strains to keep (those that passed all filters)
list[dict] :
Strains to exclude along with the function that filtered them and the arguments used to run the function.
list[dict] :
Strains to force-include along with the function that filtered them and the arguments used to run the function.
For example, filter data by minimum date, but force the include of strains
from Africa.
>>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-10-02"}, {"region": "North America", "date": "2020-01-01"}], index=["strain1", "strain2", "strain3"])
>>> exclude_by = [(filter_by_date, {"min_date": numeric_date("2020-04-01")})]
>>> include_by = [(include_by_include_where, {"include_where": "region=Africa"})]
>>> strains_to_keep, strains_to_exclude, strains_to_include = apply_filters(metadata, exclude_by, include_by)
>>> strains_to_keep
{'strain2'}
>>> sorted(strains_to_exclude, key=lambda record: record["strain"])
[{'strain': 'strain1', 'filter': 'filter_by_date', 'kwargs': '[["min_date", 2020.25]]'}, {'strain': 'strain3', 'filter': 'filter_by_date', 'kwargs': '[["min_date", 2020.25]]'}]
>>> strains_to_include
[{'strain': 'strain1', 'filter': 'include_by_include_where', 'kwargs': '[["include_where", "region=Africa"]]'}]
We also want to filter by characteristics of the sequence data that we've
annotated in a sequence index.
>>> sequence_index = pd.DataFrame([{"strain": "strain1", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}, {"strain": "strain3", "A": 1250, "C": 1250, "G": 1250, "T": 1250}]).set_index("strain")
>>> exclude_by = [(filter_by_sequence_length, {"sequence_index": sequence_index, "min_length": 27000})]
>>> include_by = [(include_by_include_where, {"include_where": "region=Europe"})]
>>> strains_to_keep, strains_to_exclude, strains_to_include = apply_filters(metadata, exclude_by, include_by)
>>> strains_to_keep
{'strain1'}
>>> sorted(strains_to_exclude, key=lambda record: record["strain"])
[{'strain': 'strain2', 'filter': 'filter_by_sequence_length', 'kwargs': '[["min_length", 27000]]'}, {'strain': 'strain3', 'filter': 'filter_by_sequence_length', 'kwargs': '[["min_length", 27000]]'}]
>>> strains_to_include
[{'strain': 'strain2', 'filter': 'include_by_include_where', 'kwargs': '[["include_where", "region=Europe"]]'}]
"""
strains_to_keep = set(metadata.index.values)
strains_to_filter = []
strains_to_force_include = []
distinct_strains_to_force_include = set()
# Track strains that should be included regardless of filters.
for include_function, include_kwargs in include_by:
passed = metadata.pipe(
include_function,
**include_kwargs,
)
distinct_strains_to_force_include = distinct_strains_to_force_include | passed
# Track the reason why strains were included.
if len(passed) > 0:
include_name = include_function.__name__
include_kwargs_str = filter_kwargs_to_str(include_kwargs)
for strain in passed:
strains_to_force_include.append({
"strain": strain,
"filter": include_name,
"kwargs": include_kwargs_str,
})
for filter_function, filter_kwargs in exclude_by:
# Apply the current function with its given arguments. Each function
# returns a set of strains that passed the corresponding filter.
passed = metadata.pipe(
filter_function,
**filter_kwargs,
)
# Track the strains that failed this filter, so we can explain why later
# on and update the list of strains to keep to intersect with the
# strains that passed.
failed = strains_to_keep - passed
strains_to_keep = (strains_to_keep & passed)
# Track the reason each strain was filtered for downstream reporting.
if len(failed) > 0:
# Use a human-readable name for each filter when reporting why a strain
# was excluded.
filter_name = filter_function.__name__
filter_kwargs_str = filter_kwargs_to_str(filter_kwargs)
for strain in failed:
strains_to_filter.append({
"strain": strain,
"filter": filter_name,
"kwargs": filter_kwargs_str,
})
# Stop applying filters if no strains remain.
if len(strains_to_keep) == 0:
break
return strains_to_keep, strains_to_filter, strains_to_force_include
def get_groups_for_subsampling(strains, metadata, group_by=None):
"""Return a list of groups for each given strain based on the corresponding
metadata and group by column.
Parameters
----------
strains : list
A list of strains to get groups for.
metadata : pandas.DataFrame
Metadata to inspect for the given strains.
group_by : list
A list of metadata (or calculated) columns to group records by.
Returns
-------
dict :
A mapping of strain names to tuples corresponding to the values of the strain's group.
list :
A list of dictionaries with strains that were skipped from grouping and the reason why (see also: `apply_filters` output).
>>> strains = ["strain1", "strain2"]
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020-01-01", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by = ["region"]
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': ('Africa',), 'strain2': ('Europe',)}
>>> skipped_strains
[]
If we group by year or month, these groups are calculated from the date
string.
>>> group_by = ["year", "month"]
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, (2020, 1)), 'strain2': (2020, (2020, 2))}
If we omit the grouping columns, the result will group by a dummy column.
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata)
>>> group_by_strain
{'strain1': ('_dummy',), 'strain2': ('_dummy',)}
If we try to group by columns that don't exist, we get an error.
>>> group_by = ["missing_column"]
>>> get_groups_for_subsampling(strains, metadata, group_by)
Traceback (most recent call last):
...
augur.filter.FilterException: The specified group-by categories (['missing_column']) were not found. No sequences-per-group sampling will be done.
If we try to group by some columns that exist and some that don't, we allow
grouping to continue and print a warning message to stderr.
>>> group_by = ["year", "month", "missing_column"]
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, (2020, 1), 'unknown'), 'strain2': (2020, (2020, 2), 'unknown')}
If we group by year month and some records don't have that information in
their date fields, we should skip those records from the group output and
track which records were skipped for which reasons.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["year"])
>>> group_by_strain
{'strain2': (2020,)}
>>> skipped_strains
[{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_year', 'kwargs': ''}]
Similarly, if we group by month, we should skip records that don't have
month information in their date fields.
>>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain")
>>> group_by_strain, skipped_strains = get_groups_for_subsampling(strains, metadata, ["month"])
>>> group_by_strain
{'strain2': ((2020, 2),)}
>>> skipped_strains
[{'strain': 'strain1', 'filter': 'skip_group_by_with_ambiguous_month', 'kwargs': ''}]
"""
metadata = metadata.loc[list(strains)]
group_by_strain = {}
skipped_strains = []
if metadata.empty:
return group_by_strain, skipped_strains
if not group_by or group_by == ('_dummy',):
group_by_strain = {strain: ('_dummy',) for strain in strains}
return group_by_strain, skipped_strains
group_by_set = set(group_by)
# If we could not find any requested categories, we cannot complete subsampling.
if 'date' not in metadata and group_by_set <= {'year', 'month'}:
raise FilterException(f"The specified group-by categories ({group_by}) were not found. No sequences-per-group sampling will be done. Note that using 'year' or 'year month' requires a column called 'date'.")
if not group_by_set & (set(metadata.columns) | {'year', 'month'}):
raise FilterException(f"The specified group-by categories ({group_by}) were not found. No sequences-per-group sampling will be done.")
# date requested
if 'year' in group_by_set or 'month' in group_by_set:
if 'date' not in metadata:
# set year/month/day = unknown
print(f"WARNING: A 'date' column could not be found to group-by year or month.", file=sys.stderr)
print(f"Filtering by group may behave differently than expected!", file=sys.stderr)
df_dates = pd.DataFrame({'year': 'unknown', 'month': 'unknown'}, index=metadata.index)
metadata = pd.concat([metadata, df_dates], axis=1)
else:
# replace date with year/month/day as nullable ints
date_cols = ['year', 'month', 'day']
df_dates = metadata['date'].str.split('-', n=2, expand=True)
df_dates = df_dates.set_axis(date_cols[:len(df_dates.columns)], axis=1)
missing_date_cols = set(date_cols) - set(df_dates.columns)
for col in missing_date_cols:
df_dates[col] = pd.NA
for col in date_cols:
df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype())
metadata = pd.concat([metadata.drop('date', axis=1), df_dates], axis=1)
if 'year' in group_by_set:
# skip ambiguous years
df_skip = metadata[metadata['year'].isnull()]
metadata.dropna(subset=['year'], inplace=True)
for strain in df_skip.index:
skipped_strains.append({
"strain": strain,
"filter": "skip_group_by_with_ambiguous_year",
"kwargs": "",
})
if 'month' in group_by_set:
# skip ambiguous months
df_skip = metadata[metadata['month'].isnull()]
metadata.dropna(subset=['month'], inplace=True)
for strain in df_skip.index:
skipped_strains.append({
"strain": strain,
"filter": "skip_group_by_with_ambiguous_month",
"kwargs": "",
})
# month = (year, month)
metadata['month'] = list(zip(metadata['year'], metadata['month']))
# TODO: support group by day
unknown_groups = group_by_set - set(metadata.columns)
if unknown_groups:
print(f"WARNING: Some of the specified group-by categories couldn't be found: {', '.join(unknown_groups)}", file=sys.stderr)
print("Filtering by group may behave differently than expected!", file=sys.stderr)
for group in unknown_groups:
metadata[group] = 'unknown'
group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1)))
return group_by_strain, skipped_strains
class PriorityQueue:
"""A priority queue implementation that automatically replaces lower priority
items in the heap with incoming higher priority items.
Add a single record to a heap with a maximum of 2 records.
>>> queue = PriorityQueue(max_size=2)
>>> queue.add({"strain": "strain1"}, 0.5)
1
Add another record with a higher priority. The queue should be at its maximum
size.
>>> queue.add({"strain": "strain2"}, 1.0)
2
>>> queue.heap
[(0.5, 0, {'strain': 'strain1'}), (1.0, 1, {'strain': 'strain2'})]
>>> list(queue.get_items())
[{'strain': 'strain1'}, {'strain': 'strain2'}]
Add a higher priority record that causes the queue to exceed its maximum
size. The resulting queue should contain the two highest priority records
after the lowest priority record is removed.
>>> queue.add({"strain": "strain3"}, 2.0)
2
>>> list(queue.get_items())
[{'strain': 'strain2'}, {'strain': 'strain3'}]
Add a record with the same priority as another record, forcing the duplicate
to be resolved by removing the oldest entry.
>>> queue.add({"strain": "strain4"}, 1.0)
2
>>> list(queue.get_items())
[{'strain': 'strain4'}, {'strain': 'strain3'}]
"""