Skip to content

Commit

Permalink
filter: Use print_err and AugurError
Browse files Browse the repository at this point in the history
This also simplifies the implementation of validate_arguments() to raise
AugurErrors directly instead of returning a boolean to be translated to
a proper error message by the caller.
  • Loading branch information
victorlin committed Jul 21, 2022
1 parent 76566bc commit f657832
Show file tree
Hide file tree
Showing 12 changed files with 1,916 additions and 68 deletions.
1,875 changes: 1,875 additions & 0 deletions augur/filter.py

Large diffs are not rendered by default.

60 changes: 15 additions & 45 deletions augur/filter/filter_and_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import numpy as np
import os
import pandas as pd
import sys
from tempfile import NamedTemporaryFile

from augur.errors import AugurError
from augur.index import index_sequences, index_vcf
from augur.io import open_file, read_metadata, read_sequences, write_sequences, is_vcf as filename_is_vcf, write_vcf
from augur.io import open_file, read_metadata, read_sequences, write_sequences, is_vcf as filename_is_vcf, write_vcf, print_err
from .io import cleanup_outputs, read_priority_scores
from .rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling
Expand All @@ -23,41 +22,25 @@


def validate_arguments(args):
"""Validate arguments and return a boolean representing whether all validation
rules succeeded.
"""Validate arguments.
Parameters
----------
args : argparse.Namespace
Parsed arguments from argparse
Returns
-------
bool :
Validation succeeded.
"""
# Don't allow sequence output when no sequence input is provided.
if args.output and not args.sequences:
print(
"ERROR: You need to provide sequences to output sequences.",
file=sys.stderr)
return False
raise AugurError("You need to provide sequences to output sequences.")

# Confirm that at least one output was requested.
if not any((args.output, args.output_metadata, args.output_strains)):
print(
"ERROR: You need to select at least one output.",
file=sys.stderr)
return False
raise AugurError("You need to select at least one output.")

# Don't allow filtering on sequence-based information, if no sequences or
# sequence index is provided.
if not args.sequences and not args.sequence_index and any(getattr(args, arg) for arg in SEQUENCE_ONLY_FILTERS):
print(
"ERROR: You need to provide a sequence index or sequences to filter on sequence-specific information.",
file=sys.stderr)
return False
raise AugurError("You need to provide a sequence index or sequences to filter on sequence-specific information.")

# Set flags if VCF
is_vcf = filename_is_vcf(args.sequences)
Expand All @@ -66,29 +49,20 @@ def validate_arguments(args):
if is_vcf:
from shutil import which
if which("vcftools") is None:
print("ERROR: 'vcftools' is not installed! This is required for VCF data. "
"Please see the augur install instructions to install it.",
file=sys.stderr)
return False
raise AugurError("'vcftools' is not installed! This is required for VCF data. "
"Please see the augur install instructions to install it.")

# If user requested grouping, confirm that other required inputs are provided, too.
if args.group_by and not any((args.sequences_per_group, args.subsample_max_sequences)):
print(
"ERROR: You must specify a number of sequences per group or maximum sequences to subsample.",
file=sys.stderr
)
return False

return True
raise AugurError("You must specify a number of sequences per group or maximum sequences to subsample.")


def filter_and_subsample(args):
'''
filter and subsample a set of sequences into an analysis set
'''
# Validate arguments before attempting any I/O.
if not validate_arguments(args):
return 1
validate_arguments(args)

# Determine whether the sequence index exists or whether should be
# generated. We need to generate an index if the input sequences are in a
Expand All @@ -111,10 +85,9 @@ def filter_and_subsample(args):
with NamedTemporaryFile(delete=False) as sequence_index_file:
sequence_index_path = sequence_index_file.name

print(
print_err(
"Note: You did not provide a sequence index, so Augur will generate one.",
"You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.",
file=sys.stderr
"You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`."
)

if is_vcf:
Expand Down Expand Up @@ -357,8 +330,7 @@ def filter_and_subsample(args):
args.probabilistic_sampling,
)
except TooManyGroupsError as error:
print(f"ERROR: {error}", file=sys.stderr)
sys.exit(1)
raise AugurError(error)

if (probabilistic_used):
print(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
Expand Down Expand Up @@ -488,10 +460,9 @@ def filter_and_subsample(args):
# Warn the user if the expected strains from the sequence index are
# not a superset of the observed strains.
if sequence_strains is not None and observed_sequence_strains > sequence_strains:
print(
print_err(
"WARNING: The sequence index is out of sync with the provided sequences.",
"Metadata and strain output may not match sequence output.",
file=sys.stderr
"Metadata and strain output may not match sequence output."
)

# Update the set of available sequence strains.
Expand Down Expand Up @@ -550,7 +521,6 @@ def filter_and_subsample(args):
print("\t%i of these were dropped because of subsampling criteria%s" % (num_excluded_subsamp, seed_txt))

if total_strains_passed == 0:
print("ERROR: All samples have been dropped! Check filter rules and metadata file format.", file=sys.stderr)
return 1
raise AugurError("All samples have been dropped! Check filter rules and metadata file format.")

print(f"{total_strains_passed} strains passed all filters")
6 changes: 3 additions & 3 deletions augur/filter/io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import os
import sys
import numpy as np
import pandas as pd
from collections import defaultdict

from augur.errors import AugurError


def read_priority_scores(fname):
def constant_factory(value):
Expand All @@ -17,8 +18,7 @@ def constant_factory(value):
for elems in (line.strip().split('\t') if '\t' in line else line.strip().split() for line in pfile.readlines())
})
except Exception as e:
print(f"ERROR: missing or malformed priority scores file {fname}", file=sys.stderr)
raise e
raise AugurError(f"missing or malformed priority scores file {fname}")


def filter_kwargs_to_str(kwargs):
Expand Down
5 changes: 2 additions & 3 deletions augur/filter/rules.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import operator
import re
import sys
import numpy as np
import pandas as pd

from augur.dates import numeric_date, is_date_ambiguous, get_numerical_dates
from augur.errors import AugurError
from augur.io import is_vcf as filename_is_vcf
from augur.io import is_vcf as filename_is_vcf, print_err
from augur.utils import read_strains
from .io import filter_kwargs_to_str

Expand Down Expand Up @@ -596,7 +595,7 @@ def construct_filters(args, sequence_index):
is_vcf = filename_is_vcf(args.sequences)

if is_vcf: #doesn't make sense for VCF, ignore.
print("WARNING: Cannot use min_length for VCF files. Ignoring...", file=sys.stderr)
print_err("WARNING: Cannot use min_length for VCF files. Ignoring...")
else:
exclude_by.append((
filter_by_sequence_length,
Expand Down
15 changes: 8 additions & 7 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
from typing import Collection

from augur.io import print_err
from .errors import FilterException


Expand Down Expand Up @@ -113,16 +114,16 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
if 'year' in group_by_set or 'month' in group_by_set:

if 'year' in metadata.columns and 'year' in group_by_set:
print(f"WARNING: `--group-by year` uses the generated year value from the 'date' column. The custom 'year' column in the metadata is ignored for grouping purposes.", file=sys.stderr)
print_err(f"WARNING: `--group-by year` uses the generated year value from the 'date' column. The custom 'year' column in the metadata is ignored for grouping purposes.")
metadata.drop('year', axis=1, inplace=True)
if 'month' in metadata.columns and 'month' in group_by_set:
print(f"WARNING: `--group-by month` uses the generated month value from the 'date' column. The custom 'month' column in the metadata is ignored for grouping purposes.", file=sys.stderr)
print_err(f"WARNING: `--group-by month` uses the generated month value from the 'date' column. The custom 'month' column in the metadata is ignored for grouping purposes.")
metadata.drop('month', axis=1, inplace=True)

if 'date' not in metadata:
# set year/month/day = unknown
print(f"WARNING: A 'date' column could not be found to group-by year or month.", file=sys.stderr)
print(f"Filtering by group may behave differently than expected!", file=sys.stderr)
print_err(f"WARNING: A 'date' column could not be found to group-by year or month.")
print_err(f"Filtering by group may behave differently than expected!")
df_dates = pd.DataFrame({'year': 'unknown', 'month': 'unknown'}, index=metadata.index)
metadata = pd.concat([metadata, df_dates], axis=1)
else:
Expand Down Expand Up @@ -162,8 +163,8 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):

unknown_groups = group_by_set - set(metadata.columns)
if unknown_groups:
print(f"WARNING: Some of the specified group-by categories couldn't be found: {', '.join(unknown_groups)}", file=sys.stderr)
print("Filtering by group may behave differently than expected!", file=sys.stderr)
print_err(f"WARNING: Some of the specified group-by categories couldn't be found: {', '.join(unknown_groups)}")
print_err("Filtering by group may behave differently than expected!")
for group in unknown_groups:
metadata[group] = 'unknown'

Expand Down Expand Up @@ -346,7 +347,7 @@ def calculate_sequences_per_group(target_max_value, counts_per_group, allow_prob
)
except TooManyGroupsError as error:
if allow_probabilistic:
print(f"WARNING: {error}", file=sys.stderr)
print_err(f"WARNING: {error}")
sequences_per_group = _calculate_fractional_sequences_per_group(
target_max_value,
counts_per_group,
Expand Down
7 changes: 5 additions & 2 deletions tests/filter/test_filter_and_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import augur.filter.io
import augur.filter.rules
from augur.io import read_metadata
from augur.errors import AugurError

@pytest.fixture
def argparser():
Expand Down Expand Up @@ -79,9 +80,10 @@ def test_read_priority_scores_valid(self, mock_priorities_file_valid):
assert priorities["strain42"] == -np.inf, "Default priority is negative infinity for unlisted sequences"

def test_read_priority_scores_malformed(self, mock_priorities_file_malformed):
with pytest.raises(ValueError):
with pytest.raises(AugurError) as e_info:
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
augur.filter.io.read_priority_scores("tests/builds/tb/data/lee_2015.vcf")
assert str(e_info.value) == "ERROR: missing or malformed priority scores file tests/builds/tb/data/lee_2015.vcf"

def test_read_priority_scores_valid_with_spaces_and_tabs(self, mock_priorities_file_valid_with_spaces_and_tabs):
# builtins.open is stubbed, but we need a valid file to satisfy the existence check
Expand All @@ -92,8 +94,9 @@ def test_read_priority_scores_valid_with_spaces_and_tabs(self, mock_priorities_f
assert priorities == {"strain 1": 5, "strain 2": 6, "strain 3": 8}

def test_read_priority_scores_does_not_exist(self):
with pytest.raises(FileNotFoundError):
with pytest.raises(AugurError) as e_info:
augur.filter.io.read_priority_scores("/does/not/exist.txt")
assert str(e_info.value) == "ERROR: missing or malformed priority scores file /does/not/exist.txt"

def test_filter_on_query_good(self, tmpdir, sequences):
"""Basic filter_on_query test"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ This should fail because the requested filters rely on sequence information.
> --min-length 10000 \
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
ERROR: You need to provide a sequence index or sequences to filter on sequence-specific information.
[1]
[2]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This should produce no results because the intersection of metadata and sequence
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.
ERROR: All samples have been dropped! Check filter rules and metadata file format.
[1]
[2]
$ wc -l "$TMP/filtered_strains.txt"
\s*0 .* (re)
$ rm -f "$TMP/filtered_strains.txt"
Expand All @@ -30,7 +30,7 @@ Repeat with sequence and strain outputs. We should get the same results.
> --output-sequences "$TMP/filtered.fasta" > /dev/null
Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.
ERROR: All samples have been dropped! Check filter rules and metadata file format.
[1]
[2]
$ wc -l "$TMP/filtered_strains.txt"
\s*0 .* (re)
$ grep "^>" "$TMP/filtered.fasta" | wc -l
Expand All @@ -47,7 +47,7 @@ Since we expect metadata to be filtered by presence of strains in input sequence
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`.
ERROR: All samples have been dropped! Check filter rules and metadata file format.
[1]
[2]
$ wc -l "$TMP/filtered_strains.txt"
\s*0 .* (re)
$ rm -f "$TMP/filtered_strains.txt"
2 changes: 1 addition & 1 deletion tests/functional/filter/cram/filter-no-outputs-error.t
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Try to filter without any outputs.
> --metadata filter/data/metadata.tsv \
> --min-length 10000 > /dev/null
ERROR: You need to select at least one output.
[1]
[2]
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ This should fail.
> --min-length 10000 \
> --output "$TMP/filtered.fasta" > /dev/null
ERROR: You need to provide sequences to output sequences.
[1]
[2]
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ This should fail, as probabilistic sampling is explicitly disabled.
> --no-probabilistic-sampling \
> --output "$TMP/filtered.fasta"
ERROR: Asked to provide at most 5 sequences, but there are 8 groups.
[1]
[2]
$ rm -f "$TMP/filtered.fasta"
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ This should fail with a helpful error message.
> --group-by year month \
> --output-strains "$TMP/filtered_strains.txt" > /dev/null
ERROR: You must specify a number of sequences per group or maximum sequences to subsample.
[1]
[2]

0 comments on commit f657832

Please sign in to comment.