From 068a2165814e32ab773797a96b9bbe2afc87f960 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 31 Jan 2024 14:47:36 -0800 Subject: [PATCH 01/15] io: Add class interfaces for working with SQLite3 This will be used in future commits by a new implementation of augur filter. --- augur/io/sqlite3.py | 172 ++++++++++++++++++++++++ docs/api/developer/augur.io.rst | 1 + docs/api/developer/augur.io.sqlite3.rst | 7 + 3 files changed, 180 insertions(+) create mode 100644 augur/io/sqlite3.py create mode 100644 docs/api/developer/augur.io.sqlite3.rst diff --git a/augur/io/sqlite3.py b/augur/io/sqlite3.py new file mode 100644 index 000000000..71f989c51 --- /dev/null +++ b/augur/io/sqlite3.py @@ -0,0 +1,172 @@ +import csv +import random +import re +import string +import sqlite3 +from contextlib import AbstractContextManager +from typing import Any, Dict, Iterable +from urllib.parse import urlencode +from .file import open_file + + +class Sqlite3Database(AbstractContextManager): + """Represents a SQLite3 database. + + This class should be used as a context manager where the runtime context of + an instance reflects a connection to the database. + """ + + def __init__(self, file: str, **connect_uri_params): + """ + Parameters + ---------- + file + Database file + connect_uri_params + Parameters passed to sqlite3.connect() in the form of URI parameters. + Examples: https://docs.python.org/3/library/sqlite3.html#how-to-work-with-sqlite-uris + """ + + self.file = file + """Database file.""" + + self.connection: sqlite3.Connection = None + """SQLite3 database connection.""" + + self.connect_uri_params = connect_uri_params + """Parameters passed to sqlite3.connect() in the form of URI parameters.""" + + # Default to opening the database in read-only mode + if 'mode' not in self.connect_uri_params: + self.connect_uri_params['mode'] = 'ro' + + def _connect(self, **connect_uri_params): + """Return a new connection to the SQLite database. + + Intended for use with the context manager.""" + encoded_params = urlencode(connect_uri_params) + return sqlite3.connect(f'file:{self.file}?{encoded_params}', uri=True) + + def __enter__(self): + self.connection = self._connect(**self.connect_uri_params).__enter__() + + # Allow access by column name. + self.connection.row_factory = sqlite3.Row + + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.connection.__exit__(exc_type, exc_value, exc_traceback) + + def tables(self): + """Yield table names in the database.""" + for row in self.connection.execute(f"SELECT name FROM sqlite_master WHERE type='table'"): + yield str(row["name"]) + + def columns(self, table: str): + """Yield columns names from a table.""" + self._assert_table_existence(table) + + for row in self.connection.execute(f"SELECT name FROM pragma_table_info({sanitize_identifier(table)})"): + yield str(row["name"]) + + def _assert_table_existence(self, table: str): + """Assert that the table exists.""" + assert table in self.tables() + + def create_table(self, table: str, columns: Iterable[str]): + """Create a table with all columns having the TEXT type affinity.""" + + # Create table with TEXT type affinity for all columns. + # FIXME: STRICT requires SQLite version 3.37.0. Do we actually need it? + try: + self.connection.execute(f""" + CREATE TABLE {sanitize_identifier(table)} ( + {','.join(f'{sanitize_identifier(column)} TEXT' for column in columns)} + ) STRICT + """) + except sqlite3.OperationalError as e: + raise Exception(f'Failed to create table.') from e + + def insert(self, table: str, columns: Iterable[str], rows: Iterable[Dict[str, Any]]): + """Insert rows into a table.""" + self._assert_table_existence(table) + + # This intermediate dict serves two purposes: + # 1. Generates valid placeholder names. Column names cannot be used + # directly since placeholder names cannot be quote-sanitized in the + # insert statement, and column names are user-defined. + # 2. Ensures the list of column names and the list of placeholders in + # the insert statement are 1:1. + column_placeholder_mapping = {column: _generate_placeholder(column) for column in columns} + + insert_statement = f""" + INSERT INTO {sanitize_identifier(table)} + ({','.join([sanitize_identifier(column) for column in column_placeholder_mapping.keys()])}) + VALUES ({','.join([f':{placeholder}' for placeholder in column_placeholder_mapping.values()])}) + """ + rows_with_placeholders = ( + {column_placeholder_mapping[key]: value + for key, value in row.items() + if key in columns + } + for row in rows + ) + try: + self.connection.executemany(insert_statement, rows_with_placeholders) + except sqlite3.ProgrammingError as e: + raise Exception("Failed to insert rows.") from e + + def query_to_file(self, query: str, path: str, header: bool = True, delimiter: str = '\t'): + """Query the database and write results to a tabular file. + + Parameters + ---------- + query + SQLite query string. + path + Path to the output file. + header + Write out the column names. + delimiter + Delimiter between columns. + """ + with open_file(path, mode="w") as output_file: + cursor = self.connection.cursor() + cursor.execute(query) + writer = csv.writer(output_file, delimiter=delimiter, lineterminator='\n') + if header: + column_names = [column[0] for column in cursor.description] + writer.writerow(column_names) + for row in cursor: + writer.writerow(row) + + +def sanitize_identifier(identifier: str): + """Sanitize a SQLite identifier. + + Note: We can (and probably should) do more here¹. However, column names in + the database are used in the output, which should be as close as possible + to the input names. + ¹ https://stackoverflow.com/a/6701665 + """ + + # Escape existing double quotes + identifier = identifier.replace('"', '""') + + # Wrap inside double quotes + return f'"{identifier}"' + + +def _generate_placeholder(identifier: str): + """Generate a placeholder¹ name that doesn't need escaping. + + ¹ https://www.sqlite.org/lang_expr.html#varparam + """ + # Remove everything but alphanumeric characters and underscores + stripped_identifier = re.sub(r'\W+', '', identifier) + + # Generate a random suffix to avoid collisions + random_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) + + return f'{stripped_identifier}_{random_suffix}' diff --git a/docs/api/developer/augur.io.rst b/docs/api/developer/augur.io.rst index d6db9a2b2..e996a8364 100644 --- a/docs/api/developer/augur.io.rst +++ b/docs/api/developer/augur.io.rst @@ -18,5 +18,6 @@ Submodules augur.io.print augur.io.sequences augur.io.shell_command_runner + augur.io.sqlite3 augur.io.strains augur.io.vcf diff --git a/docs/api/developer/augur.io.sqlite3.rst b/docs/api/developer/augur.io.sqlite3.rst new file mode 100644 index 000000000..9f20ff0e3 --- /dev/null +++ b/docs/api/developer/augur.io.sqlite3.rst @@ -0,0 +1,7 @@ +augur.io.sqlite3 module +======================= + +.. automodule:: augur.io.sqlite3 + :members: + :undoc-members: + :show-inheritance: From 28f49a9d659f4c7abddb34647de88454b2b10ea5 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 31 Jan 2024 14:48:11 -0800 Subject: [PATCH 02/15] io: Split Metadata class into TabularFile and File These will be used in future commits by a new implementation of augur filter. --- augur/frequencies.py | 2 +- augur/io/file.py | 14 +++ augur/io/metadata.py | 64 ++----------- augur/io/tabular_file.py | 97 ++++++++++++++++++++ augur/refine.py | 2 +- docs/api/developer/augur.io.rst | 1 + docs/api/developer/augur.io.tabular_file.rst | 7 ++ 7 files changed, 131 insertions(+), 56 deletions(-) create mode 100644 augur/io/tabular_file.py create mode 100644 docs/api/developer/augur.io.tabular_file.rst diff --git a/augur/frequencies.py b/augur/frequencies.py index b0abfb397..a41ad3994 100644 --- a/augur/frequencies.py +++ b/augur/frequencies.py @@ -91,7 +91,7 @@ def format_frequencies(freq): def run(args): try: - metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) + metadata_object = Metadata(args.metadata, args.metadata_id_columns, delimiters=args.metadata_delimiters) except InvalidDelimiter: raise AugurError( f"Could not determine the delimiter of {args.metadata!r}. " diff --git a/augur/io/file.py b/augur/io/file.py index 1aa5f59af..c9f29f365 100644 --- a/augur/io/file.py +++ b/augur/io/file.py @@ -68,3 +68,17 @@ def open_file(path_or_buffer, mode="r", **kwargs): else: raise TypeError(f"Type {type(path_or_buffer)} is not supported.") + + +class File: + """Represents a file.""" + + path: str + """Path to the file on disk.""" + + def __init__(self, path: str): + self.path = path + + def open(self, **kwargs): + """Open the file with auto-compression/decompression.""" + return open_file(self.path, **kwargs) diff --git a/augur/io/metadata.py b/augur/io/metadata.py index de474162d..bb47f865e 100644 --- a/augur/io/metadata.py +++ b/augur/io/metadata.py @@ -1,6 +1,6 @@ import csv import os -from typing import Iterable, Sequence +from typing import Sequence import pandas as pd import pyfastx import python_calamine as calamine @@ -13,19 +13,14 @@ from augur.io.print import print_err from augur.types import DataErrorMethod from .file import PANDAS_READ_CSV_OPTIONS, open_file +from .tabular_file import DEFAULT_DELIMITERS, InvalidDelimiter, TabularFile, get_delimiter -DEFAULT_DELIMITERS = (',', '\t') - DEFAULT_ID_COLUMNS = ("strain", "name") METADATA_DATE_COLUMN = 'date' -class InvalidDelimiter(Exception): - pass - - def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, columns=None, id_columns=DEFAULT_ID_COLUMNS, chunk_size=None, dtype=None): r"""Read metadata from a given filename and into a pandas `DataFrame` or `TextFileReader` object. @@ -82,7 +77,7 @@ def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, columns=None, id """ kwargs = { - "sep": _get_delimiter(metadata_file, delimiters), + "sep": get_delimiter(metadata_file, delimiters), "engine": "c", "skipinitialspace": True, "na_filter": False, @@ -265,7 +260,7 @@ def visible_worksheet(s: calamine.SheetMetadata) -> bool: handle = chain(table_sample_file, handle) try: - # Note: this sort of duplicates _get_delimiter(), but it's easier if + # Note: this sort of duplicates get_delimiter(), but it's easier if # this is separate since it handles non-seekable buffers. dialect = csv.Sniffer().sniff(table_sample, delimiters) except csv.Error as error: @@ -559,46 +554,24 @@ def write_records_to_tsv(records, output_file): tsv_writer.writerow(record) -class Metadata: +class Metadata(TabularFile): """Represents a metadata file.""" - path: str - """Path to the file on disk.""" - - delimiter: str - """Inferred delimiter of metadata.""" - - columns: Sequence[str] - """Columns extracted from the first row in the metadata file.""" - id_column: str """Inferred ID column.""" - def __init__(self, path: str, delimiters: Sequence[str], id_columns: Sequence[str]): + def __init__(self, path: str, id_columns: Sequence[str], **kwargs): """ Parameters ---------- path Path of the metadata file. - delimiters - Possible delimiters to use, in order of precedence. id_columns - Possible ID columns to use, in order of precedence. + Possible ID columns, in order of precedence. + **kwargs + See TabularFile.__init__() for more parameters. """ - self.path = path - - # Infer the dialect. - self.delimiter = _get_delimiter(self.path, delimiters) - - # Infer the column names. - with self.open() as f: - reader = csv.reader(f, delimiter=self.delimiter) - try: - self.columns = next(reader) - except StopIteration: - raise AugurError(f"{self.path}: Expected a header row but it is empty.") - - # Infer the ID column. + super().__init__(path, **kwargs) self.id_column = self._find_id_column(id_columns) def open(self, **kwargs): @@ -637,20 +610,3 @@ def rows(self, strict: bool = True): # This is distinct from a blank value (empty string). raise AugurError(f"{self.path}: Line {reader.line_num} is missing at least one column. The inferred delimiter is {self.delimiter!r}.") yield row - - -def _get_delimiter(path: str, valid_delimiters: Iterable[str]): - """Get the delimiter of a file given a list of valid delimiters.""" - - for delimiter in valid_delimiters: - if len(delimiter) != 1: - raise AugurError(f"Delimiters must be single-character strings. {delimiter!r} does not satisfy that condition.") - - with open_file(path, newline='') as file: - try: - # Infer the delimiter from the first line. - return csv.Sniffer().sniff(file.readline(), "".join(valid_delimiters)).delimiter - except csv.Error as error: - # This assumes all csv.Errors imply a delimiter issue. That might - # change in a future Python version. - raise InvalidDelimiter from error diff --git a/augur/io/tabular_file.py b/augur/io/tabular_file.py new file mode 100644 index 000000000..be1771fd7 --- /dev/null +++ b/augur/io/tabular_file.py @@ -0,0 +1,97 @@ +import csv +from typing import Iterable, Sequence + +from augur.errors import AugurError +from .file import File, open_file + + +DEFAULT_DELIMITERS = (',', '\t') + + +class InvalidDelimiter(Exception): + pass + + +class TabularFile(File): + """Represents a tabular file with a delimiter and required header line. + + A pandas DataFrame can provide these properties, however this class is more + suitable for large files as it does not store the rows in memory. + + The properties on this class can also be used on sub-classes for more + context-specific usage. + """ + + columns: Sequence[str] + """Columns extracted from the first row in the file.""" + + delimiter: str + """Inferred delimiter of file.""" + + def __init__(self, path: str, delimiters: Sequence[str] = None): + """ + Parameters + ---------- + path + Path of the tabular file. + delimiters + Possible delimiters to use, in order of precedence. + """ + super().__init__(path) + + if delimiters is None: + delimiters = DEFAULT_DELIMITERS + + # Infer the dialect. + self.delimiter = get_delimiter(self.path, delimiters) + + # Infer the column names. + with self.open() as f: + reader = csv.reader(f, delimiter=self.delimiter) + try: + self.columns = next(reader) + except StopIteration: + raise AugurError(f"{self.path}: Expected a header row but it is empty.") + + + def rows(self, strict: bool = True): + """Yield rows in a dictionary format. Empty lines are ignored. + + Parameters + ---------- + strict + If True, raise an error when a row contains more or less than the number of expected columns. + """ + with self.open() as f: + reader = csv.DictReader(f, delimiter=self.delimiter, fieldnames=self.columns, restkey=None, restval=None) + + # Skip the header row. + next(reader) + + # NOTE: Empty lines are ignored by csv.DictReader. + # + for row in reader: + if strict: + if None in row.keys(): + raise AugurError(f"{self.path}: Line {reader.line_num} contains at least one extra column. The inferred delimiter is {self.delimiter!r}.") + if None in row.values(): + # This is distinct from a blank value (empty string). + raise AugurError(f"{self.path}: Line {reader.line_num} is missing at least one column. The inferred delimiter is {self.delimiter!r}.") + yield row + + +def get_delimiter(path: str, valid_delimiters: Iterable[str]): + """Get the delimiter of a file given a list of valid delimiters.""" + + for delimiter in valid_delimiters: + if len(delimiter) != 1: + raise AugurError(f"Delimiters must be single-character strings. {delimiter!r} does not satisfy that condition.") + + with open_file(path, newline='') as file: + try: + # Infer the delimiter from the first line. + return csv.Sniffer().sniff(file.readline(), "".join(valid_delimiters)).delimiter + except csv.Error as error: + # This assumes all csv.Errors imply a delimiter issue. That might + # change in a future Python version. + raise InvalidDelimiter from error diff --git a/augur/refine.py b/augur/refine.py index 158c86afd..41356e44d 100644 --- a/augur/refine.py +++ b/augur/refine.py @@ -216,7 +216,7 @@ def run(args): return 1 try: - metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) + metadata_object = Metadata(args.metadata, args.metadata_id_columns, delimiters=args.metadata_delimiters) except InvalidDelimiter: raise AugurError( f"Could not determine the delimiter of {args.metadata!r}. " diff --git a/docs/api/developer/augur.io.rst b/docs/api/developer/augur.io.rst index e996a8364..e14050657 100644 --- a/docs/api/developer/augur.io.rst +++ b/docs/api/developer/augur.io.rst @@ -20,4 +20,5 @@ Submodules augur.io.shell_command_runner augur.io.sqlite3 augur.io.strains + augur.io.tabular_file augur.io.vcf diff --git a/docs/api/developer/augur.io.tabular_file.rst b/docs/api/developer/augur.io.tabular_file.rst new file mode 100644 index 000000000..c6efb90ec --- /dev/null +++ b/docs/api/developer/augur.io.tabular_file.rst @@ -0,0 +1,7 @@ +augur.io.tabular\_file module +============================= + +.. automodule:: augur.io.tabular_file + :members: + :undoc-members: + :show-inheritance: From acc4913b4e65ab0a12362590258aa21bf09c42a8 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Wed, 31 Jan 2024 14:59:26 -0800 Subject: [PATCH 03/15] Allow custom delimiter, header, columns to TabularFile To be used in future commits. --- augur/io/tabular_file.py | 64 +++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/augur/io/tabular_file.py b/augur/io/tabular_file.py index be1771fd7..b95bb7553 100644 --- a/augur/io/tabular_file.py +++ b/augur/io/tabular_file.py @@ -13,7 +13,7 @@ class InvalidDelimiter(Exception): class TabularFile(File): - """Represents a tabular file with a delimiter and required header line. + """Represents a tabular file with a delimiter and optional header line. A pandas DataFrame can provide these properties, however this class is more suitable for large files as it does not store the rows in memory. @@ -28,31 +28,59 @@ class TabularFile(File): delimiter: str """Inferred delimiter of file.""" - def __init__(self, path: str, delimiters: Sequence[str] = None): + header: bool + """Whether the first of the file represents column names.""" + + def __init__(self, path: str, delimiter: str = None, delimiters: Sequence[str] = None, + header: bool = True, columns: Sequence[str] = None): """ Parameters ---------- path Path of the tabular file. + delimiter + Use this as the delimiter. delimiters Possible delimiters to use, in order of precedence. + header + If true, the first line will be used as column names and not a row of data. + columns + If set, this will be used for column names. """ super().__init__(path) - if delimiters is None: - delimiters = DEFAULT_DELIMITERS - - # Infer the dialect. - self.delimiter = get_delimiter(self.path, delimiters) - - # Infer the column names. - with self.open() as f: - reader = csv.reader(f, delimiter=self.delimiter) - try: - self.columns = next(reader) - except StopIteration: - raise AugurError(f"{self.path}: Expected a header row but it is empty.") - + if delimiter: + # Delimiter is given as an argument. + if delimiters: + raise ValueError("At most one of delimiter and delimiters can be set.") + self.delimiter = delimiter + else: + # Delimiter is inferred from possible delimiters. + if not delimiters: + delimiters = DEFAULT_DELIMITERS + + # Infer the dialect. + self.delimiter = get_delimiter(self.path, delimiters) + + if header: + # Infer column names from the header. + if columns: + raise ValueError("Tabular file must have either a header row or column names specified, but not both.") + else: + with self.open() as f: + reader = csv.reader(f, delimiter=self.delimiter) + try: + self.header = True + self.columns = next(reader) + except StopIteration: + raise AugurError(f"{self.path}: Expected a header row but it is empty.") + else: + # Column names should be given as an argument. + if not columns: + raise ValueError("Tabular file must have either a header row or column names specified.") + else: + self.header = False + self.columns = columns def rows(self, strict: bool = True): """Yield rows in a dictionary format. Empty lines are ignored. @@ -65,8 +93,8 @@ def rows(self, strict: bool = True): with self.open() as f: reader = csv.DictReader(f, delimiter=self.delimiter, fieldnames=self.columns, restkey=None, restval=None) - # Skip the header row. - next(reader) + if self.header: + next(reader) # NOTE: Empty lines are ignored by csv.DictReader. # From e5348e86f8893e72bc9c1efbbd6219b8c41472a7 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 30 Dec 2022 20:41:29 -0800 Subject: [PATCH 04/15] io/sqlite3: Add methods to manage primary indexes This serves multiple purposes: 1. Detect duplicates. 2. Speed up strain-based queries. 3. Provide a marker for a single primary index column (akin to a pandas DataFrame index column). Also adds DuplicateError, a new exception class to expose duplicates for custom error messages. This will be used in a future commit by a new implementation of augur filter. --- augur/io/sqlite3.py | 57 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/augur/io/sqlite3.py b/augur/io/sqlite3.py index 71f989c51..217afd43b 100644 --- a/augur/io/sqlite3.py +++ b/augur/io/sqlite3.py @@ -9,6 +9,11 @@ from .file import open_file +class DuplicateError(Exception): + def __init__(self, duplicated_values: Iterable): + self.duplicated_values = duplicated_values + + class Sqlite3Database(AbstractContextManager): """Represents a SQLite3 database. @@ -117,6 +122,58 @@ def insert(self, table: str, columns: Iterable[str], rows: Iterable[Dict[str, An except sqlite3.ProgrammingError as e: raise Exception("Failed to insert rows.") from e + def create_primary_index(self, table: str, column: str): + """Create a primary index in a table with a unique constraint. + + This defines a "primary index" as a singular index per table that can be added after table creation. + + If SQLite allowed setting PRIMARY KEYs after table creation, this would not be necessary. + + Avoid setting PRIMARY KEYs before inserting data for optimal data insertion speeds¹. + ¹ https://stackoverflow.com/q/1711631/4410590 + """ + self._assert_table_existence(table) + index = f'primary_index_{table}' + try: + self.connection.execute(f""" + CREATE UNIQUE INDEX {sanitize_identifier(index)} + ON {sanitize_identifier(table)} ({sanitize_identifier(column)}) + """) + except sqlite3.IntegrityError: + results = self.connection.execute(f""" + SELECT + {sanitize_identifier(column)} as id, + COUNT(*) AS count + FROM {table} + GROUP BY {sanitize_identifier(column)} + HAVING count > 1 + """) + duplicates = (row['id'] for row in results) + raise DuplicateError(duplicates) + + def get_primary_index(self, table: str): + """Get the primary index of a table.""" + + index = f'primary_index_{table}' + + results = self.connection.execute(f""" + SELECT + ii.name + FROM sqlite_master AS m, + pragma_index_list(m.name) AS il, + pragma_index_info(il.name) AS ii + WHERE + m.type = 'table' + AND m.tbl_name = '{table}' + AND il.name = '{index}'; + """) + columns = [result[0] for result in results] + + # Check that there is exactly one primary index. + assert len(columns) == 1 + + return columns[0] + def query_to_file(self, query: str, path: str, header: bool = True, delimiter: str = '\t'): """Query the database and write results to a tabular file. From 0c7133421dfd82813d79a12bb6cce40828a289ea Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 10 Mar 2023 10:52:56 -0800 Subject: [PATCH 05/15] io/sqlite3: Add debugging function This is unused but can come in handy when debugging a query. --- augur/io/sqlite3.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/augur/io/sqlite3.py b/augur/io/sqlite3.py index 217afd43b..47713dde3 100644 --- a/augur/io/sqlite3.py +++ b/augur/io/sqlite3.py @@ -1,4 +1,5 @@ import csv +import pandas as pd import random import re import string @@ -174,6 +175,12 @@ def get_primary_index(self, table: str): return columns[0] + def show_query(self, query: str, nrows=10): + """Print the first nrows of query results.""" + # Use pandas for a nicely formatted output. + df_chunks = pd.read_sql_query(query, self.connection, chunksize=nrows) + print(next(df_chunks)) + def query_to_file(self, query: str, path: str, header: bool = True, delimiter: str = '\t'): """Query the database and write results to a tabular file. From 22b80543f4e37608c700fcb01925c79705be5f2b Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 10 Mar 2023 10:54:10 -0800 Subject: [PATCH 06/15] dates: Allow ambiguity resolution method in get_numerical_date_from_value() This will be used in a future commit by a new implementation of augur filter. --- augur/dates/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/augur/dates/__init__.py b/augur/dates/__init__.py index 119954027..8517479bb 100644 --- a/augur/dates/__init__.py +++ b/augur/dates/__init__.py @@ -107,7 +107,7 @@ def is_date_ambiguous(date, ambiguous_by): "X" in day and ambiguous_by in ("any", "day") )) -def get_numerical_date_from_value(value, fmt=None, min_max_year=None): +def get_numerical_date_from_value(value, fmt=None, min_max_year=None, ambiguity_resolver='both'): value = str(value) if re.match(r'^-*\d+\.\d+$', value): # numeric date which can be negative @@ -120,7 +120,15 @@ def get_numerical_date_from_value(value, fmt=None, min_max_year=None): ambig_date = AmbiguousDate(value, fmt=fmt).range(min_max_year=min_max_year) except InvalidDate as error: raise AugurError(str(error)) from error - return [treetime.utils.numeric_date(d) for d in ambig_date] + ambig_date_numeric = [treetime.utils.numeric_date(d) for d in ambig_date] + if ambiguity_resolver == 'both': + return ambig_date_numeric + elif ambiguity_resolver == 'min': + return ambig_date_numeric[0] + elif ambiguity_resolver == 'max': + return ambig_date_numeric[1] + else: + raise Exception(f"Invalid value for ambiguity_resolver: {ambiguity_resolver!r}.") try: return treetime.utils.numeric_date(datetime.datetime.strptime(value, fmt)) except: From 9f08218610d4e324a13d1720ad0b71940bb59ec6 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Tue, 28 Mar 2023 15:43:23 -0700 Subject: [PATCH 07/15] filter/run: Split into smaller functions by using shared variables in constants.py These variables are temporary to bridge the transition from using shared variables to a shared database. They will be removed in a future commit. --- augur/filter/_run.py | 222 +++------------------ augur/filter/constants.py | 10 + augur/filter/include_exclude_rules.py | 12 +- augur/filter/io.py | 128 +++++++++++- augur/filter/report.py | 69 +++++++ docs/api/developer/augur.filter.report.rst | 7 + docs/api/developer/augur.filter.rst | 1 + 7 files changed, 243 insertions(+), 206 deletions(-) create mode 100644 augur/filter/report.py create mode 100644 docs/api/developer/augur.filter.report.rst diff --git a/augur/filter/_run.py b/augur/filter/_run.py index cca387f13..172bd902a 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -1,93 +1,29 @@ from collections import defaultdict import csv import itertools -import json import numpy as np -import os -import pandas as pd -from tempfile import NamedTemporaryFile from augur.errors import AugurError -from augur.index import ( - index_sequences, - index_vcf, - ID_COLUMN as SEQUENCE_INDEX_ID_COLUMN, - DELIMITER as SEQUENCE_INDEX_DELIMITER, -) -from augur.io.file import PANDAS_READ_CSV_OPTIONS, open_file +from augur.io.file import open_file from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata -from augur.io.sequences import read_sequences, write_sequences from augur.io.print import print_err -from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf -from augur.types import EmptyOutputReportingMethod -from . import include_exclude_rules -from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs +from augur.io.tabular_file import InvalidDelimiter +from . import constants +from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs, import_sequence_index, read_and_output_sequences from .include_exclude_rules import apply_filters, construct_filters +from .report import print_report from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling, get_weighted_group_sizes def run(args): - # Determine whether the sequence index exists or whether should be - # generated. We need to generate an index if the input sequences are in a - # VCF, if sequence output has been requested (so we can filter strains by - # sequences that are present), or if any other sequence-based filters have - # been requested. - sequence_strains = None - sequence_index_path = args.sequence_index - build_sequence_index = False - is_vcf = filename_is_vcf(args.sequences) - - # Don't build sequence index with --exclude-all since the only way to add - # strains back in with this flag are the `--include` or `--include-where` - # options, so we know we don't need a sequence index to apply any additional - # filters. - if sequence_index_path is None and args.sequences and not args.exclude_all: - build_sequence_index = True - - if build_sequence_index: - # Generate the sequence index on the fly, for backwards compatibility - # with older workflows that don't generate the index ahead of time. - # Create a temporary index using a random filename to avoid collisions - # between multiple filter commands. - with NamedTemporaryFile(delete=False) as sequence_index_file: - sequence_index_path = sequence_index_file.name - - print_err( - "Note: You did not provide a sequence index, so Augur will generate one.", - "You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`." - ) - - if is_vcf: - index_vcf(args.sequences, sequence_index_path) - else: - index_sequences(args.sequences, sequence_index_path) - - # Load the sequence index, if a path exists. - sequence_index = None - if sequence_index_path: - sequence_index = pd.read_csv( - sequence_index_path, - sep=SEQUENCE_INDEX_DELIMITER, - index_col=SEQUENCE_INDEX_ID_COLUMN, - dtype={SEQUENCE_INDEX_ID_COLUMN: "string"}, - **PANDAS_READ_CSV_OPTIONS, - ) - - # Remove temporary index file, if it exists. - if build_sequence_index: - os.unlink(sequence_index_path) - - sequence_strains = set(sequence_index.index.values) + import_sequence_index(args) ##################################### #Filtering steps ##################################### # Setup filters. - exclude_by, include_by = construct_filters( - args, - sequence_index, - ) + exclude_by, include_by = construct_filters(args) # Setup grouping. We handle the following major use cases: # @@ -154,13 +90,13 @@ def run(args): # Load metadata. Metadata are the source of truth for which sequences we # want to keep in filtered output. - metadata_strains = set() - valid_strains = set() # TODO: rename this more clearly - all_sequences_to_include = set() - filter_counts = defaultdict(int) + constants.metadata_strains = set() + constants.valid_strains = set() # TODO: rename this more clearly + constants.all_sequences_to_include = set() + constants.filter_counts = defaultdict(int) try: - metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) + metadata_object = Metadata(args.metadata, args.metadata_id_columns, delimiters=args.metadata_delimiters) except InvalidDelimiter: raise AugurError( f"Could not determine the delimiter of {args.metadata!r}. " @@ -180,14 +116,14 @@ def run(args): for metadata in metadata_reader: duplicate_strains = ( set(metadata.index[metadata.index.duplicated()]) | - (set(metadata.index) & metadata_strains) + (set(metadata.index) & constants.metadata_strains) ) if len(duplicate_strains) > 0: cleanup_outputs(args) raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) # Maintain list of all strains seen. - metadata_strains.update(set(metadata.index.values)) + constants.metadata_strains.update(set(metadata.index.values)) # Filter metadata. seq_keep, sequences_to_filter, sequences_to_include = apply_filters( @@ -195,7 +131,7 @@ def run(args): exclude_by, include_by, ) - valid_strains.update(seq_keep) + constants.valid_strains.update(seq_keep) # Track distinct strains to include, so we can write their # corresponding metadata, strains, or sequences later, as needed. @@ -203,13 +139,13 @@ def run(args): record["strain"] for record in sequences_to_include } - all_sequences_to_include.update(distinct_force_included_strains) + constants.all_sequences_to_include.update(distinct_force_included_strains) # Track reasons for filtered or force-included strains, so we can # report total numbers filtered and included at the end. Optionally, # write out these reasons to a log file. for filtered_strain in itertools.chain(sequences_to_filter, sequences_to_include): - filter_counts[(filtered_strain["filter"], filtered_strain["kwargs"])] += 1 + constants.filter_counts[(filtered_strain["filter"], filtered_strain["kwargs"])] += 1 # Log the names of strains that were filtered or force-included, # so we can properly account for each strain (e.g., including @@ -307,9 +243,9 @@ def run(args): # have passed filters. metadata_reader = read_metadata( args.metadata, - delimiters=args.metadata_delimiters, + delimiters=[metadata_object.delimiter], columns=useful_metadata_columns, - id_columns=args.metadata_id_columns, + id_columns=[metadata_object.id_column], chunk_size=args.metadata_chunk_size, dtype="string", ) @@ -318,11 +254,11 @@ def run(args): # metadata a second time. TODO: We could store these in memory # during the first pass, but we want to minimize overall memory # usage at the moment. - seq_keep = set(metadata.index.values) & valid_strains + seq_keep = set(metadata.index.values) & constants.valid_strains # Prevent force-included strains from being considered in this # second pass, as in the first pass. - seq_keep = seq_keep - all_sequences_to_include + seq_keep = seq_keep - constants.all_sequences_to_include group_by_strain = get_groups_for_subsampling( seq_keep, @@ -339,7 +275,7 @@ def run(args): # If we have any records in queues, we have grouped results and need to # stream the highest priority records to the requested outputs. - num_excluded_subsamp = 0 + constants.num_excluded_subsamp = 0 if queues_by_group: # Populate the set of strains to keep from the records in queues. subsampled_strains = set() @@ -355,8 +291,8 @@ def run(args): # Count and optionally log strains that were not included due to # subsampling. - strains_filtered_by_subsampling = valid_strains - subsampled_strains - num_excluded_subsamp = len(strains_filtered_by_subsampling) + strains_filtered_by_subsampling = constants.valid_strains - subsampled_strains + constants.num_excluded_subsamp = len(strains_filtered_by_subsampling) if output_log_writer: for strain in strains_filtered_by_subsampling: output_log_writer.writerow({ @@ -365,115 +301,13 @@ def run(args): "kwargs": "", }) - valid_strains = subsampled_strains - - # Force inclusion of specific strains after filtering and subsampling. - valid_strains = valid_strains | all_sequences_to_include - - # Write output starting with sequences, if they've been requested. It is - # possible for the input sequences and sequence index to be out of sync - # (e.g., the index is a superset of the given sequences input), so we need - # to update the set of strains to keep based on which strains are actually - # available. - if is_vcf: - if args.output: - # Get the samples to be deleted, not to keep, for VCF - dropped_samps = list(sequence_strains - valid_strains) - write_vcf(args.sequences, args.output, dropped_samps) - elif args.sequences: - sequences = read_sequences(args.sequences) - - # If the user requested sequence output, stream to disk all sequences - # that passed all filters to avoid reading sequences into memory first. - # Even if we aren't emitting sequences, we track the observed strain - # names in the sequence file as part of the single pass to allow - # comparison with the provided sequence index. - if args.output: - observed_sequence_strains = set() - with open_file(args.output, "wt") as output_handle: - for sequence in sequences: - observed_sequence_strains.add(sequence.id) - - if sequence.id in valid_strains: - write_sequences(sequence, output_handle, 'fasta') - else: - observed_sequence_strains = {sequence.id for sequence in sequences} - - if sequence_strains != observed_sequence_strains: - # Warn the user if the expected strains from the sequence index are - # not a superset of the observed strains. - if sequence_strains is not None and observed_sequence_strains > sequence_strains: - print_err( - "WARNING: The sequence index is out of sync with the provided sequences.", - "Metadata and strain output may not match sequence output." - ) + constants.valid_strains = subsampled_strains - # Update the set of available sequence strains. - sequence_strains = observed_sequence_strains + read_and_output_sequences(args) if args.output_metadata or args.output_strains: write_metadata_based_outputs(args.metadata, args.metadata_delimiters, args.metadata_id_columns, args.output_metadata, - args.output_strains, valid_strains) - - # Calculate the number of strains that don't exist in either metadata or - # sequences. - num_excluded_by_lack_of_metadata = 0 - if sequence_strains: - num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains) - - - # Calculate the number of strains passed and filtered. - total_strains_passed = len(valid_strains) - total_strains_filtered = len(metadata_strains) + num_excluded_by_lack_of_metadata - total_strains_passed - - print_err(f"{total_strains_filtered} {'strain was' if total_strains_filtered == 1 else 'strains were'} dropped during filtering") - - if num_excluded_by_lack_of_metadata: - print_err(f"\t{num_excluded_by_lack_of_metadata} had no metadata") - - report_template_by_filter_name = { - include_exclude_rules.filter_by_sequence_index.__name__: "{count} had no sequence data", - include_exclude_rules.filter_by_exclude_all.__name__: "{count} {were} dropped by `--exclude-all`", - include_exclude_rules.filter_by_exclude.__name__: "{count} {were} dropped because {they} {were} in {exclude_file}", - include_exclude_rules.filter_by_exclude_where.__name__: "{count} {were} dropped because of '{exclude_where}'", - include_exclude_rules.filter_by_query.__name__: "{count} {were} filtered out by the query: \"{query}\"", - include_exclude_rules.filter_by_ambiguous_date.__name__: "{count} {were} dropped because of their ambiguous date in {ambiguity}", - include_exclude_rules.filter_by_min_date.__name__: "{count} {were} dropped because {they} {were} earlier than {min_date} or missing a date", - include_exclude_rules.filter_by_max_date.__name__: "{count} {were} dropped because {they} {were} later than {max_date} or missing a date", - include_exclude_rules.filter_by_min_length.__name__: "{count} {were} dropped because {they} {were} shorter than the minimum length of {min_length}bp when only counting standard nucleotide characters A, C, G, or T (case-insensitive)", - include_exclude_rules.filter_by_max_length.__name__: "{count} {were} dropped because {they} {were} longer than the maximum length of {max_length}bp when only counting standard nucleotide characters A, C, G, or T (case-insensitive)", - include_exclude_rules.filter_by_non_nucleotide.__name__: "{count} {were} dropped because {they} had non-nucleotide characters", - include_exclude_rules.skip_group_by_with_ambiguous_year.__name__: "{count} {were} dropped during grouping due to ambiguous year information", - include_exclude_rules.skip_group_by_with_ambiguous_month.__name__: "{count} {were} dropped during grouping due to ambiguous month information", - include_exclude_rules.skip_group_by_with_ambiguous_day.__name__: "{count} {were} dropped during grouping due to ambiguous day information", - include_exclude_rules.force_include_strains.__name__: "{count} {were} added back because {they} {were} in {include_file}", - include_exclude_rules.force_include_where.__name__: "{count} {were} added back because of '{include_where}'", - } - for (filter_name, filter_kwargs), count in filter_counts.items(): - if filter_kwargs: - parameters = dict(json.loads(filter_kwargs)) - else: - parameters = {} - - parameters["count"] = count - parameters["were"] = "was" if count == 1 else "were" - parameters["they"] = "it" if count == 1 else "they" - print_err("\t" + report_template_by_filter_name[filter_name].format(**parameters)) - - if (group_by and args.sequences_per_group) or args.subsample_max_sequences: - seed_txt = ", using seed {}".format(args.subsample_seed) if args.subsample_seed else "" - print_err(f"\t{num_excluded_subsamp} {'was' if num_excluded_subsamp == 1 else 'were'} dropped because of subsampling criteria{seed_txt}") - - if total_strains_passed == 0: - empty_results_message = "All samples have been dropped! Check filter rules and metadata file format." - if args.empty_output_reporting is EmptyOutputReportingMethod.ERROR: - raise AugurError(empty_results_message) - elif args.empty_output_reporting is EmptyOutputReportingMethod.WARN: - print_err(f"WARNING: {empty_results_message}") - elif args.empty_output_reporting is EmptyOutputReportingMethod.SILENT: - pass - else: - raise ValueError(f"Encountered unhandled --empty-output-reporting method {args.empty_output_reporting!r}") + args.output_strains, constants.valid_strains) - print_err(f"{total_strains_passed} {'strain' if total_strains_passed == 1 else 'strains'} passed all filters") + print_report(args) diff --git a/augur/filter/constants.py b/augur/filter/constants.py index 139ff4092..a028e4d17 100644 --- a/augur/filter/constants.py +++ b/augur/filter/constants.py @@ -1,3 +1,13 @@ +# Shared variables set at run time. +# TODO: Remove these with the database implementation. +sequence_index = None +sequence_strains = None +metadata_strains = None +valid_strains = None +all_sequences_to_include = None +filter_counts = None +num_excluded_subsamp = None + # Generated date columns. DATE_YEAR_COLUMN = 'year' DATE_MONTH_COLUMN = 'month' diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index 05e58481b..35ff21c1d 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -589,7 +589,7 @@ def force_include_where(metadata, include_where) -> FilterFunctionReturn: return included -def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[FilterOption]]: +def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: """Construct lists of filters and inclusion criteria based on user-provided arguments. @@ -597,8 +597,6 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi ---------- args : argparse.Namespace Command line arguments provided by the user. - sequence_index : pandas.DataFrame - Sequence index for the provided arguments. """ exclude_by: List[FilterOption] = [] include_by: List[FilterOption] = [] @@ -629,11 +627,11 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi exclude_by.append((filter_by_exclude_all, {})) # Filter by sequence index. - if sequence_index is not None: + if constants.sequence_index is not None: exclude_by.append(( filter_by_sequence_index, { - "sequence_index": sequence_index, + "sequence_index": constants.sequence_index, }, )) @@ -705,7 +703,7 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi exclude_by.append(( filter_by_min_length, { - "sequence_index": sequence_index, + "sequence_index": constants.sequence_index, "min_length": args.min_length, } )) @@ -726,7 +724,7 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi exclude_by.append(( filter_by_non_nucleotide, { - "sequence_index": sequence_index, + "sequence_index": constants.sequence_index, } )) diff --git a/augur/filter/io.py b/augur/filter/io.py index 0e03b12e5..5fe145743 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -6,14 +6,24 @@ from textwrap import dedent from typing import Sequence, Set import numpy as np +import pandas as pd +from tempfile import NamedTemporaryFile from collections import defaultdict from xopen import xopen from augur.errors import AugurError -from augur.io.file import open_file +from augur.index import ( + index_sequences, + index_vcf, + ID_COLUMN as SEQUENCE_INDEX_ID_COLUMN, + DELIMITER as SEQUENCE_INDEX_DELIMITER, +) +from augur.io.file import PANDAS_READ_CSV_OPTIONS, open_file from augur.io.metadata import Metadata, METADATA_DATE_COLUMN from augur.io.print import print_err -from .constants import GROUP_BY_GENERATED_COLUMNS +from augur.io.sequences import read_sequences, write_sequences +from augur.io.vcf import is_vcf, write_vcf +from . import constants from .include_exclude_rules import extract_variables, parse_filter_query @@ -29,12 +39,12 @@ def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Se if (args.exclude_ambiguous_dates_by or args.min_date or args.max_date - or (args.group_by and GROUP_BY_GENERATED_COLUMNS.intersection(args.group_by))): + or (args.group_by and constants.GROUP_BY_GENERATED_COLUMNS.intersection(args.group_by))): columns.add(METADATA_DATE_COLUMN) if args.group_by: group_by_set = set(args.group_by) - requested_generated_columns = group_by_set & GROUP_BY_GENERATED_COLUMNS + requested_generated_columns = group_by_set & constants.GROUP_BY_GENERATED_COLUMNS # Add columns used for grouping. columns.update(group_by_set - requested_generated_columns) @@ -103,7 +113,7 @@ def write_metadata_based_outputs(input_metadata_path: str, delimiters: Sequence[ Write output metadata and/or strains file given input metadata information and a set of IDs to write. """ - input_metadata = Metadata(input_metadata_path, delimiters, id_columns) + input_metadata = Metadata(input_metadata_path, id_columns, delimiters=delimiters) # Handle all outputs with one pass of metadata. This requires using # conditionals both outside of and inside the loop through metadata rows. @@ -158,6 +168,114 @@ def column_type_pair(input: str): return (column, dtype) +def import_sequence_index(args): + # Determine whether the sequence index exists or whether should be + # generated. We need to generate an index if the input sequences are in a + # VCF, if sequence output has been requested (so we can filter strains by + # sequences that are present), or if any other sequence-based filters have + # been requested. + sequence_index_path = args.sequence_index + build_sequence_index = False + + # Don't build sequence index with --exclude-all since the only way to add + # strains back in with this flag are the `--include` or `--include-where` + # options, so we know we don't need a sequence index to apply any additional + # filters. + if sequence_index_path is None and args.sequences and not args.exclude_all: + build_sequence_index = True + + if build_sequence_index: + sequence_index_path = _generate_sequence_index(args.sequences) + + # Load the sequence index, if a path exists. + if sequence_index_path: + constants.sequence_index = pd.read_csv( + sequence_index_path, + sep=SEQUENCE_INDEX_DELIMITER, + index_col=SEQUENCE_INDEX_ID_COLUMN, + dtype={SEQUENCE_INDEX_ID_COLUMN: "string"}, + **PANDAS_READ_CSV_OPTIONS, + ) + + # Remove temporary index file, if it exists. + if build_sequence_index: + os.unlink(sequence_index_path) + + constants.sequence_strains = set(constants.sequence_index.index.values) + + +def _generate_sequence_index(sequences_file): + """Generate a sequence index file. + """ + # Generate the sequence index on the fly, for backwards compatibility + # with older workflows that don't generate the index ahead of time. + # Create a temporary index using a random filename to avoid collisions + # between multiple filter commands. + with NamedTemporaryFile(delete=False) as sequence_index_file: + sequence_index_path = sequence_index_file.name + + print_err( + "Note: You did not provide a sequence index, so Augur will generate one.", + "You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`." + ) + + # FIXME: call a function in index_sequences which already handles VCF vs. FASTA + if is_vcf(sequences_file): + index_vcf(sequences_file, sequence_index_path) + else: + index_sequences(sequences_file, sequence_index_path) + + return sequence_index_path + + +def read_and_output_sequences(args): + """Read sequences and output all that passed filtering. + """ + # Force inclusion of specific strains after filtering and subsampling. + constants.valid_strains = constants.valid_strains | constants.all_sequences_to_include + + # Write output starting with sequences, if they've been requested. It is + # possible for the input sequences and sequence index to be out of sync + # (e.g., the index is a superset of the given sequences input), so we need + # to update the set of strains to keep based on which strains are actually + # available. + if is_vcf(args.sequences): + if args.output: + # Get the samples to be deleted, not to keep, for VCF + dropped_samps = list(constants.sequence_strains - constants.valid_strains) + write_vcf(args.sequences, args.output, dropped_samps) + elif args.sequences: + sequences = read_sequences(args.sequences) + + # If the user requested sequence output, stream to disk all sequences + # that passed all filters to avoid reading sequences into memory first. + # Even if we aren't emitting sequences, we track the observed strain + # names in the sequence file as part of the single pass to allow + # comparison with the provided sequence index. + if args.output: + observed_sequence_strains = set() + with open_file(args.output, "wt") as output_handle: + for sequence in sequences: + observed_sequence_strains.add(sequence.id) + + if sequence.id in constants.valid_strains: + write_sequences(sequence, output_handle, 'fasta') + else: + observed_sequence_strains = {sequence.id for sequence in sequences} + + if constants.sequence_strains != observed_sequence_strains: + # Warn the user if the expected strains from the sequence index are + # not a superset of the observed strains. + if constants.sequence_strains is not None and observed_sequence_strains > constants.sequence_strains: + print_err( + "WARNING: The sequence index is out of sync with the provided sequences.", + "Metadata and strain output may not match sequence output." + ) + + # Update the set of available sequence strains. + constants.sequence_strains = observed_sequence_strains + + def cleanup_outputs(args): """Remove output files. Useful when terminating midway through a loop of metadata chunks.""" if args.output: diff --git a/augur/filter/report.py b/augur/filter/report.py new file mode 100644 index 000000000..125f9495b --- /dev/null +++ b/augur/filter/report.py @@ -0,0 +1,69 @@ +import json +from augur.errors import AugurError +from augur.io.print import print_err +from augur.types import EmptyOutputReportingMethod +from . import constants, include_exclude_rules + + +def print_report(args): + """Print a report of how many strains were dropped and reasoning.""" + # Calculate the number of strains that don't exist in either metadata or + # sequences. + num_excluded_by_lack_of_metadata = 0 + if constants.sequence_strains: + num_excluded_by_lack_of_metadata = len(constants.sequence_strains - constants.metadata_strains) + + # Calculate the number of strains passed and filtered. + total_strains_passed = len(constants.valid_strains) + total_strains_filtered = len(constants.metadata_strains) + num_excluded_by_lack_of_metadata - total_strains_passed + + print_err(f"{total_strains_filtered} {'strain was' if total_strains_filtered == 1 else 'strains were'} dropped during filtering") + + if num_excluded_by_lack_of_metadata: + print_err(f"\t{num_excluded_by_lack_of_metadata} had no metadata") + + report_template_by_filter_name = { + include_exclude_rules.filter_by_sequence_index.__name__: "{count} had no sequence data", + include_exclude_rules.filter_by_exclude_all.__name__: "{count} {were} dropped by `--exclude-all`", + include_exclude_rules.filter_by_exclude.__name__: "{count} {were} dropped because {they} {were} in {exclude_file}", + include_exclude_rules.filter_by_exclude_where.__name__: "{count} {were} dropped because of '{exclude_where}'", + include_exclude_rules.filter_by_query.__name__: "{count} {were} filtered out by the query: \"{query}\"", + include_exclude_rules.filter_by_ambiguous_date.__name__: "{count} {were} dropped because of their ambiguous date in {ambiguity}", + include_exclude_rules.filter_by_min_date.__name__: "{count} {were} dropped because {they} {were} earlier than {min_date} or missing a date", + include_exclude_rules.filter_by_max_date.__name__: "{count} {were} dropped because {they} {were} later than {max_date} or missing a date", + include_exclude_rules.filter_by_min_length.__name__: "{count} {were} dropped because {they} {were} shorter than the minimum length of {min_length}bp when only counting standard nucleotide characters A, C, G, or T (case-insensitive)", + include_exclude_rules.filter_by_max_length.__name__: "{count} {were} dropped because {they} {were} longer than the maximum length of {max_length}bp when only counting standard nucleotide characters A, C, G, or T (case-insensitive)", + include_exclude_rules.filter_by_non_nucleotide.__name__: "{count} {were} dropped because {they} had non-nucleotide characters", + include_exclude_rules.skip_group_by_with_ambiguous_year.__name__: "{count} {were} dropped during grouping due to ambiguous year information", + include_exclude_rules.skip_group_by_with_ambiguous_month.__name__: "{count} {were} dropped during grouping due to ambiguous month information", + include_exclude_rules.skip_group_by_with_ambiguous_day.__name__: "{count} {were} dropped during grouping due to ambiguous day information", + include_exclude_rules.force_include_strains.__name__: "{count} {were} added back because {they} {were} in {include_file}", + include_exclude_rules.force_include_where.__name__: "{count} {were} added back because of '{include_where}'", + } + for (filter_name, filter_kwargs), count in constants.filter_counts.items(): + if filter_kwargs: + parameters = dict(json.loads(filter_kwargs)) + else: + parameters = {} + + parameters["count"] = count + parameters["were"] = "was" if count == 1 else "were" + parameters["they"] = "it" if count == 1 else "they" + print_err("\t" + report_template_by_filter_name[filter_name].format(**parameters)) + + if (args.group_by and args.sequences_per_group) or args.subsample_max_sequences: + seed_txt = ", using seed {}".format(args.subsample_seed) if args.subsample_seed else "" + print_err(f"\t{constants.num_excluded_subsamp} {'was' if constants.num_excluded_subsamp == 1 else 'were'} dropped because of subsampling criteria{seed_txt}") + + if total_strains_passed == 0: + empty_results_message = "All samples have been dropped! Check filter rules and metadata file format." + if args.empty_output_reporting is EmptyOutputReportingMethod.ERROR: + raise AugurError(empty_results_message) + elif args.empty_output_reporting is EmptyOutputReportingMethod.WARN: + print_err(f"WARNING: {empty_results_message}") + elif args.empty_output_reporting is EmptyOutputReportingMethod.SILENT: + pass + else: + raise ValueError(f"Encountered unhandled --empty-output-reporting method {args.empty_output_reporting!r}") + + print_err(f"{total_strains_passed} {'strain' if total_strains_passed == 1 else 'strains'} passed all filters") diff --git a/docs/api/developer/augur.filter.report.rst b/docs/api/developer/augur.filter.report.rst new file mode 100644 index 000000000..534cfc13c --- /dev/null +++ b/docs/api/developer/augur.filter.report.rst @@ -0,0 +1,7 @@ +augur.filter.report module +========================== + +.. automodule:: augur.filter.report + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/developer/augur.filter.rst b/docs/api/developer/augur.filter.rst index c24ddbac3..59c2b7058 100644 --- a/docs/api/developer/augur.filter.rst +++ b/docs/api/developer/augur.filter.rst @@ -15,6 +15,7 @@ Submodules augur.filter.constants augur.filter.include_exclude_rules augur.filter.io + augur.filter.report augur.filter.subsample augur.filter.validate_arguments augur.filter.weights_file From 2e46b3a888f4a2a723bcfbb8aa38ca5392877bb5 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Tue, 28 Mar 2023 11:26:36 -0700 Subject: [PATCH 08/15] (1/3) filter: Rewrite using SQLite3 WARNING: This change won't work as-is. Broken out as a separate commit for easier review. Remove metadata input and examples in include_exclude_rules. These are not applicable in the SQLite-based implementation. --- augur/filter/include_exclude_rules.py | 272 +++----------------------- 1 file changed, 31 insertions(+), 241 deletions(-) diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index 35ff21c1d..2f98afef9 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -34,46 +34,22 @@ FilterOption = Tuple[FilterFunction, FilterFunctionKwargs] -def filter_by_exclude_all(metadata) -> FilterFunctionReturn: +def filter_by_exclude_all() -> FilterFunctionReturn: """Exclude all strains regardless of the given metadata content. This is a placeholder function that can be called as part of a generalized loop through all possible functions. - - Parameters - ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) - >>> filter_by_exclude_all(metadata) - set() """ return set() -def filter_by_exclude(metadata, exclude_file) -> FilterFunctionReturn: - """Exclude the given set of strains from the given metadata. +def filter_by_exclude(exclude_file) -> FilterFunctionReturn: + """Exclude the given strains. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name exclude_file : str - Filename with strain names to exclude from the given metadata - - Examples - -------- - >>> import os - >>> from tempfile import NamedTemporaryFile - >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) - >>> with NamedTemporaryFile(delete=False) as exclude_file: - ... characters_written = exclude_file.write(b'strain1') - >>> filter_by_exclude(metadata, exclude_file.name) - {'strain2'} - >>> os.unlink(exclude_file.name) + Filename with strain names to exclude """ excluded_strains = read_strains(exclude_file) return set(metadata.index.values) - excluded_strains @@ -113,8 +89,8 @@ def parse_filter_query(query): return column, op, value -def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn: - """Exclude all strains from the given metadata that match the given exclusion query. +def filter_by_exclude_where(exclude_where) -> FilterFunctionReturn: + """Exclude all strains that match the given exclusion query. Unlike pandas query syntax, exclusion queries should follow the pattern of `"property=value"` or `"property!=value"`. Additionally, this filter treats @@ -123,26 +99,8 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn: Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name exclude_where : str Filter query used to exclude strains - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) - >>> filter_by_exclude_where(metadata, "region!=Europe") - {'strain2'} - >>> filter_by_exclude_where(metadata, "region=Europe") - {'strain1'} - >>> filter_by_exclude_where(metadata, "region=europe") - {'strain1'} - - If the column referenced in the given query does not exist, skip the filter. - - >>> sorted(filter_by_exclude_where(metadata, "missing_column=value")) - ['strain1', 'strain2'] - """ column, op, value = parse_filter_query(exclude_where) if column in metadata.columns: @@ -165,27 +123,16 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn: return filtered -def filter_by_query(metadata: pd.DataFrame, query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: +def filter_by_query(query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: """Filter metadata in the given pandas DataFrame with a query string and return the strain names that pass the filter. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name query : str Query string for the dataframe. column_types : str Dict mapping of data type - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) - >>> filter_by_query(metadata, "region == 'Africa'") - {'strain1'} - >>> filter_by_query(metadata, "region == 'North America'") - set() - """ # Create a copy to prevent modification of the original DataFrame. metadata_copy = metadata.copy() @@ -263,32 +210,15 @@ def _string_to_boolean(s: str): raise ValueError(f"Unable to convert {s!r} to a boolean value.") -def filter_by_ambiguous_date(metadata, date_column, ambiguity) -> FilterFunctionReturn: - """Filter metadata in the given pandas DataFrame where values in the given date - column have a given level of ambiguity. +def filter_by_ambiguous_date(date_column, ambiguity) -> FilterFunctionReturn: + """Filter where values in the given date column have a given level of ambiguity. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name date_column : str Column in the dataframe with dates. ambiguity : str - Level of date ambiguity to filter metadata by - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-XX"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> filter_by_ambiguous_date(metadata, date_column="date", ambiguity="any") - {'strain2'} - >>> sorted(filter_by_ambiguous_date(metadata, date_column="date", ambiguity="month")) - ['strain1', 'strain2'] - - If the requested date column does not exist, we quietly skip this filter. - - >>> sorted(filter_by_ambiguous_date(metadata, date_column="missing_column", ambiguity="any")) - ['strain1', 'strain2'] - + Level of date ambiguity to filter by """ if date_column in metadata.columns: date_is_ambiguous = metadata[date_column].apply( @@ -301,45 +231,30 @@ def filter_by_ambiguous_date(metadata, date_column, ambiguity) -> FilterFunction return filtered -def skip_group_by_with_ambiguous_year(metadata, date_column) -> FilterFunctionReturn: +def skip_group_by_with_ambiguous_year(date_column) -> FilterFunctionReturn: """Alias to filter_by_ambiguous_date for year. This is to have a named function available for the filter reason.""" - return filter_by_ambiguous_date(metadata, date_column, ambiguity="year") + return filter_by_ambiguous_date(date_column, ambiguity="year") -def skip_group_by_with_ambiguous_month(metadata, date_column) -> FilterFunctionReturn: +def skip_group_by_with_ambiguous_month(date_column) -> FilterFunctionReturn: """Alias to filter_by_ambiguous_date for month. This is to have a named function available for the filter reason.""" - return filter_by_ambiguous_date(metadata, date_column, ambiguity="month") + return filter_by_ambiguous_date(date_column, ambiguity="month") -def skip_group_by_with_ambiguous_day(metadata, date_column) -> FilterFunctionReturn: +def skip_group_by_with_ambiguous_day(date_column) -> FilterFunctionReturn: """Alias to filter_by_ambiguous_date for day. This is to have a named function available for the filter reason.""" - return filter_by_ambiguous_date(metadata, date_column, ambiguity="day") + return filter_by_ambiguous_date(date_column, ambiguity="day") -def filter_by_min_date(metadata, date_column, min_date) -> FilterFunctionReturn: - """Filter metadata by minimum date. +def filter_by_min_date(date_column, min_date) -> FilterFunctionReturn: + """Filter by minimum date. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name date_column : str Column in the dataframe with dates. min_date : float Minimum date - - Examples - -------- - >>> from augur.dates import numeric_date - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> filter_by_min_date(metadata, date_column="date", min_date=numeric_date("2020-01-02")) - {'strain2'} - - If the requested date column does not exist, we quietly skip this filter. - - >>> sorted(filter_by_min_date(metadata, date_column="missing_column", min_date=numeric_date("2020-01-02"))) - ['strain1', 'strain2'] - """ strains = set(metadata.index.values) @@ -357,30 +272,15 @@ def filter_by_min_date(metadata, date_column, min_date) -> FilterFunctionReturn: return filtered -def filter_by_max_date(metadata, date_column, max_date) -> FilterFunctionReturn: - """Filter metadata by maximum date. +def filter_by_max_date(date_column, max_date) -> FilterFunctionReturn: + """Filter by maximum date. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name date_column : str Column in the dataframe with dates. max_date : float Maximum date - - Examples - -------- - >>> from augur.dates import numeric_date - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> filter_by_max_date(metadata, date_column="date", max_date=numeric_date("2020-01-01")) - {'strain1'} - - If the requested date column does not exist, we quietly skip this filter. - - >>> sorted(filter_by_max_date(metadata, date_column="missing_column", max_date=numeric_date("2020-01-01"))) - ['strain1', 'strain2'] - """ strains = set(metadata.index.values) @@ -398,25 +298,15 @@ def filter_by_max_date(metadata, date_column, max_date) -> FilterFunctionReturn: return filtered -def filter_by_sequence_index(metadata, sequence_index) -> FilterFunctionReturn: - """Filter metadata by presence of corresponding entries in a given sequence +def filter_by_sequence_index(sequence_index) -> FilterFunctionReturn: + """Filter by presence of corresponding entries in a given sequence index. This filter effectively intersects the strain ids in the metadata and sequence index. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name sequence_index : pandas.DataFrame Sequence index - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> sequence_index = pd.DataFrame([{"strain": "strain1", "ACGT": 28000}]).set_index("strain") - >>> filter_by_sequence_index(metadata, sequence_index) - {'strain1'} - """ metadata_strains = set(metadata.index.values) sequence_index_strains = set(sequence_index.index.values) @@ -425,30 +315,14 @@ def filter_by_sequence_index(metadata, sequence_index) -> FilterFunctionReturn: def filter_by_min_length(metadata, sequence_index, min_length) -> FilterFunctionReturn: - """Filter metadata by sequence length from a given sequence index. + """Filter by sequence length from a given sequence index. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name sequence_index : pandas.DataFrame Sequence index min_length : int Minimum number of standard nucleotide characters (A, C, G, or T) in each sequence - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> sequence_index = pd.DataFrame([{"strain": "strain1", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}]).set_index("strain") - >>> filter_by_min_length(metadata, sequence_index, min_length=27000) - {'strain1'} - - It is possible for the sequence index to be missing strains present in the metadata. - - >>> sequence_index = pd.DataFrame([{"strain": "strain3", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}]).set_index("strain") - >>> filter_by_min_length(metadata, sequence_index, min_length=27000) - set() - """ strains = set(metadata.index.values) filtered_sequence_index = sequence_index.loc[ @@ -459,24 +333,15 @@ def filter_by_min_length(metadata, sequence_index, min_length) -> FilterFunction return set(filtered_sequence_index[filtered_sequence_index["ACGT"] >= min_length].index.values) -def filter_by_max_length(metadata, sequence_index, max_length) -> FilterFunctionReturn: +def filter_by_max_length(sequence_index, max_length) -> FilterFunctionReturn: """Filter metadata by sequence length from a given sequence index. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name sequence_index : pandas.DataFrame Sequence index max_length : int Maximum number of standard nucleotide characters (A, C, G, or T) in each sequence - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> sequence_index = pd.DataFrame([{"strain": "strain1", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}]).set_index("strain") - >>> filter_by_max_length(metadata, sequence_index, max_length=27000) - {'strain2'} """ strains = set(metadata.index.values) filtered_sequence_index = sequence_index.loc[ @@ -487,23 +352,13 @@ def filter_by_max_length(metadata, sequence_index, max_length) -> FilterFunction return set(filtered_sequence_index[filtered_sequence_index["ACGT"] <= max_length].index.values) -def filter_by_non_nucleotide(metadata, sequence_index) -> FilterFunctionReturn: - """Filter metadata for strains with invalid nucleotide content. +def filter_by_non_nucleotide(sequence_index) -> FilterFunctionReturn: + """Filter for strains with invalid nucleotide content. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name sequence_index : pandas.DataFrame Sequence index - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-01-02"}], index=["strain1", "strain2"]) - >>> sequence_index = pd.DataFrame([{"strain": "strain1", "invalid_nucleotides": 0}, {"strain": "strain2", "invalid_nucleotides": 1}]).set_index("strain") - >>> filter_by_non_nucleotide(metadata, sequence_index) - {'strain1'} - """ strains = set(metadata.index.values) filtered_sequence_index = sequence_index.loc[ @@ -514,34 +369,20 @@ def filter_by_non_nucleotide(metadata, sequence_index) -> FilterFunctionReturn: return set(filtered_sequence_index[no_invalid_nucleotides].index.values) -def force_include_strains(metadata, include_file) -> FilterFunctionReturn: - """Include strains in the given text file from the given metadata. +def force_include_strains(include_file) -> FilterFunctionReturn: + """Include strains in the given text file. Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name include_file : str - Filename with strain names to include from the given metadata - - Examples - -------- - >>> import os - >>> from tempfile import NamedTemporaryFile - >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) - >>> with NamedTemporaryFile(delete=False) as include_file: - ... characters_written = include_file.write(b'strain1') - >>> force_include_strains(metadata, include_file.name) - {'strain1'} - >>> os.unlink(include_file.name) - + Filename with strain names to include """ included_strains = read_strains(include_file) return set(metadata.index.values) & included_strains -def force_include_where(metadata, include_where) -> FilterFunctionReturn: - """Include all strains from the given metadata that match the given query. +def force_include_where(include_where) -> FilterFunctionReturn: + """Include all strains that match the given query. Unlike pandas query syntax, inclusion queries should follow the pattern of `"property=value"` or `"property!=value"`. Additionally, this filter treats @@ -550,26 +391,8 @@ def force_include_where(metadata, include_where) -> FilterFunctionReturn: Parameters ---------- - metadata : pandas.DataFrame - Metadata indexed by strain name include_where : str Filter query used to include strains - - Examples - -------- - >>> metadata = pd.DataFrame([{"region": "Africa"}, {"region": "Europe"}], index=["strain1", "strain2"]) - >>> force_include_where(metadata, "region!=Europe") - {'strain1'} - >>> force_include_where(metadata, "region=Europe") - {'strain2'} - >>> force_include_where(metadata, "region=europe") - {'strain2'} - - If the column referenced in the given query does not exist, skip the filter. - - >>> force_include_where(metadata, "missing_column=value") - set() - """ column, op, value = parse_filter_query(include_where) @@ -768,39 +591,6 @@ def apply_filters(metadata, exclude_by: List[FilterOption], include_by: List[Fil Strains to exclude along with the function that filtered them and the arguments used to run the function. list of dict : Strains to force-include along with the function that filtered them and the arguments used to run the function. - - - For example, filter data by minimum date, but force the include of strains - from Africa. - - Examples - -------- - >>> from augur.dates import numeric_date - >>> metadata = pd.DataFrame([{"region": "Africa", "date": "2020-01-01"}, {"region": "Europe", "date": "2020-10-02"}, {"region": "North America", "date": "2020-01-01"}], index=["strain1", "strain2", "strain3"]) - >>> exclude_by = [(filter_by_min_date, {"date_column": "date", "min_date": numeric_date("2020-04-01")})] - >>> include_by = [(force_include_where, {"include_where": "region=Africa"})] - >>> strains_to_keep, strains_to_exclude, strains_to_include = apply_filters(metadata, exclude_by, include_by) - >>> strains_to_keep - {'strain2'} - >>> sorted(strains_to_exclude, key=lambda record: record["strain"]) - [{'strain': 'strain1', 'filter': 'filter_by_min_date', 'kwargs': '[["date_column", "date"], ["min_date", 2020.25]]'}, {'strain': 'strain3', 'filter': 'filter_by_min_date', 'kwargs': '[["date_column", "date"], ["min_date", 2020.25]]'}] - >>> strains_to_include - [{'strain': 'strain1', 'filter': 'force_include_where', 'kwargs': '[["include_where", "region=Africa"]]'}] - - We also want to filter by characteristics of the sequence data that we've - annotated in a sequence index. - - >>> sequence_index = pd.DataFrame([{"strain": "strain1", "A": 7000, "C": 7000, "G": 7000, "T": 7000}, {"strain": "strain2", "A": 6500, "C": 6500, "G": 6500, "T": 6500}, {"strain": "strain3", "A": 1250, "C": 1250, "G": 1250, "T": 1250}]).set_index("strain") - >>> exclude_by = [(filter_by_min_length, {"sequence_index": sequence_index, "min_length": 27000})] - >>> include_by = [(force_include_where, {"include_where": "region=Europe"})] - >>> strains_to_keep, strains_to_exclude, strains_to_include = apply_filters(metadata, exclude_by, include_by) - >>> strains_to_keep - {'strain1'} - >>> sorted(strains_to_exclude, key=lambda record: record["strain"]) - [{'strain': 'strain2', 'filter': 'filter_by_min_length', 'kwargs': '[["min_length", 27000]]'}, {'strain': 'strain3', 'filter': 'filter_by_min_length', 'kwargs': '[["min_length", 27000]]'}] - >>> strains_to_include - [{'strain': 'strain2', 'filter': 'force_include_where', 'kwargs': '[["include_where", "region=Europe"]]'}] - """ strains_to_keep = set(metadata.index.values) strains_to_filter = [] From 237273b95ef9fc5a55cb4c5b40b9e18abea02734 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Thu, 23 Mar 2023 16:58:47 -0700 Subject: [PATCH 09/15] =?UTF-8?q?=F0=9F=9A=A7=20(2/3)=20filter:=20Rewrite?= =?UTF-8?q?=20using=20SQLite3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WARNING: This change won't work as-is. Broken out as a separate commit for easier review. Remove sequence index parameter in include_exclude_rules. These are not applicable in the SQLite-based implementation. It also removes the need to check for DataFrame instances in kwargs. --- augur/filter/include_exclude_rules.py | 48 ++++++--------------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index 2f98afef9..feaa041a7 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -298,15 +298,10 @@ def filter_by_max_date(date_column, max_date) -> FilterFunctionReturn: return filtered -def filter_by_sequence_index(sequence_index) -> FilterFunctionReturn: +def filter_by_sequence_index() -> FilterFunctionReturn: """Filter by presence of corresponding entries in a given sequence index. This filter effectively intersects the strain ids in the metadata and sequence index. - - Parameters - ---------- - sequence_index : pandas.DataFrame - Sequence index """ metadata_strains = set(metadata.index.values) sequence_index_strains = set(sequence_index.index.values) @@ -314,13 +309,12 @@ def filter_by_sequence_index(sequence_index) -> FilterFunctionReturn: return metadata_strains & sequence_index_strains -def filter_by_min_length(metadata, sequence_index, min_length) -> FilterFunctionReturn: +# FIXME: remove metadata in previous commit +def filter_by_min_length(min_length) -> FilterFunctionReturn: """Filter by sequence length from a given sequence index. Parameters ---------- - sequence_index : pandas.DataFrame - Sequence index min_length : int Minimum number of standard nucleotide characters (A, C, G, or T) in each sequence """ @@ -333,13 +327,11 @@ def filter_by_min_length(metadata, sequence_index, min_length) -> FilterFunction return set(filtered_sequence_index[filtered_sequence_index["ACGT"] >= min_length].index.values) -def filter_by_max_length(sequence_index, max_length) -> FilterFunctionReturn: +def filter_by_max_length(max_length) -> FilterFunctionReturn: """Filter metadata by sequence length from a given sequence index. Parameters ---------- - sequence_index : pandas.DataFrame - Sequence index max_length : int Maximum number of standard nucleotide characters (A, C, G, or T) in each sequence """ @@ -352,13 +344,8 @@ def filter_by_max_length(sequence_index, max_length) -> FilterFunctionReturn: return set(filtered_sequence_index[filtered_sequence_index["ACGT"] <= max_length].index.values) -def filter_by_non_nucleotide(sequence_index) -> FilterFunctionReturn: +def filter_by_non_nucleotide() -> FilterFunctionReturn: """Filter for strains with invalid nucleotide content. - - Parameters - ---------- - sequence_index : pandas.DataFrame - Sequence index """ strains = set(metadata.index.values) filtered_sequence_index = sequence_index.loc[ @@ -451,12 +438,7 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: # Filter by sequence index. if constants.sequence_index is not None: - exclude_by.append(( - filter_by_sequence_index, - { - "sequence_index": constants.sequence_index, - }, - )) + exclude_by.append((filter_by_sequence_index, {})) # Remove strains explicitly excluded by name. if args.exclude: @@ -526,7 +508,6 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: exclude_by.append(( filter_by_min_length, { - "sequence_index": constants.sequence_index, "min_length": args.min_length, } )) @@ -537,19 +518,13 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: exclude_by.append(( filter_by_max_length, { - "sequence_index": sequence_index, "max_length": args.max_length, } )) # Exclude sequences with non-nucleotide characters. if args.non_nucleotide: - exclude_by.append(( - filter_by_non_nucleotide, - { - "sequence_index": constants.sequence_index, - } - )) + exclude_by.append((filter_by_non_nucleotide, {})) if args.group_by: # The order in which these are applied later should be broad → specific @@ -674,8 +649,7 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs): -------- >>> from augur.dates import numeric_date >>> from augur.filter.include_exclude_rules import filter_by_min_length, filter_by_min_date - >>> sequence_index = pd.DataFrame([{"strain": "strain1", "ACGT": 28000}, {"strain": "strain2", "ACGT": 26000}, {"strain": "strain3", "ACGT": 5000}]).set_index("strain") - >>> exclude_by = [(filter_by_min_length, {"sequence_index": sequence_index, "min_length": 27000})] + >>> exclude_by = [(filter_by_min_length, {"min_length": 27000})] >>> _filter_kwargs_to_str(exclude_by[0][1]) '[["min_length", 27000]]' >>> exclude_by = [(filter_by_min_date, {"date_column": "date", "min_date": numeric_date("2020-03-01")})] @@ -692,10 +666,8 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs): value = kwargs[key] # Handle special cases for data types that we want to represent - # differently from their defaults or not at all. - if isinstance(value, pd.DataFrame): - continue - elif isinstance(value, float): + # differently from their defaults. + if isinstance(value, float): value = round(value, 2) kwarg_list.append((key, value)) From fdf131be7593287a8a922a64307ac736f44e5305 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 2 Feb 2024 15:12:16 -0800 Subject: [PATCH 10/15] =?UTF-8?q?=F0=9F=9A=A7=20(3/3)=20filter:=20Rewrite?= =?UTF-8?q?=20using=20SQLite3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is the final change that replaces the Pandas-based implementation of augur filter with a SQLite3-based implementation. Breaking changes in behavior (see changes under tests/): 1. `--subsample-seed` is still deterministic but differs from previous implementation. 2. `--include*`: Previously, both exclusion and force-inclusion would be shown in console output and `--output-log`. Now, only the force-inclusion is reported. Implementation differences with no functional changes: 1. Tabular files are loaded into tables in a temporary SQLite3 database file on disk rather than Pandas DataFrames. This generally means less memory usage and more disk usage. Tables are indexed on strain. 2. Since chunked loading of metadata was introduced to avoid high memory usage¹, that is no longer necessary and all operations are now on the entire metadata (except for `--query`/`--query-pandas`). 3. For large datasets, the SQLite3 implementation is much faster than the Pandas implementation. 4. Instead of relying on continually updated variables (e.g. `valid_strains`), new tables in the database are created at various stages in the process. The "filter reason table" is used as a source of truth for all outputs (and is conveniently a ~direct representation of `--output-log`). This also allows the function `augur.filter._run.run()` to be further broken into smaller parts. 5. Exclusion/inclusion is done using WHERE expressions. 6. For subsampling, priority queues are no longer necessary, as the the highest priority strains can be determined using a ORDER BY across all strains. 7. Date parsing has improved with caching and a a min/max approach to resolving date ranges. Note that sequence I/O remains unchanged. ¹ 87ca73c49fa35a8e9d3a74f41a0a24ccab5941ea --- augur/dates/__init__.py | 7 - augur/filter/_run.py | 337 ++-------- augur/filter/constants.py | 93 ++- augur/filter/dates.py | 192 ++++++ augur/filter/include_exclude_rules.py | 599 +++++++++++------- augur/filter/io.py | 323 ++++++++-- augur/filter/report.py | 88 ++- augur/filter/subsample.py | 509 +++++++++------ docs/api/developer/augur.filter.dates.rst | 7 + docs/api/developer/augur.filter.rst | 1 + tests/filter/test_subsample.py | 117 +--- ...ilter-metadata-sequence-strains-mismatch.t | 5 +- .../filter/cram/filter-min-max-date-output.t | 2 +- .../filter/cram/filter-mismatched-sequences.t | 3 +- ...ter-query-and-exclude-ambiguous-dates-by.t | 2 +- .../filter/cram/filter-sequences-vcf.t | 2 +- .../filter-subsample-missing-date-parts.t | 2 +- ...nces-with-probabilistic-sampling-warning.t | 6 +- ...e-probabilistic-sampling-not-always-used.t | 2 +- .../subsample-probabilistic-sampling-output.t | 2 +- .../cram/subsample-skip-ambiguous-dates.t | 2 +- .../cram/subsample-weighted-and-uniform-mix.t | 2 +- 22 files changed, 1367 insertions(+), 936 deletions(-) create mode 100644 augur/filter/dates.py create mode 100644 docs/api/developer/augur.filter.dates.rst diff --git a/augur/dates/__init__.py b/augur/dates/__init__.py index 8517479bb..2b31e0cb7 100644 --- a/augur/dates/__init__.py +++ b/augur/dates/__init__.py @@ -150,10 +150,3 @@ def get_numerical_dates(metadata:pd.DataFrame, name_col = None, date_col='date', strains = metadata.index.values dates = metadata[date_col].astype(float) return dict(zip(strains, dates)) - -def get_year_month(year, month): - return f"{year}-{str(month).zfill(2)}" - -def get_year_week(year, month, day): - year, week = datetime.date(year, month, day).isocalendar()[:2] - return f"{year}-{str(week).zfill(2)}" diff --git a/augur/filter/_run.py b/augur/filter/_run.py index 172bd902a..9601e81f5 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -1,313 +1,54 @@ -from collections import defaultdict -import csv -import itertools -import numpy as np - +from argparse import Namespace +from tempfile import NamedTemporaryFile from augur.errors import AugurError + from augur.io.file import open_file -from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata +from augur.io.metadata import Metadata from augur.io.print import print_err from augur.io.tabular_file import InvalidDelimiter from . import constants -from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs, import_sequence_index, read_and_output_sequences +from .dates import parse_dates +from .io import get_useful_metadata_columns, initialize_input_source_table, import_metadata, import_sequence_index, write_outputs from .include_exclude_rules import apply_filters, construct_filters from .report import print_report -from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling, get_weighted_group_sizes - - -def run(args): - import_sequence_index(args) - - ##################################### - #Filtering steps - ##################################### - - # Setup filters. - exclude_by, include_by = construct_filters(args) - - # Setup grouping. We handle the following major use cases: - # - # 1. group by and sequences per group defined -> use the given values by the - # user to identify the highest priority records from each group in a single - # pass through the metadata. - # - # 2. group by and maximum sequences defined -> use the first pass through - # the metadata to count the number of records in each group, calculate the - # sequences per group that satisfies the requested maximum, and use a second - # pass through the metadata to select that many sequences per group. - # - # 3. group by not defined but maximum sequences defined -> use a "dummy" - # group such that we select at most the requested maximum number of - # sequences in a single pass through the metadata. - # - # Each case relies on a priority queue to track the highest priority records - # per group. In the best case, we can track these records in a single pass - # through the metadata. In the worst case, we don't know how many sequences - # per group to use, so we need to calculate this number after the first pass - # and use a second pass to add records to the queue. - group_by = args.group_by - sequences_per_group = args.sequences_per_group - records_per_group = None - - if group_by and args.subsample_max_sequences: - # In this case, we need two passes through the metadata with the first - # pass used to count the number of records per group. - records_per_group = defaultdict(int) - elif not group_by and args.subsample_max_sequences: - group_by = ("_dummy",) - sequences_per_group = args.subsample_max_sequences - - # If we are grouping data, use queues to store the highest priority strains - # for each group. When no priorities are provided, they will be randomly - # generated. - queues_by_group = None - if group_by: - # Use user-defined priorities, if possible. Otherwise, setup a - # corresponding dictionary that returns a random float for each strain. - if args.priority: - priorities = read_priority_scores(args.priority) - else: - random_generator = np.random.default_rng(args.subsample_seed) - priorities = defaultdict(random_generator.random) - - # Setup logging. - output_log_context_manager = open_file(args.output_log, "w", newline='') - output_log_writer = None - if args.output_log: - # Log the names of strains that were filtered or force-included, so we - # can properly account for each strain (e.g., including those that were - # initially filtered for one reason and then included again for another - # reason). - output_log = output_log_context_manager.__enter__() - output_log_header = ("strain", "filter", "kwargs") - output_log_writer = csv.DictWriter( - output_log, - fieldnames=output_log_header, - delimiter="\t", - lineterminator="\n", - ) - output_log_writer.writeheader() - - # Load metadata. Metadata are the source of truth for which sequences we - # want to keep in filtered output. - constants.metadata_strains = set() - constants.valid_strains = set() # TODO: rename this more clearly - constants.all_sequences_to_include = set() - constants.filter_counts = defaultdict(int) - - try: - metadata_object = Metadata(args.metadata, args.metadata_id_columns, delimiters=args.metadata_delimiters) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns) - - metadata_reader = read_metadata( - args.metadata, - delimiters=[metadata_object.delimiter], - columns=useful_metadata_columns, - id_columns=[metadata_object.id_column], - chunk_size=args.metadata_chunk_size, - dtype="string", - ) - for metadata in metadata_reader: - duplicate_strains = ( - set(metadata.index[metadata.index.duplicated()]) | - (set(metadata.index) & constants.metadata_strains) - ) - if len(duplicate_strains) > 0: - cleanup_outputs(args) - raise AugurError(f"The following strains are duplicated in '{args.metadata}':\n" + "\n".join(sorted(duplicate_strains))) - - # Maintain list of all strains seen. - constants.metadata_strains.update(set(metadata.index.values)) - - # Filter metadata. - seq_keep, sequences_to_filter, sequences_to_include = apply_filters( - metadata, - exclude_by, - include_by, - ) - constants.valid_strains.update(seq_keep) - - # Track distinct strains to include, so we can write their - # corresponding metadata, strains, or sequences later, as needed. - distinct_force_included_strains = { - record["strain"] - for record in sequences_to_include - } - constants.all_sequences_to_include.update(distinct_force_included_strains) - - # Track reasons for filtered or force-included strains, so we can - # report total numbers filtered and included at the end. Optionally, - # write out these reasons to a log file. - for filtered_strain in itertools.chain(sequences_to_filter, sequences_to_include): - constants.filter_counts[(filtered_strain["filter"], filtered_strain["kwargs"])] += 1 - - # Log the names of strains that were filtered or force-included, - # so we can properly account for each strain (e.g., including - # those that were initially filtered for one reason and then - # included again for another reason). - if args.output_log: - output_log_writer.writerow(filtered_strain) - - if group_by: - # Prevent force-included sequences from being included again during - # subsampling. - seq_keep = seq_keep - distinct_force_included_strains - - # If grouping, track the highest priority metadata records or - # count the number of records per group. First, we need to get - # the groups for the given records. - group_by_strain = get_groups_for_subsampling( - seq_keep, - metadata, - group_by, +from .subsample import apply_subsampling + + +def run(args: Namespace): + with NamedTemporaryFile() as file: + # Set the database file as a variable that can be easily accessed within + # functions deep in the call stack. It could be passed down by value, + # but that would be tedious and makes it harder to trace references back + # to the source. + constants.RUNTIME_DB_FILE = file.name + + initialize_input_source_table() + + try: + metadata = Metadata(args.metadata, id_columns=args.metadata_id_columns, delimiters=args.metadata_delimiters) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." ) + columns = get_useful_metadata_columns(args, metadata.id_column, metadata.columns) + import_metadata(metadata, columns) - if args.subsample_max_sequences and records_per_group is not None: - # Count the number of records per group. We will use this - # information to calculate the number of sequences per group - # for the given maximum number of requested sequences. - for group in group_by_strain.values(): - records_per_group[group] += 1 - else: - # Track the highest priority records, when we already - # know the number of sequences allowed per group. - if queues_by_group is None: - queues_by_group = {} - - for strain in sorted(group_by_strain.keys()): - # During this first pass, we do not know all possible - # groups will be, so we need to build each group's queue - # as we first encounter the group. - group = group_by_strain[strain] - if group not in queues_by_group: - queues_by_group[group] = PriorityQueue( - max_size=sequences_per_group, - ) - - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], - ) - - # In the worst case, we need to calculate sequences per group from the - # requested maximum number of sequences and the number of sequences per - # group. Then, we need to make a second pass through the metadata to find - # the requested number of records. - if args.subsample_max_sequences and records_per_group is not None: - if queues_by_group is None: - # We know all of the possible groups now from the first pass through - # the metadata, so we can create queues for all groups at once. - if args.group_by_weights: - print_err(f"Sampling with weights defined by {args.group_by_weights}.") - group_sizes = get_weighted_group_sizes( - records_per_group, - group_by, - args.group_by_weights, - args.subsample_max_sequences, - args.output_group_by_sizes, - args.subsample_seed, - ) - else: - # Calculate sequences per group. If there are more groups than maximum - # sequences requested, sequences per group will be a floating point - # value and subsampling will be probabilistic. - try: - sequences_per_group, probabilistic_used = calculate_sequences_per_group( - args.subsample_max_sequences, - records_per_group.values(), - args.probabilistic_sampling, - ) - except TooManyGroupsError as error: - raise AugurError(error) - - if (probabilistic_used): - print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") - group_sizes = get_probabilistic_group_sizes( - records_per_group.keys(), - sequences_per_group, - random_seed=args.subsample_seed, - ) - else: - print_err(f"Sampling at {sequences_per_group} per group.") - assert type(sequences_per_group) is int - group_sizes = {group: sequences_per_group for group in records_per_group.keys()} - queues_by_group = create_queues_by_group(group_sizes) - - # Make a second pass through the metadata, only considering records that - # have passed filters. - metadata_reader = read_metadata( - args.metadata, - delimiters=[metadata_object.delimiter], - columns=useful_metadata_columns, - id_columns=[metadata_object.id_column], - chunk_size=args.metadata_chunk_size, - dtype="string", - ) - for metadata in metadata_reader: - # Recalculate groups for subsampling as we loop through the - # metadata a second time. TODO: We could store these in memory - # during the first pass, but we want to minimize overall memory - # usage at the moment. - seq_keep = set(metadata.index.values) & constants.valid_strains - - # Prevent force-included strains from being considered in this - # second pass, as in the first pass. - seq_keep = seq_keep - constants.all_sequences_to_include - - group_by_strain = get_groups_for_subsampling( - seq_keep, - metadata, - group_by, - ) - - for strain in sorted(group_by_strain.keys()): - group = group_by_strain[strain] - queues_by_group[group].add( - metadata.loc[strain], - priorities[strain], - ) - - # If we have any records in queues, we have grouped results and need to - # stream the highest priority records to the requested outputs. - constants.num_excluded_subsamp = 0 - if queues_by_group: - # Populate the set of strains to keep from the records in queues. - subsampled_strains = set() - for group, queue in queues_by_group.items(): - records = [] - for record in queue.get_items(): - # Each record is a pandas.Series instance. Track the name of the - # record, so we can output its sequences later. - subsampled_strains.add(record.name) + import_sequence_index(args) - # Construct a data frame of records to simplify metadata output. - records.append(record) + parse_dates() - # Count and optionally log strains that were not included due to - # subsampling. - strains_filtered_by_subsampling = constants.valid_strains - subsampled_strains - constants.num_excluded_subsamp = len(strains_filtered_by_subsampling) - if output_log_writer: - for strain in strains_filtered_by_subsampling: - output_log_writer.writerow({ - "strain": strain, - "filter": "subsampling", - "kwargs": "", - }) + exclude_by, include_by = construct_filters(args) + apply_filters(exclude_by, include_by) - constants.valid_strains = subsampled_strains + if args.group_by or args.subsample_max_sequences: + apply_subsampling(args) - read_and_output_sequences(args) + write_outputs(args) - if args.output_metadata or args.output_strains: - write_metadata_based_outputs(args.metadata, args.metadata_delimiters, - args.metadata_id_columns, args.output_metadata, - args.output_strains, constants.valid_strains) + print_report(args) - print_report(args) + # TODO: The current implementation assumes the database file is hidden from + # the user. If this ever changes, clean the database of any + # tables/indexes/etc. diff --git a/augur/filter/constants.py b/augur/filter/constants.py index a028e4d17..66c107e86 100644 --- a/augur/filter/constants.py +++ b/augur/filter/constants.py @@ -1,17 +1,51 @@ -# Shared variables set at run time. -# TODO: Remove these with the database implementation. -sequence_index = None -sequence_strains = None -metadata_strains = None -valid_strains = None -all_sequences_to_include = None -filter_counts = None -num_excluded_subsamp = None - -# Generated date columns. +# Constants set at run time. +RUNTIME_DB_FILE: str = None + + +# ID column used for all tables defined internally. +ID_COLUMN = 'strain' + + +# Below are table names, column names, and constant values associated with the database. + +# A table representing the original metadata from the user. +METADATA_TABLE = '__augur_filter__metadata' + + +# A table representing the sequence index, either provided by the user or +# automatically generated. +SEQUENCE_INDEX_TABLE = '__augur_filter__sequence_index' + + +# A table representing information on where strains are present among multiple +# inputs. +INPUT_SOURCE_TABLE = '__augur_filter__input_source' + +# Columns that represent the multiple input sources. +STRAIN_IN_METADATA_COLUMN = '__augur_filter__strain_in_metadata' +STRAIN_IN_SEQUENCES_COLUMN = '__augur_filter__strain_in_sequences' +STRAIN_IN_SEQUENCE_INDEX_COLUMN = '__augur_filter__strain_in_sequence_index' + + +# A table representing strain priorities for subsampling, either provided by the +# user or automatically generated. +PRIORITIES_TABLE = '__augur_filter__priorities' + +# A column for priority scores. +PRIORITY_COLUMN = '__augur_filter__priority' + + +# A table representing information parsed from the date column in the original +# metadata. +DATE_TABLE = '__augur_filter__metadata_date_info' + +# Columns in the date table. DATE_YEAR_COLUMN = 'year' DATE_MONTH_COLUMN = 'month' +DATE_DAY_COLUMN = 'day' DATE_WEEK_COLUMN = 'week' +NUMERIC_DATE_MIN_COLUMN = 'date_min' +NUMERIC_DATE_MAX_COLUMN = 'date_max' # Generated columns available for --group-by. # Use sorted() for reproducible output. @@ -20,3 +54,40 @@ DATE_MONTH_COLUMN, DATE_WEEK_COLUMN, } + + +# A table with columns to indicate why and how each row is filtered. +FILTER_REASON_TABLE = '__augur_filter__filter_reason' + +# A column for boolean values to indicate excluded strains. +EXCLUDE_COLUMN = '__augur_filter__exclude' + +# A column for boolean values to indicate force-included strains. +INCLUDE_COLUMN = '__augur_filter__force_include' + +# A column for the filter reason. +FILTER_REASON_COLUMN = 'filter' + +# A column for the filter reason's keyword arguments. +FILTER_REASON_KWARGS_COLUMN = 'kwargs' + +# A value for the filter reason column to denote exclusion by subsampling. +SUBSAMPLE_FILTER_REASON = 'subsampling' + + +# A table used only during subsampling that contains the columns necessary for +# grouping. +GROUPING_TABLE = '__augur_filter__grouping' + +# A column used when --group-by is not provided to ensure all samples are +# effectively in the same group. +GROUP_BY_DUMMY_COLUMN = '__augur_filter__group_by_placeholder' +GROUP_BY_DUMMY_VALUE = '"dummy"' + + +# A table used only during subsampling that contains information on group size +# limits based on user-specified parameters. +GROUP_SIZE_LIMITS_TABLE = '__augur_filter__group_size_limits' + +# A column for group size limits. +GROUP_SIZE_LIMIT_COLUMN = '__augur_filter__group_size_limit' diff --git a/augur/filter/dates.py b/augur/filter/dates.py new file mode 100644 index 000000000..b29ec8830 --- /dev/null +++ b/augur/filter/dates.py @@ -0,0 +1,192 @@ +from functools import lru_cache +import treetime.utils +from augur.dates import get_numerical_date_from_value +from augur.dates.errors import InvalidDate +from augur.errors import AugurError +from augur.io.metadata import METADATA_DATE_COLUMN +from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier +from . import constants + + +def parse_dates(): + """Validate dates and create a date table.""" + # First, determine if there is a date column. + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_has_date_column = METADATA_DATE_COLUMN in db.columns(constants.METADATA_TABLE) + + if metadata_has_date_column: + # Check dates for errors in Python since error handling is non-trivial + # with SQLite3 user-defined functions. + _validate_dates() + + _create_date_table_from_metadata() + else: + # Create a placeholder table so later JOINs to this table will not break. + _create_empty_date_table() + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.create_primary_index(constants.DATE_TABLE, constants.ID_COLUMN) + + +def _validate_dates(): + """Query metadata for dates and error upon any invalid dates.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT {METADATA_DATE_COLUMN} + FROM {constants.METADATA_TABLE} + """) + for row in result: + try: + get_numerical_date_from_value(str(row[METADATA_DATE_COLUMN]), fmt='%Y-%m-%d') + except InvalidDate as error: + raise AugurError(error) + + +def _create_date_table_from_metadata(): + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + # Register SQLite3 user-defined functions. + db.connection.create_function(get_year.__name__ , 1, get_year) + db.connection.create_function(get_month.__name__, 1, get_month) + db.connection.create_function(get_day.__name__ , 1, get_day) + db.connection.create_function(get_week.__name__ , 1, get_week) + db.connection.create_function(try_get_numeric_date_min.__name__, 1, try_get_numeric_date_min) + db.connection.create_function(try_get_numeric_date_max.__name__, 1, try_get_numeric_date_max) + + db.connection.execute(f"""CREATE TABLE {constants.DATE_TABLE} AS + SELECT + {sanitize_identifier(metadata_id_column)} AS {constants.ID_COLUMN}, + {get_year.__name__}({METADATA_DATE_COLUMN}) AS {constants.DATE_YEAR_COLUMN}, + {get_month.__name__}({METADATA_DATE_COLUMN}) AS {constants.DATE_MONTH_COLUMN}, + {get_day.__name__}({METADATA_DATE_COLUMN}) AS {constants.DATE_DAY_COLUMN}, + {get_week.__name__}({METADATA_DATE_COLUMN}) AS {constants.DATE_WEEK_COLUMN}, + {try_get_numeric_date_min.__name__}({METADATA_DATE_COLUMN}) AS {constants.NUMERIC_DATE_MIN_COLUMN}, + {try_get_numeric_date_max.__name__}({METADATA_DATE_COLUMN}) AS {constants.NUMERIC_DATE_MAX_COLUMN} + FROM {constants.METADATA_TABLE} + """) + + # Remove user-defined functions. + db.connection.create_function(get_year.__name__ , 1, None) + db.connection.create_function(get_month.__name__, 1, None) + db.connection.create_function(get_day.__name__ , 1, None) + db.connection.create_function(get_week.__name__ , 1, None) + db.connection.create_function(try_get_numeric_date_min.__name__, 1, None) + db.connection.create_function(try_get_numeric_date_max.__name__, 1, None) + + +def _create_empty_date_table(): + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.connection.execute(f"""CREATE TABLE {constants.DATE_TABLE} AS + SELECT + {sanitize_identifier(metadata_id_column)} AS {constants.ID_COLUMN}, + NULL AS {METADATA_DATE_COLUMN}, + NULL AS {constants.DATE_YEAR_COLUMN}, + NULL AS {constants.DATE_MONTH_COLUMN}, + NULL AS {constants.DATE_DAY_COLUMN}, + NULL AS {constants.DATE_WEEK_COLUMN}, + NULL AS {constants.NUMERIC_DATE_MIN_COLUMN}, + NULL AS {constants.NUMERIC_DATE_MAX_COLUMN} + FROM {constants.METADATA_TABLE} + """) + + +CACHE_SIZE = 8192 +# Some functions below use @lru_cache to minimize redundant operations on large +# datasets that are likely to have multiple entries with the same date value. + + +@lru_cache(maxsize=CACHE_SIZE) +def get_year(date): + """Get the year from a date. + This function is intended to be registered as a user-defined function in sqlite3. As such, it will not raise any errors. + """ + try: + date_min, date_max = datetime_range(date) + except: + return None + + if date_min.year == date_max.year: + return f"{date_min.year}" + return None + + +@lru_cache(maxsize=CACHE_SIZE) +def get_month(date): + """Get the month from a date. + This function is intended to be registered as a user-defined function in sqlite3. As such, it will not raise any errors. + """ + try: + date_min, date_max = datetime_range(date) + except: + return None + + if date_min.year == date_max.year and date_min.month == date_max.month: + return f"{date_min.year}-{str(date_min.month).zfill(2)}" + return None + + +# FIXME: remove this? +@lru_cache(maxsize=CACHE_SIZE) +def get_day(date): + """Get the day from a date. + This function is intended to be registered as a user-defined function in sqlite3. As such, it will not raise any errors. + """ + try: + date_min, date_max = datetime_range(date) + except: + return None + + if date_min == date_max: + return f"{date_min.day}" + return None + + +@lru_cache(maxsize=CACHE_SIZE) +def get_week(date): + """Get the year and week from a date. + This function is intended to be registered as a user-defined function in sqlite3. As such, it will not raise any errors. + """ + try: + date_min, date_max = datetime_range(date) + except: + return None + + if date_min == date_max: + year, week = date_min.isocalendar()[:2] + return f"{year}-{str(week).zfill(2)}" + return None + + +def datetime_range(date): + numeric_min = get_numerical_date_from_value(date, fmt='%Y-%m-%d', ambiguity_resolver='min') + numeric_max = get_numerical_date_from_value(date, fmt='%Y-%m-%d', ambiguity_resolver='max') + date_min = treetime.utils.datetime_from_numeric(numeric_min) + date_max = treetime.utils.datetime_from_numeric(numeric_max) + return (date_min, date_max) + + +@lru_cache(maxsize=CACHE_SIZE) +def try_get_numeric_date_min(date): + """Get the numeric date from any supported date format, taking the minimum possible value if ambiguous. + This function is intended to be registered as a user-defined function in sqlite3. As such, it will not raise any errors. + """ + try: + return get_numerical_date_from_value(date, fmt='%Y-%m-%d', ambiguity_resolver='min') + except: + return None + + +@lru_cache(maxsize=CACHE_SIZE) +def try_get_numeric_date_max(date): + """Get the numeric date from any supported date format, taking the maximum possible value if ambiguous. + This function is intended to be registered as a user-defined function in sqlite3. As such, it will not raise any errors. + """ + try: + return get_numerical_date_from_value(date, fmt='%Y-%m-%d', ambiguity_resolver='max') + except: + return None diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index feaa041a7..619743c66 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -1,16 +1,15 @@ import ast import json -import operator import re -import numpy as np import pandas as pd -from typing import Any, Callable, Dict, List, Optional, Set, Tuple - -from augur.dates import is_date_ambiguous, get_numerical_dates +import sqlite3 +from typing import Any, Callable, Dict, List, Optional, Tuple from augur.errors import AugurError +from augur.index import ID_COLUMN as SEQUENCE_INDEX_ID_COLUMN from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err from augur.io.strains import read_strains +from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier from augur.io.vcf import is_vcf as filename_is_vcf from . import constants @@ -20,9 +19,14 @@ except AttributeError: PandasUndefinedVariableError = pd.core.computation.ops.UndefinedVariableError # type: ignore +# A SQL expression to represent strains that the filter applies to. +SqlExpression = str + +# Named parameters used in the SQL expression. +SqlParameters = Dict[str, Any] -# The strains to keep as a result of applying a filter function. -FilterFunctionReturn = Set[str] +# The return value of a filter function. +FilterFunctionReturn = Tuple[SqlExpression, SqlParameters] # A function to use for filtering. Parameters vary. FilterFunction = Callable[..., FilterFunctionReturn] @@ -40,7 +44,9 @@ def filter_by_exclude_all() -> FilterFunctionReturn: This is a placeholder function that can be called as part of a generalized loop through all possible functions. """ - return set() + expression = 'True' + parameters: SqlParameters = {} + return expression, parameters def filter_by_exclude(exclude_file) -> FilterFunctionReturn: @@ -52,7 +58,23 @@ def filter_by_exclude(exclude_file) -> FilterFunctionReturn: Filename with strain names to exclude """ excluded_strains = read_strains(exclude_file) - return set(metadata.index.values) - excluded_strains + return _filter_by_exclude_strains(excluded_strains) + + +def _filter_by_exclude_strains(strains) -> FilterFunctionReturn: + """Exclude the given strains. + + Parameters + ---------- + exclude_file : str + Filename with strain names to exclude + """ + quoted_strains = (f"'{strain}'" for strain in strains) + expression = f""" + {constants.ID_COLUMN} IN ({','.join(quoted_strains)}) + """ + parameters: SqlParameters = {} + return expression, parameters def parse_filter_query(query): @@ -68,23 +90,23 @@ def parse_filter_query(query): ------- str : Name of column to query - callable : - Operator function to test equality or non-equality of values + str : + Either '=' or '!=' to denote the operator for a SQLite3 WHERE expression. str : Value of column to query Examples -------- >>> parse_filter_query("property=value") - ('property', , 'value') + ('property', '=', 'value') >>> parse_filter_query("property!=value") - ('property', , 'value') + ('property', '!=', 'value') """ column, value = re.split(r'!?=', query) - op = operator.eq + op = '=' if "!=" in query: - op = operator.ne + op = '!=' return column, op, value @@ -102,40 +124,51 @@ def filter_by_exclude_where(exclude_where) -> FilterFunctionReturn: exclude_where : str Filter query used to exclude strains """ + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_columns = db.columns(constants.METADATA_TABLE) + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + column, op, value = parse_filter_query(exclude_where) - if column in metadata.columns: - # Apply a test operator (equality or inequality) to values from the - # column in the given query. This produces an array of boolean values we - # can index with. - excluded = op( - metadata[column].astype(str).str.lower(), - value.lower() - ) - # Negate the boolean index of excluded strains to get the index of - # strains that passed the filter. - included = ~excluded - filtered = set(metadata[included].index.values) + if column in metadata_columns: + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {sanitize_identifier(metadata_id_column)} + FROM {constants.METADATA_TABLE} + WHERE lower({constants.METADATA_TABLE}.{sanitize_identifier(column)}) {op} lower(:value) + ) + """ + parameters = {'value': value} else: # Skip the filter, if the requested column does not exist. - filtered = set(metadata.index.values) + expression = 'False' + parameters = {} - return filtered + return expression, parameters -def filter_by_query(query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: - """Filter metadata in the given pandas DataFrame with a query string and return - the strain names that pass the filter. +def filter_by_query(query: str, chunksize: int, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: + """Filter by a Pandas expression on the metadata. Parameters ---------- query : str - Query string for the dataframe. + Pandas query string used on a DataFrame representation of the metadata. column_types : str Dict mapping of data type + chunksize : int + Maximum number of metadata records to read into memory at a time. + Increasing this number can speed up filtering at the cost of more memory used. """ - # Create a copy to prevent modification of the original DataFrame. - metadata_copy = metadata.copy() + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + metadata_columns = set(db.columns(constants.METADATA_TABLE)) + + # TODO: select only the columns used in the query + metadata_chunks = pd.read_sql_query(f""" + SELECT * + FROM {constants.METADATA_TABLE} + """, db.connection, chunksize=chunksize) if column_types is None: column_types = {} @@ -143,56 +176,66 @@ def filter_by_query(query: str, column_types: Optional[Dict[str, str]] = None) - # Set columns for type conversion. variables = extract_variables(query) if variables is not None: - columns = variables.intersection(metadata_copy.columns) + columns = variables.intersection(metadata_columns) else: # Column extraction failed. Apply type conversion to all columns. - columns = metadata_copy.columns + columns = metadata_columns # If a type is not explicitly provided, try automatic conversion. for column in columns: column_types.setdefault(column, 'auto') - # Convert data types before applying the query. - # NOTE: This can behave differently between different chunks of metadata, - # but it's the best we can do. - for column, dtype in column_types.items(): - if dtype == 'auto': - # Try numeric conversion followed by boolean conversion. - try: - # pd.to_numeric supports nullable numeric columns unlike pd.read_csv's - # built-in data type inference. - metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise') - except: + excluded_strains = [] + + for metadata_chunk in metadata_chunks: + # Convert data types before applying the query. + # NOTE: This can behave differently between different chunks of metadata, + # but it's the best we can do. + for column, dtype in column_types.items(): + if dtype == 'auto': + # Try numeric conversion followed by boolean conversion. + try: + # pd.to_numeric supports nullable numeric columns unlike pd.read_csv's + # built-in data type inference. + metadata_chunk[column] = pd.to_numeric(metadata_chunk[column], errors='raise') + except: + try: + metadata_chunk[column] = metadata_chunk[column].map(_string_to_boolean) + except ValueError: + # If both conversions fail, column values are preserved as strings. + pass + + elif dtype == 'int': + try: + metadata_chunk[column] = pd.to_numeric(metadata_chunk[column], errors='raise', downcast='integer') + except ValueError as e: + raise AugurError(f"Failed to convert value in column {column!r} to int. {e}") + elif dtype == 'float': try: - metadata_copy[column] = metadata_copy[column].map(_string_to_boolean) - except ValueError: - # If both conversions fail, column values are preserved as strings. - pass + metadata_chunk[column] = pd.to_numeric(metadata_chunk[column], errors='raise', downcast='float') + except ValueError as e: + raise AugurError(f"Failed to convert value in column {column!r} to float. {e}") + elif dtype == 'bool': + try: + metadata_chunk[column] = metadata_chunk[column].map(_string_to_boolean) + except ValueError as e: + raise AugurError(f"Failed to convert value in column {column!r} to bool. {e}") + elif dtype == 'str': + metadata_chunk[column] = metadata_chunk[column].astype('str', errors='ignore') + + try: + matches = metadata_chunk.query(query).index + except Exception as e: + if isinstance(e, PandasUndefinedVariableError): + raise AugurError(f"Query contains a column that does not exist in metadata: {e}") from e + raise AugurError(f"Internal Pandas error when applying query:\n\t{e}\nEnsure the syntax is valid per .") from e + + # Exclude strains that do not match the query. + excluded_strains.extend( + metadata_chunk.drop(matches)[metadata_id_column].values + ) - elif dtype == 'int': - try: - metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='integer') - except ValueError as e: - raise AugurError(f"Failed to convert value in column {column!r} to int. {e}") - elif dtype == 'float': - try: - metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='float') - except ValueError as e: - raise AugurError(f"Failed to convert value in column {column!r} to float. {e}") - elif dtype == 'bool': - try: - metadata_copy[column] = metadata_copy[column].map(_string_to_boolean) - except ValueError as e: - raise AugurError(f"Failed to convert value in column {column!r} to bool. {e}") - elif dtype == 'str': - metadata_copy[column] = metadata_copy[column].astype('str', errors='ignore') - - try: - return set(metadata_copy.query(query).index.values) - except Exception as e: - if isinstance(e, PandasUndefinedVariableError): - raise AugurError(f"Query contains a column that does not exist in metadata: {e}") from e - raise AugurError(f"Internal Pandas error when applying query:\n\t{e}\nEnsure the syntax is valid per .") from e + return _filter_by_exclude_strains(excluded_strains) def _string_to_boolean(s: str): @@ -213,22 +256,58 @@ def _string_to_boolean(s: str): def filter_by_ambiguous_date(date_column, ambiguity) -> FilterFunctionReturn: """Filter where values in the given date column have a given level of ambiguity. + Determine ambiguity hierarchically such that, for example, an ambiguous + month implicates an ambiguous day even when day information is available. + Parameters ---------- date_column : str - Column in the dataframe with dates. + The date column is already parsed beforehand. However, this is used to + verify that the column exists and it is still beneficial to report in + the output log as a kwarg. ambiguity : str Level of date ambiguity to filter by """ - if date_column in metadata.columns: - date_is_ambiguous = metadata[date_column].apply( - lambda date: is_date_ambiguous(date, ambiguity) - ) - filtered = set(metadata[~date_is_ambiguous].index.values) - else: - filtered = set(metadata.index.values) - return filtered + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_has_date_column = date_column in db.columns(constants.METADATA_TABLE) + + if not metadata_has_date_column: + expression = 'False' + elif ambiguity == 'year': + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {constants.ID_COLUMN} + FROM {constants.DATE_TABLE} + WHERE {constants.DATE_YEAR_COLUMN} IS NULL + ) + """ + elif ambiguity == 'month': + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {constants.ID_COLUMN} + FROM {constants.DATE_TABLE} + WHERE ( + {constants.DATE_MONTH_COLUMN} IS NULL OR + {constants.DATE_YEAR_COLUMN} IS NULL + ) + ) + """ + else: + assert ambiguity == 'day' or ambiguity == 'any' + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {constants.ID_COLUMN} + FROM {constants.DATE_TABLE} + WHERE ( + {constants.DATE_DAY_COLUMN} IS NULL OR + {constants.DATE_MONTH_COLUMN} IS NULL OR + {constants.DATE_YEAR_COLUMN} IS NULL + ) + ) + """ + parameters: SqlParameters = {} + return expression, parameters def skip_group_by_with_ambiguous_year(date_column) -> FilterFunctionReturn: @@ -251,25 +330,32 @@ def filter_by_min_date(date_column, min_date) -> FilterFunctionReturn: Parameters ---------- - date_column : str - Column in the dataframe with dates. + date_column + The date column is already parsed beforehand. However, this is used to + verify that the column exists and it is still beneficial to report in + the output log as a kwarg. min_date : float Minimum date """ - strains = set(metadata.index.values) - - # Skip this filter if the date column does not exist. - if date_column not in metadata.columns: - return strains - dates = get_numerical_dates(metadata, date_col=date_column, fmt="%Y-%m-%d") + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_has_date_column = date_column in db.columns(constants.METADATA_TABLE) - filtered = {s for s in strains - if (dates[s] is not None - and (np.isscalar(dates[s]) or all(dates[s])) - and np.max(dates[s]) >= min_date)} + # Skip this filter if the date column does not exist. + if not metadata_has_date_column: + expression = 'False' + parameters: SqlParameters = {} + else: + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {constants.ID_COLUMN} + FROM {constants.DATE_TABLE} + WHERE {constants.NUMERIC_DATE_MAX_COLUMN} < :min_date OR {constants.NUMERIC_DATE_MIN_COLUMN} IS NULL + ) + """ + parameters = {'min_date': min_date} - return filtered + return expression, parameters def filter_by_max_date(date_column, max_date) -> FilterFunctionReturn: @@ -278,24 +364,31 @@ def filter_by_max_date(date_column, max_date) -> FilterFunctionReturn: Parameters ---------- date_column : str - Column in the dataframe with dates. + The date column is already parsed beforehand. However, this is used to + verify that the column exists and it is still beneficial to report in + the output log as a kwarg. max_date : float Maximum date """ - strains = set(metadata.index.values) - # Skip this filter if the date column does not exist. - if date_column not in metadata.columns: - return strains + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_has_date_column = date_column in db.columns(constants.METADATA_TABLE) - dates = get_numerical_dates(metadata, date_col=date_column, fmt="%Y-%m-%d") - - filtered = {s for s in strains - if (dates[s] is not None - and (np.isscalar(dates[s]) or all(dates[s])) - and np.min(dates[s]) <= max_date)} + # Skip this filter if the date column does not exist. + if not metadata_has_date_column: + expression = 'False' + parameters: SqlParameters = {} + else: + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {constants.ID_COLUMN} + FROM {constants.DATE_TABLE} + WHERE {constants.NUMERIC_DATE_MIN_COLUMN} > :max_date OR {constants.NUMERIC_DATE_MAX_COLUMN} IS NULL + ) + """ + parameters = {'max_date': max_date} - return filtered + return expression, parameters def filter_by_sequence_index() -> FilterFunctionReturn: @@ -303,10 +396,14 @@ def filter_by_sequence_index() -> FilterFunctionReturn: index. This filter effectively intersects the strain ids in the metadata and sequence index. """ - metadata_strains = set(metadata.index.values) - sequence_index_strains = set(sequence_index.index.values) - - return metadata_strains & sequence_index_strains + expression = f""" + {constants.ID_COLUMN} NOT IN ( + SELECT {SEQUENCE_INDEX_ID_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + ) + """ + parameters: SqlParameters = {} + return expression, parameters # FIXME: remove metadata in previous commit @@ -318,13 +415,15 @@ def filter_by_min_length(min_length) -> FilterFunctionReturn: min_length : int Minimum number of standard nucleotide characters (A, C, G, or T) in each sequence """ - strains = set(metadata.index.values) - filtered_sequence_index = sequence_index.loc[ - sequence_index.index.intersection(strains) - ] - filtered_sequence_index["ACGT"] = filtered_sequence_index.loc[:, ["A", "C", "G", "T"]].sum(axis=1) - - return set(filtered_sequence_index[filtered_sequence_index["ACGT"] >= min_length].index.values) + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {SEQUENCE_INDEX_ID_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + WHERE A + C + G + T < :min_length + ) + """ + parameters = {'min_length': min_length} + return expression, parameters def filter_by_max_length(max_length) -> FilterFunctionReturn: @@ -335,25 +434,29 @@ def filter_by_max_length(max_length) -> FilterFunctionReturn: max_length : int Maximum number of standard nucleotide characters (A, C, G, or T) in each sequence """ - strains = set(metadata.index.values) - filtered_sequence_index = sequence_index.loc[ - sequence_index.index.intersection(strains) - ] - filtered_sequence_index["ACGT"] = filtered_sequence_index.loc[:, ["A", "C", "G", "T"]].sum(axis=1) - - return set(filtered_sequence_index[filtered_sequence_index["ACGT"] <= max_length].index.values) + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {SEQUENCE_INDEX_ID_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + WHERE A + C + G + T < :max_length + ) + """ + parameters = {'max_length': max_length} + return expression, parameters def filter_by_non_nucleotide() -> FilterFunctionReturn: """Filter for strains with invalid nucleotide content. """ - strains = set(metadata.index.values) - filtered_sequence_index = sequence_index.loc[ - sequence_index.index.intersection(strains) - ] - no_invalid_nucleotides = filtered_sequence_index["invalid_nucleotides"] == 0 - - return set(filtered_sequence_index[no_invalid_nucleotides].index.values) + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {SEQUENCE_INDEX_ID_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + WHERE invalid_nucleotides != 0 + ) + """ + parameters: SqlParameters = {} + return expression, parameters def force_include_strains(include_file) -> FilterFunctionReturn: @@ -364,8 +467,13 @@ def force_include_strains(include_file) -> FilterFunctionReturn: include_file : str Filename with strain names to include """ - included_strains = read_strains(include_file) - return set(metadata.index.values) & included_strains + strains = read_strains(include_file) + quoted_strains = (f"'{strain}'" for strain in strains) + expression = f""" + {constants.ID_COLUMN} IN ({','.join(quoted_strains)}) + """ + parameters: SqlParameters = {} + return expression, parameters def force_include_where(include_where) -> FilterFunctionReturn: @@ -381,22 +489,28 @@ def force_include_where(include_where) -> FilterFunctionReturn: include_where : str Filter query used to include strains """ + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_columns = db.columns(constants.METADATA_TABLE) + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + column, op, value = parse_filter_query(include_where) - if column in metadata.columns: - # Apply a test operator (equality or inequality) to values from the - # column in the given query. This produces an array of boolean values we - # can index with. - included_index = op( - metadata[column].astype(str).str.lower(), - value.lower() - ) - included = set(metadata[included_index].index.values) + if column in metadata_columns: + + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {sanitize_identifier(metadata_id_column)} + FROM {constants.METADATA_TABLE} + WHERE {constants.METADATA_TABLE}.{sanitize_identifier(column)} {op} :value + ) + """ + parameters = {'value': value} else: # Skip the inclusion filter if the requested column does not exist. - included = set() + expression = 'False' + parameters = {} - return included + return expression, parameters def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: @@ -437,8 +551,9 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: exclude_by.append((filter_by_exclude_all, {})) # Filter by sequence index. - if constants.sequence_index is not None: - exclude_by.append((filter_by_sequence_index, {})) + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + if constants.SEQUENCE_INDEX_TABLE in db.tables(): + exclude_by.append((filter_by_sequence_index, {})) # Remove strains explicitly excluded by name. if args.exclude: @@ -460,7 +575,10 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: # Exclude strains by metadata, using pandas querying. if args.query: - kwargs = {"query": args.query} + kwargs = { + "query": args.query, + "chunksize": args.metadata_chunk_size, + } if args.query_columns: kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns} @@ -549,83 +667,100 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: return exclude_by, include_by -def apply_filters(metadata, exclude_by: List[FilterOption], include_by: List[FilterOption]): - """Apply a list of filters to exclude or force-include records from the given - metadata and return the strains to keep, to exclude, and to force include. - - Parameters - ---------- - metadata : pandas.DataFrame - Metadata to filter - - Returns - ------- - set : - Strains to keep (those that passed all filters) - list of dict : - Strains to exclude along with the function that filtered them and the arguments used to run the function. - list of dict : - Strains to force-include along with the function that filtered them and the arguments used to run the function. - """ - strains_to_keep = set(metadata.index.values) - strains_to_filter = [] - strains_to_force_include = [] - distinct_strains_to_force_include: Set = set() - - # Track strains that should be included regardless of filters. - for include_function, include_kwargs in include_by: - passed = metadata.pipe( - include_function, - **include_kwargs, - ) - distinct_strains_to_force_include = distinct_strains_to_force_include | passed - - # Track the reason why strains were included. - if len(passed) > 0: - include_name = include_function.__name__ - include_kwargs_str = _filter_kwargs_to_str(include_kwargs) - for strain in passed: - strains_to_force_include.append({ - "strain": strain, - "filter": include_name, - "kwargs": include_kwargs_str, - }) - - for filter_function, filter_kwargs in exclude_by: - # Apply the current function with its given arguments. Each function - # returns a set of strains that passed the corresponding filter. - passed = metadata.pipe( - filter_function, - **filter_kwargs, - ) - - # Track the strains that failed this filter, so we can explain why later - # on and update the list of strains to keep to intersect with the - # strains that passed. - failed = strains_to_keep - passed - strains_to_keep = (strains_to_keep & passed) - - # Track the reason each strain was filtered for downstream reporting. - if len(failed) > 0: - # Use a human-readable name for each filter when reporting why a strain - # was excluded. - filter_name = filter_function.__name__ - filter_kwargs_str = _filter_kwargs_to_str(filter_kwargs) - for strain in failed: - strains_to_filter.append({ - "strain": strain, - "filter": filter_name, - "kwargs": filter_kwargs_str, - }) - - # Stop applying filters if no strains remain. - if len(strains_to_keep) == 0: - break - - return strains_to_keep, strains_to_filter, strains_to_force_include - - -def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs): +def apply_filters(exclude_by: List[FilterOption], include_by: List[FilterOption]): + """Apply exclusion and force-inclusion rules to filter strains from the metadata.""" + init_filter_reason_table() + apply_exclusions(exclude_by) + apply_force_inclusions(include_by) + + +def init_filter_reason_table(): + """Initialize the filter reason table with all strains as not being excluded nor force-included.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.connection.execute(f"""CREATE TABLE {constants.FILTER_REASON_TABLE} AS + SELECT + {sanitize_identifier(metadata_id_column)} AS {constants.ID_COLUMN}, + FALSE AS {constants.EXCLUDE_COLUMN}, + FALSE AS {constants.INCLUDE_COLUMN}, + NULL AS {constants.FILTER_REASON_COLUMN}, + NULL AS {constants.FILTER_REASON_KWARGS_COLUMN} + FROM {constants.METADATA_TABLE} + """) + db.create_primary_index(constants.FILTER_REASON_TABLE, constants.ID_COLUMN) + + +def apply_exclusions(exclude_by: List[FilterOption]): + """Update the filter reason table given the outcome of exclusion filters.""" + + # Reversed so that earlier entries in the original list will be have higher precedence. + # This is because later evaluations will overwrite any previously applied filter reasons. + for exclude_function, kwargs in reversed(exclude_by): + where_expression, where_parameters = None, None + + # Note: Consider using JOIN instead of subqueries if performance issues arise¹. + # ¹ https://stackoverflow.com/q/3856164 + where_expression, where_parameters = exclude_function(**kwargs) + + assert where_expression is not None + assert where_parameters is not None + + sql = f""" + UPDATE {constants.FILTER_REASON_TABLE} + SET + {constants.EXCLUDE_COLUMN} = TRUE, + {constants.FILTER_REASON_COLUMN} = :filter_reason, + {constants.FILTER_REASON_KWARGS_COLUMN} = :filter_reason_kwargs + WHERE {where_expression} + """ + + sql_parameters = { + 'filter_reason': exclude_function.__name__, + 'filter_reason_kwargs': filter_kwargs_to_str(kwargs) + } + + # Add parameters returned from the filter function. + sql_parameters = {**sql_parameters, **where_parameters} + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + try: + db.connection.execute(sql, sql_parameters) + except Exception as e: + if exclude_function is filter_by_sqlite_query: + if isinstance(e, sqlite3.OperationalError): + if "no such column" in str(e): + raise AugurError(f"Query contains a column that does not exist in metadata.") from e + raise AugurError(f"Error when applying query. Ensure the syntax is valid per .") from e + + +def apply_force_inclusions(include_by: List[FilterOption]): + """Update the filter reason table with force-inclusion rules.""" + for include_function, kwargs in include_by: + where_expression, where_parameters = include_function(**kwargs) + sql = f""" + UPDATE {constants.FILTER_REASON_TABLE} + SET + {constants.INCLUDE_COLUMN} = TRUE, + {constants.FILTER_REASON_COLUMN} = :filter_reason, + {constants.FILTER_REASON_KWARGS_COLUMN} = :filter_reason_kwargs + WHERE {where_expression} + """ + + sql_parameters = { + 'filter_reason': include_function.__name__, + 'filter_reason_kwargs': filter_kwargs_to_str(kwargs) + } + + # Add parameters returned from the filter function. + sql_parameters = {**sql_parameters, **where_parameters} + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.connection.execute(sql, sql_parameters) + + +def filter_kwargs_to_str(kwargs: FilterFunctionKwargs): """Convert a dictionary of kwargs to a JSON string for downstream reporting. This structured string can be converted back into a Python data structure @@ -650,10 +785,10 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs): >>> from augur.dates import numeric_date >>> from augur.filter.include_exclude_rules import filter_by_min_length, filter_by_min_date >>> exclude_by = [(filter_by_min_length, {"min_length": 27000})] - >>> _filter_kwargs_to_str(exclude_by[0][1]) + >>> filter_kwargs_to_str(exclude_by[0][1]) '[["min_length", 27000]]' >>> exclude_by = [(filter_by_min_date, {"date_column": "date", "min_date": numeric_date("2020-03-01")})] - >>> _filter_kwargs_to_str(exclude_by[0][1]) + >>> filter_kwargs_to_str(exclude_by[0][1]) '[["date_column", "date"], ["min_date", 2020.17]]' """ @@ -670,6 +805,10 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs): if isinstance(value, float): value = round(value, 2) + # Don't include chunksize since it does not affect end results. + if key == 'chunksize': + continue + kwarg_list.append((key, value)) return json.dumps(kwarg_list) diff --git a/augur/filter/io.py b/augur/filter/io.py index 5fe145743..a9acb97f6 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -4,11 +4,8 @@ import os import re from textwrap import dedent -from typing import Sequence, Set -import numpy as np -import pandas as pd +from typing import Sequence from tempfile import NamedTemporaryFile -from collections import defaultdict from xopen import xopen from augur.errors import AugurError @@ -22,6 +19,8 @@ from augur.io.metadata import Metadata, METADATA_DATE_COLUMN from augur.io.print import print_err from augur.io.sequences import read_sequences, write_sequences +from augur.io.sqlite3 import DuplicateError, Sqlite3Database, sanitize_identifier +from augur.io.tabular_file import InvalidDelimiter, TabularFile from augur.io.vcf import is_vcf, write_vcf from . import constants from .include_exclude_rules import extract_variables, parse_filter_query @@ -92,27 +91,46 @@ def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Se return list(columns) -def read_priority_scores(fname): - def constant_factory(value): - return lambda: value - +def import_priorities_table(path): + """Import a priorities file into the database.""" try: - with open_file(fname) as pfile: - return defaultdict(constant_factory(-np.inf), { - elems[0]: float(elems[1]) - for elems in (line.strip().split('\t') if '\t' in line else line.strip().split() for line in pfile.readlines()) - }) - except Exception: - raise AugurError(f"missing or malformed priority scores file {fname}") - + priorities = TabularFile(path, delimiters=['\t'], header=False, + columns=[constants.ID_COLUMN, constants.PRIORITY_COLUMN]) + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + _import_tabular_file(priorities, db, constants.PRIORITIES_TABLE) + except (FileNotFoundError, InvalidDelimiter): + raise AugurError(f"missing or malformed priority scores file {path}") -def write_metadata_based_outputs(input_metadata_path: str, delimiters: Sequence[str], + try: + _validate_priorities_table() + except ValueError: + # TODO: Surface the underlying error message. + raise AugurError(f"missing or malformed priority scores file {path}") + + +def _validate_priorities_table(): + """Query the priorities table and error upon any invalid scores.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT {constants.ID_COLUMN}, {constants.PRIORITY_COLUMN} + FROM {constants.PRIORITIES_TABLE} + """) + for row in result: + try: + float(row[constants.PRIORITY_COLUMN]) + except ValueError: + raise ValueError(f"Priority score for strain '{row[constants.ID_COLUMN]}' ('{row[constants.PRIORITY_COLUMN]}') is not a valid number.") + + +def _write_metadata_based_outputs(input_metadata_path: str, delimiters: Sequence[str], id_columns: Sequence[str], output_metadata_path: str, - output_strains_path: str, ids_to_write: Set[str]): + output_strains_path: str): """ Write output metadata and/or strains file given input metadata information and a set of IDs to write. """ + ids_to_write = _get_valid_strains() + input_metadata = Metadata(input_metadata_path, id_columns, delimiters=delimiters) # Handle all outputs with one pass of metadata. This requires using @@ -168,6 +186,86 @@ def column_type_pair(input: str): return (column, dtype) +def initialize_input_source_table(): + """Create the input source table without any rows.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.connection.execute(f""" + CREATE TABLE {constants.INPUT_SOURCE_TABLE} ( + {constants.ID_COLUMN} TEXT, + {constants.STRAIN_IN_METADATA_COLUMN} INTEGER, + {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN} INTEGER, + {constants.STRAIN_IN_SEQUENCES_COLUMN} INTEGER + ) + """) + db.create_primary_index(constants.INPUT_SOURCE_TABLE, constants.ID_COLUMN) + + +def _import_tabular_file(file: TabularFile, db: Sqlite3Database, table: str, columns=None): + """Import a tabular file into a new table in an existing database. + + Parameters + ---------- + file + File to import from. + db + Database to import into. + table + Table name to import into. + columns + Columns to import. + """ + if columns is None: + columns = file.columns + else: + for column in list(columns): + if column not in file.columns: + # Ignore missing columns. Don't error since augur filter's + # --exclude-where allows invalid columns to be specified (they + # are just ignored). + print_err(f"WARNING: Column '{column}' does not exist in the metadata file. This may cause subsequent errors.") + columns.remove(column) + db.create_table(table, columns) + db.insert(table, columns, file.rows()) + + +def import_metadata(metadata: Metadata, columns): + """Import metadata into the database.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + _import_tabular_file(metadata, db, constants.METADATA_TABLE, columns) + + try: + db.create_primary_index(constants.METADATA_TABLE, metadata.id_column) + except DuplicateError as error: + duplicates = error.duplicated_values + raise AugurError(f"The following strains are duplicated in '{metadata.path}':\n" + "\n".join(sorted(duplicates))) + + # If the strain is already in the input source table, update it. + db.connection.execute(f""" + UPDATE {constants.INPUT_SOURCE_TABLE} + SET {constants.STRAIN_IN_METADATA_COLUMN} = TRUE + WHERE {constants.ID_COLUMN} IN ( + SELECT {sanitize_identifier(metadata.id_column)} + FROM {constants.METADATA_TABLE} + ) + """) + + # Otherwise, add an entry. + db.connection.execute(f""" + INSERT OR IGNORE INTO {constants.INPUT_SOURCE_TABLE} ( + {constants.ID_COLUMN}, + {constants.STRAIN_IN_METADATA_COLUMN}, + {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN}, + {constants.STRAIN_IN_SEQUENCES_COLUMN} + ) + SELECT + {sanitize_identifier(metadata.id_column)} AS {constants.ID_COLUMN}, + TRUE AS {constants.STRAIN_IN_METADATA_COLUMN}, + FALSE AS {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN}, + FALSE AS {constants.STRAIN_IN_SEQUENCES_COLUMN} + FROM {constants.METADATA_TABLE} + """) + + def import_sequence_index(args): # Determine whether the sequence index exists or whether should be # generated. We need to generate an index if the input sequences are in a @@ -189,20 +287,48 @@ def import_sequence_index(args): # Load the sequence index, if a path exists. if sequence_index_path: - constants.sequence_index = pd.read_csv( - sequence_index_path, - sep=SEQUENCE_INDEX_DELIMITER, - index_col=SEQUENCE_INDEX_ID_COLUMN, - dtype={SEQUENCE_INDEX_ID_COLUMN: "string"}, - **PANDAS_READ_CSV_OPTIONS, - ) + try: + sequence_index = TabularFile(sequence_index_path, header=True, delimiters=[SEQUENCE_INDEX_DELIMITER]) + except InvalidDelimiter: + # This can happen for single-column files (e.g. VCF sequence indexes). + # If so, use a tab character as an arbitrary delimiter. + sequence_index = TabularFile(sequence_index_path, header=True, delimiter='\t') + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + # Import the sequence index. + _import_tabular_file(sequence_index, db, constants.SEQUENCE_INDEX_TABLE) + # FIXME: set type affinity of SEQUENCE_INDEX_ID_COLUMN to TEXT + db.create_primary_index(constants.SEQUENCE_INDEX_TABLE, SEQUENCE_INDEX_ID_COLUMN) + + # If the strain is already in the input source table, update it. + db.connection.execute(f""" + UPDATE {constants.INPUT_SOURCE_TABLE} + SET {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN} = TRUE + WHERE {constants.ID_COLUMN} IN ( + SELECT {SEQUENCE_INDEX_ID_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + ) + """) + + # Otherwise, add an entry. + db.connection.execute(f""" + INSERT OR IGNORE INTO {constants.INPUT_SOURCE_TABLE} ( + {constants.ID_COLUMN}, + {constants.STRAIN_IN_METADATA_COLUMN}, + {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN}, + {constants.STRAIN_IN_SEQUENCES_COLUMN} + ) + SELECT + {SEQUENCE_INDEX_ID_COLUMN} AS {constants.ID_COLUMN}, + FALSE AS {constants.STRAIN_IN_METADATA_COLUMN}, + TRUE AS {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN}, + FALSE AS {constants.STRAIN_IN_SEQUENCES_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + """) # Remove temporary index file, if it exists. if build_sequence_index: os.unlink(sequence_index_path) - constants.sequence_strains = set(constants.sequence_index.index.values) - def _generate_sequence_index(sequences_file): """Generate a sequence index file. @@ -228,11 +354,21 @@ def _generate_sequence_index(sequences_file): return sequence_index_path -def read_and_output_sequences(args): +def write_outputs(args): + """Write the output files that were requested.""" + + _read_and_output_sequences(args) + + _write_metadata_based_outputs(args.metadata, args.metadata_delimiters, args.metadata_id_columns, args.output_metadata, args.output_strains) + + if args.output_log: + _output_log(args.output_log) + + +def _read_and_output_sequences(args): """Read sequences and output all that passed filtering. """ - # Force inclusion of specific strains after filtering and subsampling. - constants.valid_strains = constants.valid_strains | constants.all_sequences_to_include + valid_strains = _get_valid_strains() # Write output starting with sequences, if they've been requested. It is # possible for the input sequences and sequence index to be out of sync @@ -242,7 +378,7 @@ def read_and_output_sequences(args): if is_vcf(args.sequences): if args.output: # Get the samples to be deleted, not to keep, for VCF - dropped_samps = list(constants.sequence_strains - constants.valid_strains) + dropped_samps = _get_strains_to_drop_from_vcf() write_vcf(args.sequences, args.output, dropped_samps) elif args.sequences: sequences = read_sequences(args.sequences) @@ -258,39 +394,106 @@ def read_and_output_sequences(args): for sequence in sequences: observed_sequence_strains.add(sequence.id) - if sequence.id in constants.valid_strains: + if sequence.id in valid_strains: write_sequences(sequence, output_handle, 'fasta') else: observed_sequence_strains = {sequence.id for sequence in sequences} - if constants.sequence_strains != observed_sequence_strains: - # Warn the user if the expected strains from the sequence index are - # not a superset of the observed strains. - if constants.sequence_strains is not None and observed_sequence_strains > constants.sequence_strains: - print_err( - "WARNING: The sequence index is out of sync with the provided sequences.", - "Metadata and strain output may not match sequence output." + # Update the input source table. + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + # If the strain is already in the input source table, update it. + quoted_strains = (f"'{strain}'" for strain in observed_sequence_strains) + db.connection.execute(f""" + UPDATE {constants.INPUT_SOURCE_TABLE} + SET {constants.STRAIN_IN_SEQUENCES_COLUMN} = TRUE + WHERE {constants.ID_COLUMN} IN ({','.join(quoted_strains)}) + """) + + # Otherwise, add an entry. + rows = ({'strain': strain} for strain in observed_sequence_strains) + db.connection.executemany(f""" + INSERT OR IGNORE INTO {constants.INPUT_SOURCE_TABLE} ( + {constants.ID_COLUMN}, + {constants.STRAIN_IN_METADATA_COLUMN}, + {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN}, + {constants.STRAIN_IN_SEQUENCES_COLUMN} ) - - # Update the set of available sequence strains. - constants.sequence_strains = observed_sequence_strains - - -def cleanup_outputs(args): - """Remove output files. Useful when terminating midway through a loop of metadata chunks.""" - if args.output: - _try_remove(args.output) - if args.output_metadata: - _try_remove(args.output_metadata) - if args.output_strains: - _try_remove(args.output_strains) - if args.output_log: - _try_remove(args.output_log) + VALUES ( + :strain, + FALSE, + FALSE, + TRUE + ) + """, rows) + + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + # Only run this if the sequence index table exists. + if constants.SEQUENCE_INDEX_TABLE in db.tables(): + result = db.connection.execute(f""" + SELECT COUNT(*) + FROM {constants.INPUT_SOURCE_TABLE} + WHERE {constants.STRAIN_IN_SEQUENCES_COLUMN} AND NOT {constants.STRAIN_IN_SEQUENCE_INDEX_COLUMN} + """) + sequences_missing_from_index = result.fetchone()[0] + + if sequences_missing_from_index > 0: + # Warn the user if the expected strains from the sequence index are + # not a superset of the observed strains. + print_err( + "WARNING: The sequence index is out of sync with the provided sequences.", + "Metadata and strain output may not match sequence output." + ) + + +def _output_log(path): + """Write a file explaining the reason for excluded or force-included strains. + + This file has the following columns: + 1. Strain column + 2. Name of the filter function responsible for inclusion/exclusion + 3. Arguments given to the filter function + """ + query = f""" + SELECT + {constants.ID_COLUMN}, + {constants.FILTER_REASON_COLUMN}, + {constants.FILTER_REASON_KWARGS_COLUMN} + FROM {constants.FILTER_REASON_TABLE} + WHERE {constants.FILTER_REASON_COLUMN} IS NOT NULL + """ + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + db.query_to_file( + query=query, + path=path, + header=True, + ) -def _try_remove(filepath): - """Remove a file if it exists.""" - try: - os.remove(filepath) - except FileNotFoundError: - pass +def _get_valid_strains(): + """Returns the strains that pass all filter rules. + """ + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT {constants.ID_COLUMN} + FROM {constants.FILTER_REASON_TABLE} + WHERE NOT {constants.EXCLUDE_COLUMN} OR {constants.INCLUDE_COLUMN} + """) + return {str(row[constants.ID_COLUMN]) for row in result} + + +def _get_strains_to_drop_from_vcf(): + """Return a set of all strain names that are in the sequence index and did + not pass filtering and subsampling.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + # Query = all sequence index strains - (all strains that passed). + # This includes strains that are not present in the metadata. + result = db.connection.execute(f""" + SELECT {SEQUENCE_INDEX_ID_COLUMN} + FROM {constants.SEQUENCE_INDEX_TABLE} + WHERE {SEQUENCE_INDEX_ID_COLUMN} NOT IN ( + SELECT {constants.ID_COLUMN} + FROM {constants.FILTER_REASON_TABLE} + WHERE NOT {constants.EXCLUDE_COLUMN} OR {constants.INCLUDE_COLUMN} + ) + """) + return {str(row[SEQUENCE_INDEX_ID_COLUMN]) for row in result} diff --git a/augur/filter/report.py b/augur/filter/report.py index 125f9495b..04a3c61de 100644 --- a/augur/filter/report.py +++ b/augur/filter/report.py @@ -1,6 +1,7 @@ import json from augur.errors import AugurError from augur.io.print import print_err +from augur.io.sqlite3 import Sqlite3Database from augur.types import EmptyOutputReportingMethod from . import constants, include_exclude_rules @@ -9,13 +10,11 @@ def print_report(args): """Print a report of how many strains were dropped and reasoning.""" # Calculate the number of strains that don't exist in either metadata or # sequences. - num_excluded_by_lack_of_metadata = 0 - if constants.sequence_strains: - num_excluded_by_lack_of_metadata = len(constants.sequence_strains - constants.metadata_strains) + num_excluded_by_lack_of_metadata = _get_num_excluded_by_lack_of_metadata() # Calculate the number of strains passed and filtered. - total_strains_passed = len(constants.valid_strains) - total_strains_filtered = len(constants.metadata_strains) + num_excluded_by_lack_of_metadata - total_strains_passed + total_strains_passed = _get_total_strains_passed() + total_strains_filtered = _get_num_metadata_strains() + num_excluded_by_lack_of_metadata - total_strains_passed print_err(f"{total_strains_filtered} {'strain was' if total_strains_filtered == 1 else 'strains were'} dropped during filtering") @@ -37,10 +36,11 @@ def print_report(args): include_exclude_rules.skip_group_by_with_ambiguous_year.__name__: "{count} {were} dropped during grouping due to ambiguous year information", include_exclude_rules.skip_group_by_with_ambiguous_month.__name__: "{count} {were} dropped during grouping due to ambiguous month information", include_exclude_rules.skip_group_by_with_ambiguous_day.__name__: "{count} {were} dropped during grouping due to ambiguous day information", - include_exclude_rules.force_include_strains.__name__: "{count} {were} added back because {they} {were} in {include_file}", - include_exclude_rules.force_include_where.__name__: "{count} {were} added back because of '{include_where}'", + include_exclude_rules.force_include_strains.__name__: "{count} {were} force-included because {they} {were} in {include_file}", + include_exclude_rules.force_include_where.__name__: "{count} {were} force-included because of '{include_where}'", } - for (filter_name, filter_kwargs), count in constants.filter_counts.items(): + + for filter_name, filter_kwargs, count in _get_filter_counts(): if filter_kwargs: parameters = dict(json.loads(filter_kwargs)) else: @@ -51,9 +51,11 @@ def print_report(args): parameters["they"] = "it" if count == 1 else "they" print_err("\t" + report_template_by_filter_name[filter_name].format(**parameters)) + # TODO: Add subsampling in the report template dict now that it's stored in the same table as other filters. + num_excluded_subsamp = _get_num_excluded_by_subsampling() if (args.group_by and args.sequences_per_group) or args.subsample_max_sequences: seed_txt = ", using seed {}".format(args.subsample_seed) if args.subsample_seed else "" - print_err(f"\t{constants.num_excluded_subsamp} {'was' if constants.num_excluded_subsamp == 1 else 'were'} dropped because of subsampling criteria{seed_txt}") + print_err(f"\t{num_excluded_subsamp} {'was' if num_excluded_subsamp == 1 else 'were'} dropped because of subsampling criteria{seed_txt}") if total_strains_passed == 0: empty_results_message = "All samples have been dropped! Check filter rules and metadata file format." @@ -67,3 +69,71 @@ def print_report(args): raise ValueError(f"Encountered unhandled --empty-output-reporting method {args.empty_output_reporting!r}") print_err(f"{total_strains_passed} {'strain' if total_strains_passed == 1 else 'strains'} passed all filters") + + +def _get_num_excluded_by_lack_of_metadata(): + """Get number of strains present in other inputs but missing in metadata.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT COUNT(*) AS count + FROM {constants.INPUT_SOURCE_TABLE} + WHERE NOT {constants.STRAIN_IN_METADATA_COLUMN} + """) + return int(result.fetchone()["count"]) + + +def _get_num_metadata_strains(): + """Returns the number of strains in the original metadata.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT COUNT(*) AS count + FROM {constants.METADATA_TABLE} + """) + return int(result.fetchone()["count"]) + + +def _get_num_excluded_by_subsampling(): + """Returns the number of strains excluded by subsampling.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT COUNT(*) AS count + FROM {constants.FILTER_REASON_TABLE} + WHERE {constants.FILTER_REASON_COLUMN} = '{constants.SUBSAMPLE_FILTER_REASON}' + """) + return int(result.fetchone()["count"]) + + +# TODO: use _get_valid_strains +def _get_total_strains_passed(): + """Returns the number of strains that pass all filter rules.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT COUNT(*) AS count + FROM {constants.FILTER_REASON_TABLE} + WHERE NOT {constants.EXCLUDE_COLUMN} OR {constants.INCLUDE_COLUMN} + """) + return int(result.fetchone()["count"]) + + +def _get_filter_counts(): + """ + Returns a tuple for each filter with function name, kwargs, and number of strains included/excluded by it. + """ + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT + {constants.FILTER_REASON_COLUMN}, + {constants.FILTER_REASON_KWARGS_COLUMN}, + COUNT(*) AS count + FROM {constants.FILTER_REASON_TABLE} + WHERE {constants.FILTER_REASON_COLUMN} IS NOT NULL + AND {constants.FILTER_REASON_COLUMN} != '{constants.SUBSAMPLE_FILTER_REASON}' + GROUP BY {constants.FILTER_REASON_COLUMN}, {constants.FILTER_REASON_KWARGS_COLUMN} + ORDER BY {constants.FILTER_REASON_COLUMN}, {constants.FILTER_REASON_KWARGS_COLUMN} + """) + for row in result: + yield ( + str(row[constants.FILTER_REASON_COLUMN]), + str(row[constants.FILTER_REASON_KWARGS_COLUMN]), + int(row['count']), + ) diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index 277398915..a4f493867 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -1,103 +1,46 @@ from collections import defaultdict -import heapq import itertools -import uuid import numpy as np import pandas as pd from textwrap import dedent -from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Collection, Dict, Iterable, List, Optional, Sequence, Set, Tuple -from augur.dates import get_year_month, get_year_week from augur.errors import AugurError from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err +from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier from . import constants from .weights_file import WEIGHTS_COLUMN, COLUMN_VALUE_FOR_DEFAULT_WEIGHT, get_default_weight, get_weighted_columns, read_weights_file Group = Tuple[str, ...] """Combination of grouping column values in tuple form.""" +from .io import import_priorities_table -def get_groups_for_subsampling(strains, metadata, group_by=None): - """Return a list of groups for each given strain based on the corresponding - metadata and group by column. +def get_valid_group_by_columns(metadata_columns: Set[str], group_by: List[str]): + """Perform validation on requested group-by columns and return the valid subset. Parameters ---------- - strains : list - A list of strains to get groups for. - metadata : pandas.DataFrame - Metadata to inspect for the given strains. - group_by : list + metadata_columns + All column names in metadata. + group_by A list of metadata (or generated) columns to group records by. Returns ------- - dict : - A mapping of strain names to tuples corresponding to the values of the strain's group. - - Examples - -------- - >>> strains = ["strain1", "strain2"] - >>> metadata = pd.DataFrame([{"strain": "strain1", "date": "2020-01-01", "region": "Africa"}, {"strain": "strain2", "date": "2020-02-01", "region": "Europe"}]).set_index("strain") - >>> group_by = ["region"] - >>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by) - >>> group_by_strain - {'strain1': ('Africa',), 'strain2': ('Europe',)} - - If we group by year or month, these groups are generated from the date - string. - - >>> group_by = ["year", "month"] - >>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by) - >>> group_by_strain - {'strain1': (2020, '2020-01'), 'strain2': (2020, '2020-02')} - - If we omit the grouping columns, the result will group by a dummy column. - - >>> group_by_strain = get_groups_for_subsampling(strains, metadata) - >>> group_by_strain - {'strain1': ('_dummy',), 'strain2': ('_dummy',)} - - If we try to group by columns that don't exist, we get an error. - - >>> group_by = ["missing_column"] - >>> get_groups_for_subsampling(strains, metadata, group_by) - Traceback (most recent call last): - ... - augur.errors.AugurError: The specified group-by categories (['missing_column']) were not found. - - If we try to group by some columns that exist and some that don't, we allow - grouping to continue and print a warning message to stderr. - - >>> group_by = ["year", "month", "missing_column"] - >>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by) - >>> group_by_strain - {'strain1': (2020, '2020-01', 'unknown'), 'strain2': (2020, '2020-02', 'unknown')} - - We can group metadata without any non-ID columns. - - >>> metadata = pd.DataFrame([{"strain": "strain1"}, {"strain": "strain2"}]).set_index("strain") - >>> get_groups_for_subsampling(strains, metadata, group_by=('_dummy',)) - {'strain1': ('_dummy',), 'strain2': ('_dummy',)} + list of str: + Valid group-by columns. """ - metadata = metadata.loc[list(strains)] - group_by_strain = {} - - if len(metadata) == 0: - return group_by_strain - - if not group_by or group_by == ('_dummy',): - group_by_strain = {strain: ('_dummy',) for strain in strains} - return group_by_strain - + # Create a set copy for faster existence checks. group_by_set = set(group_by) + generated_columns_requested = constants.GROUP_BY_GENERATED_COLUMNS & group_by_set # If we could not find any requested categories, we cannot complete subsampling. - if METADATA_DATE_COLUMN not in metadata and group_by_set <= constants.GROUP_BY_GENERATED_COLUMNS: + if METADATA_DATE_COLUMN not in metadata_columns and group_by_set <= constants.GROUP_BY_GENERATED_COLUMNS: raise AugurError(f"The specified group-by categories ({group_by}) were not found. Note that using any of {sorted(constants.GROUP_BY_GENERATED_COLUMNS)} requires a column called {METADATA_DATE_COLUMN!r}.") - if not group_by_set & (set(metadata.columns) | constants.GROUP_BY_GENERATED_COLUMNS): + if not group_by_set & (set(metadata_columns) | constants.GROUP_BY_GENERATED_COLUMNS): raise AugurError(f"The specified group-by categories ({group_by}) were not found.") # Warn/error based on other columns grouped with week. @@ -112,148 +55,17 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): if generated_columns_requested: - if METADATA_DATE_COLUMN not in metadata: - # Set generated columns to 'unknown'. + if METADATA_DATE_COLUMN not in metadata_columns: print_err(f"WARNING: A {METADATA_DATE_COLUMN!r} column could not be found to group-by {sorted(generated_columns_requested)}.") print_err(f"Filtering by group may behave differently than expected!") - df_dates = pd.DataFrame({col: 'unknown' for col in constants.GROUP_BY_GENERATED_COLUMNS}, index=metadata.index) - metadata = pd.concat([metadata, df_dates], axis=1) - else: - # Create a DataFrame with year/month/day columns as nullable ints. - # These columns are prefixed to note temporary usage. They are used - # to generate other columns, and will be discarded at the end. - temp_prefix = str(uuid.uuid4()) - temp_date_cols = [f'{temp_prefix}year', f'{temp_prefix}month', f'{temp_prefix}day'] - df_dates = metadata[METADATA_DATE_COLUMN].str.split('-', n=2, expand=True) - df_dates = df_dates.set_axis(temp_date_cols[:len(df_dates.columns)], axis=1) - missing_date_cols = set(temp_date_cols) - set(df_dates.columns) - for col in missing_date_cols: - df_dates[col] = pd.NA - for col in temp_date_cols: - df_dates[col] = pd.to_numeric(df_dates[col], errors='coerce').astype(pd.Int64Dtype()) - - # Extend metadata with generated date columns - # Drop the date column since it should not be used for grouping. - metadata = pd.concat([metadata.drop(METADATA_DATE_COLUMN, axis=1), df_dates], axis=1) - - # Check again if metadata is empty after dropping ambiguous dates. - if metadata.empty: - return group_by_strain - - # Generate columns. - if constants.DATE_YEAR_COLUMN in generated_columns_requested: - metadata[constants.DATE_YEAR_COLUMN] = metadata[f'{temp_prefix}year'] - if constants.DATE_MONTH_COLUMN in generated_columns_requested: - metadata[constants.DATE_MONTH_COLUMN] = metadata.apply(lambda row: get_year_month( - row[f'{temp_prefix}year'], - row[f'{temp_prefix}month'] - ), axis=1 - ) - if constants.DATE_WEEK_COLUMN in generated_columns_requested: - # Note that week = (year, week) from the date.isocalendar(). - # Do not combine the raw year with the ISO week number alone, - # since raw year ≠ ISO year. - metadata[constants.DATE_WEEK_COLUMN] = metadata.apply(lambda row: get_year_week( - row[f'{temp_prefix}year'], - row[f'{temp_prefix}month'], - row[f'{temp_prefix}day'] - ), axis=1 - ) - - # Drop the internally used columns. - for col in temp_date_cols: - metadata.drop(col, axis=1, inplace=True) - unknown_groups = group_by_set - set(metadata.columns) + unknown_groups = group_by_set - metadata_columns - constants.GROUP_BY_GENERATED_COLUMNS if unknown_groups: print_err(f"WARNING: Some of the specified group-by categories couldn't be found: {', '.join(unknown_groups)}") print_err("Filtering by group may behave differently than expected!") for group in unknown_groups: - metadata[group] = 'unknown' - - # Finally, determine groups. - group_by_strain = dict(zip(metadata.index, metadata[group_by].apply(tuple, axis=1))) - return group_by_strain - - -class PriorityQueue: - """A priority queue implementation that automatically replaces lower priority - items in the heap with incoming higher priority items. - - Examples - -------- - - Add a single record to a heap with a maximum of 2 records. - - >>> queue = PriorityQueue(max_size=2) - >>> queue.add({"strain": "strain1"}, 0.5) - 1 - - Add another record with a higher priority. The queue should be at its maximum - size. - - >>> queue.add({"strain": "strain2"}, 1.0) - 2 - >>> queue.heap - [(0.5, 0, {'strain': 'strain1'}), (1.0, 1, {'strain': 'strain2'})] - >>> list(queue.get_items()) - [{'strain': 'strain1'}, {'strain': 'strain2'}] - - Add a higher priority record that causes the queue to exceed its maximum - size. The resulting queue should contain the two highest priority records - after the lowest priority record is removed. - - >>> queue.add({"strain": "strain3"}, 2.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain2'}, {'strain': 'strain3'}] - - Add a record with the same priority as another record, forcing the duplicate - to be resolved by removing the oldest entry. - - >>> queue.add({"strain": "strain4"}, 1.0) - 2 - >>> list(queue.get_items()) - [{'strain': 'strain4'}, {'strain': 'strain3'}] - - """ - def __init__(self, max_size): - """Create a fixed size heap (priority queue) - - """ - self.max_size = max_size - self.heap = [] - self.counter = itertools.count() - - def add(self, item, priority): - """Add an item to the queue with a given priority. - - If adding the item causes the queue to exceed its maximum size, replace - the lowest priority item with the given item. The queue stores items - with an additional heap id value (a count) to resolve ties between items - with equal priority (favoring the most recently added item). - - """ - heap_id = next(self.counter) - - if len(self.heap) >= self.max_size: - heapq.heappushpop(self.heap, (priority, heap_id, item)) - else: - heapq.heappush(self.heap, (priority, heap_id, item)) - - return len(self.heap) - - def get_items(self): - """Return each item in the queue in order. - - Yields - ------ - Any - Item stored in the queue. - - """ - for priority, heap_id, item in self.heap: - yield item + group_by.remove(group) + return group_by def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None): @@ -304,6 +116,8 @@ def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None): return max_sizes_per_group +# FIXME: read weighs file into sql table? + TARGET_SIZE_COLUMN = '_augur_filter_target_size' INPUT_SIZE_COLUMN = '_augur_filter_input_size' OUTPUT_SIZE_COLUMN = '_augur_filter_subsampling_output_size' @@ -358,6 +172,7 @@ def get_weighted_group_sizes( print_err(f"WARNING: Targeted {row[TARGET_SIZE_COLUMN]} {sequences} for group {group} but only {row[INPUT_SIZE_COLUMN]} {are} available.") if output_sizes_file: + # FIXME: make the order of rows deterministic weights.to_csv(output_sizes_file, index=False, sep='\t') return dict(zip(weights[group_by].apply(tuple, axis=1), weights[TARGET_SIZE_COLUMN])) @@ -631,3 +446,285 @@ def _calculate_sequences_per_group( return int(hi) else: return int(lo) + + +def apply_subsampling(args): + """Apply subsampling to update the filter reason table. + + We handle the following major use cases: + + 1. group by and sequences per group defined -> use the given values by the + user to identify the highest priority records from each group. + + 2. group by and maximum sequences defined -> count the group sizes, calculate the + sequences per group that satisfies the requested maximum, and select that many sequences per group. + + 3. group by not defined but maximum sequences defined -> use a "dummy" + group such that we select at most the requested maximum number of + sequences. + """ + + # Each strain has a score to determine priority during subsampling. + # When no priorities are provided, they will be randomly generated. + create_priorities_table(args) + + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_columns = set(db.columns(constants.METADATA_TABLE)) + + # FIXME: optimize conditions + + valid_group_by_columns = [] + if args.group_by: + valid_group_by_columns = get_valid_group_by_columns(metadata_columns, args.group_by) + + create_grouping_table(valid_group_by_columns, metadata_columns) + + if not args.group_by: + valid_group_by_columns = [constants.GROUP_BY_DUMMY_COLUMN] + target_group_sizes = {(constants.GROUP_BY_DUMMY_VALUE, ): _get_filtered_strains_count()} + + records_per_group = get_records_per_group(valid_group_by_columns) + + if args.subsample_max_sequences: + if args.group_by_weights: + print_err(f"Sampling with weights defined by {args.group_by_weights}.") + target_group_sizes = get_weighted_group_sizes( + records_per_group, + args.group_by, + args.group_by_weights, + args.subsample_max_sequences, + args.output_group_by_sizes, + args.subsample_seed, + ) + else: + # Calculate sequences per group. If there are more groups than maximum + # sequences requested, sequences per group will be a floating point + # value and subsampling will be probabilistic. + try: + sequences_per_group, probabilistic_used = calculate_sequences_per_group( + args.subsample_max_sequences, + records_per_group.values(), + args.probabilistic_sampling, + ) + except TooManyGroupsError as error: + raise AugurError(error) + + if (probabilistic_used): + print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") + target_group_sizes = get_probabilistic_group_sizes( + records_per_group.keys(), + sequences_per_group, + random_seed=args.subsample_seed, + ) + else: + print_err(f"Sampling at {sequences_per_group} per group.") + assert type(sequences_per_group) is int + target_group_sizes = {group: sequences_per_group for group in records_per_group.keys()} + else: + assert args.sequences_per_group + target_group_sizes = {group: args.sequences_per_group for group in records_per_group.keys()} + + create_group_size_limits_table(valid_group_by_columns, target_group_sizes) + update_filter_reason_table(valid_group_by_columns) + + +def create_priorities_table(args): + """Import or generate the priorities table.""" + if args.priority: + import_priorities_table(args.priority) + else: + generate_priorities_table(args.subsample_seed) + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.create_primary_index(constants.PRIORITIES_TABLE, constants.ID_COLUMN) + + +def generate_priorities_table(random_seed: int = None): + """Generate a priorities table with random priorities. + + It is not possible to seed the SQLite built-in RANDOM(). As an alternative, + use a Python function registered as a user-defined function. + + The generated priorities are random floats in the half-open interval [0.0, 1.0). + """ + rng = np.random.default_rng(random_seed) + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + # Register SQLite3 user-defined function. + db.connection.create_function(rng.random.__name__, 0, rng.random) + + db.connection.execute(f"""CREATE TABLE {constants.PRIORITIES_TABLE} AS + SELECT + {constants.ID_COLUMN}, + {rng.random.__name__}() AS {constants.PRIORITY_COLUMN} + FROM {constants.FILTER_REASON_TABLE} + WHERE NOT {constants.EXCLUDE_COLUMN} OR {constants.INCLUDE_COLUMN} + """) + + # Remove user-defined function. + db.connection.create_function(rng.random.__name__, 0, None) + + +def create_grouping_table(group_by_columns: Iterable[str], metadata_columns: Set[str]): + """Create a table with columns for grouping.""" + + # For both of these, start with an empty string in case it isn't needed. + generated_group_by_columns_sql = '' + metadata_group_by_columns_sql = '' + + if group_by_columns: + group_by_columns_set = set(group_by_columns) + + generated_group_by_columns = constants.GROUP_BY_GENERATED_COLUMNS & group_by_columns_set + + if generated_group_by_columns: + generated_group_by_columns_sql = ( + # Prefix columns with the table alias defined in the SQL query further down. + ','.join(f'd.{column}' for column in generated_group_by_columns) + # Add an extra comma for valid SQL. + + ',') + + metadata_group_by_columns = group_by_columns_set - generated_group_by_columns + + if metadata_group_by_columns: + metadata_group_by_columns_sql = ( + # Prefix columns with the table alias defined in the SQL query further down. + ','.join(f'm.{sanitize_identifier(column)}' for column in metadata_group_by_columns) + # Add an extra comma for valid SQL. + + ',') + + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + + # Create a new table with rows as filtered metadata, with the following columns: + # - Metadata ID column + # - Group-by columns + # - Generated date columns + # - Priority score column + # - Placeholder column + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.connection.execute(f"""CREATE TABLE {constants.GROUPING_TABLE} AS + SELECT + f.{constants.ID_COLUMN}, + {metadata_group_by_columns_sql} + {generated_group_by_columns_sql} + p.{constants.PRIORITY_COLUMN}, + {constants.GROUP_BY_DUMMY_VALUE} AS {constants.GROUP_BY_DUMMY_COLUMN} + FROM {constants.FILTER_REASON_TABLE} AS f + JOIN {constants.METADATA_TABLE} AS m + ON (f.{constants.ID_COLUMN} = m.{sanitize_identifier(metadata_id_column)}) + JOIN {constants.DATE_TABLE} AS d + USING ({constants.ID_COLUMN}) + LEFT OUTER JOIN {constants.PRIORITIES_TABLE} AS p + USING ({constants.ID_COLUMN}) + WHERE + NOT f.{constants.EXCLUDE_COLUMN} OR f.{constants.INCLUDE_COLUMN} + """) + # Note: The last JOIN is a LEFT OUTER JOIN since a default INNER JOIN would + # drop strains without a priority. + + +def get_records_per_group(group_by_columns: Sequence[str]) -> Dict[Group, int]: + group_by_columns_sql = ','.join(sanitize_identifier(column) for column in group_by_columns) + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT + COUNT(*) AS count, + {group_by_columns_sql} + FROM {constants.GROUPING_TABLE} + GROUP BY {group_by_columns_sql} + """) + return {tuple(row[c] for c in group_by_columns) : int(row['count']) for row in result} + + +# FIXME: use group_sizes instead of iterator +def create_group_size_limits_table(group_by_columns: Sequence[str], target_group_sizes: Dict[Group, float]): + """Create a table for group size limits.""" + + # Create a function to return target group size for a given group + def group_size(*group): + return target_group_sizes[group] + + group_by_columns_sql = ','.join(sanitize_identifier(column) for column in group_by_columns) + + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + # Register SQLite3 user-defined function. + db.connection.create_function(group_size.__name__, -1, group_size) + + db.connection.execute(f"""CREATE TABLE {constants.GROUP_SIZE_LIMITS_TABLE} AS + SELECT + {group_by_columns_sql}, + {group_size.__name__}({group_by_columns_sql}) AS {constants.GROUP_SIZE_LIMIT_COLUMN} + FROM {constants.GROUPING_TABLE} + GROUP BY {group_by_columns_sql} + """) + + # Remove user-defined function. + db.connection.create_function(group_size.__name__, 0, None) + + +def update_filter_reason_table(group_by_columns: Iterable[str]): + """Subsample filtered metadata and update the filter reason table.""" + group_by_columns_sql = ','.join(sanitize_identifier(column) for column in group_by_columns) + + # First, select the strain column, group-by columns, and a `group_rank` + # variable from the grouping table. `group_rank` represents an incremental + # number ordered by priority within each group (i.e. the highest priority + # strain per group gets group_rank=0). + strains_with_group_rank = f""" + SELECT + {constants.ID_COLUMN}, + {group_by_columns_sql}, + ROW_NUMBER() OVER ( + PARTITION BY {group_by_columns_sql} + ORDER BY + (CASE WHEN {constants.PRIORITY_COLUMN} IS NULL THEN 1 ELSE 0 END), + CAST({constants.PRIORITY_COLUMN} AS REAL) + DESC + ) AS group_rank + FROM {constants.GROUPING_TABLE} + """ + # Notes: + # 1. Although the name is similar to --group-by, the GROUP BY clause does not + # apply here. That command is used for aggregation commands such as getting + # the sizes of each group, which is done elsewhere. + # 2. To treat rows without priorities as lowest priority, `ORDER BY … NULLS LAST` + # would be ideal. However, that syntax is unsupported on SQLite <3.30.0¹ so + # `CASE … IS NULL …` is a more widely compatible equivalent. + # ¹ https://www.sqlite.org/changes.html + + # Combine the above with the group size limits table to select the highest + # priority strains. + query_for_subsampled_strains = f""" + SELECT {constants.ID_COLUMN} + FROM ({strains_with_group_rank}) + JOIN {constants.GROUP_SIZE_LIMITS_TABLE} USING ({group_by_columns_sql}) + WHERE group_rank <= {constants.GROUP_SIZE_LIMIT_COLUMN} + """ + + # Exclude strains that didn't pass subsampling. + # Note that the exclude column was already considered when creating the + # grouping table earlier. The condition here is only in place to not + # overwrite any existing reason for exclusion. + with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: + db.connection.execute(f""" + UPDATE {constants.FILTER_REASON_TABLE} + SET + {constants.EXCLUDE_COLUMN} = TRUE, + {constants.FILTER_REASON_COLUMN} = '{constants.SUBSAMPLE_FILTER_REASON}' + WHERE ( + NOT {constants.EXCLUDE_COLUMN} + AND {constants.ID_COLUMN} NOT IN ({query_for_subsampled_strains}) + ) + """) + + +def _get_filtered_strains_count(): + """Returns the number of metadata strains that pass all filter rules.""" + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT COUNT(*) AS count + FROM {constants.FILTER_REASON_TABLE} + WHERE NOT {constants.EXCLUDE_COLUMN} OR {constants.INCLUDE_COLUMN} + """) + return int(result.fetchone()["count"]) diff --git a/docs/api/developer/augur.filter.dates.rst b/docs/api/developer/augur.filter.dates.rst new file mode 100644 index 000000000..2458e9e63 --- /dev/null +++ b/docs/api/developer/augur.filter.dates.rst @@ -0,0 +1,7 @@ +augur.filter.dates module +========================= + +.. automodule:: augur.filter.dates + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/developer/augur.filter.rst b/docs/api/developer/augur.filter.rst index 59c2b7058..0e12aa9fd 100644 --- a/docs/api/developer/augur.filter.rst +++ b/docs/api/developer/augur.filter.rst @@ -13,6 +13,7 @@ Submodules :maxdepth: 4 augur.filter.constants + augur.filter.dates augur.filter.include_exclude_rules augur.filter.io augur.filter.report diff --git a/tests/filter/test_subsample.py b/tests/filter/test_subsample.py index b8e427f7e..d398c15e1 100644 --- a/tests/filter/test_subsample.py +++ b/tests/filter/test_subsample.py @@ -34,125 +34,44 @@ def test_sequences_per_group(self, target_max_value, counts_per_group, expected_ class TestFilterGroupBy: - def test_filter_groupby_strain_subset(self, valid_metadata: pd.DataFrame): - metadata = valid_metadata.copy() - strains = ['SEQ_1', 'SEQ_3', 'SEQ_5'] - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata) - assert group_by_strain == { - 'SEQ_1': ('_dummy',), - 'SEQ_3': ('_dummy',), - 'SEQ_5': ('_dummy',) - } - - def test_filter_groupby_dummy(self, valid_metadata: pd.DataFrame): - metadata = valid_metadata.copy() - strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata) - assert group_by_strain == { - 'SEQ_1': ('_dummy',), - 'SEQ_2': ('_dummy',), - 'SEQ_3': ('_dummy',), - 'SEQ_4': ('_dummy',), - 'SEQ_5': ('_dummy',) - } - - def test_filter_groupby_invalid_error(self, valid_metadata: pd.DataFrame): + def test_filter_groupby_invalid_error(self): groups = ['invalid'] - metadata = valid_metadata.copy() - strains = metadata.index.tolist() + metadata_columns = {'strain', 'date', 'country'} with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.get_valid_group_by_columns(metadata_columns, groups) assert str(e_info.value) == "The specified group-by categories (['invalid']) were not found." - def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys): + def test_filter_groupby_invalid_warn(self, capsys): groups = ['country', 'year', 'month', 'invalid'] - metadata = valid_metadata.copy() - strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) - assert group_by_strain == { - 'SEQ_1': ('A', 2020, '2020-01', 'unknown'), - 'SEQ_2': ('A', 2020, '2020-02', 'unknown'), - 'SEQ_3': ('B', 2020, '2020-03', 'unknown'), - 'SEQ_4': ('B', 2020, '2020-04', 'unknown'), - 'SEQ_5': ('B', 2020, '2020-05', 'unknown') - } + metadata_columns = {'strain', 'date', 'country'} + augur.filter.subsample.get_valid_group_by_columns(metadata_columns, groups) captured = capsys.readouterr() assert captured.err == "WARNING: Some of the specified group-by categories couldn't be found: invalid\nFiltering by group may behave differently than expected!\n" - def test_filter_groupby_missing_year_error(self, valid_metadata: pd.DataFrame): + def test_filter_groupby_missing_year_error(self): groups = ['year'] - metadata = valid_metadata.copy() - metadata = metadata.drop('date', axis='columns') - strains = metadata.index.tolist() + metadata_columns = {'strain', 'country'} with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.get_valid_group_by_columns(metadata_columns, groups) assert str(e_info.value) == "The specified group-by categories (['year']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'." - def test_filter_groupby_missing_month_error(self, valid_metadata: pd.DataFrame): + def test_filter_groupby_missing_month_error(self): groups = ['month'] - metadata = valid_metadata.copy() - metadata = metadata.drop('date', axis='columns') - strains = metadata.index.tolist() + metadata_columns = {'strain', 'country'} with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.get_valid_group_by_columns(metadata_columns, groups) assert str(e_info.value) == "The specified group-by categories (['month']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'." - def test_filter_groupby_missing_year_and_month_error(self, valid_metadata: pd.DataFrame): + def test_filter_groupby_missing_year_and_month_error(self): groups = ['year', 'month'] - metadata = valid_metadata.copy() - metadata = metadata.drop('date', axis='columns') - strains = metadata.index.tolist() + metadata_columns = {'strain', 'country'} with pytest.raises(AugurError) as e_info: - augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) + augur.filter.subsample.get_valid_group_by_columns(metadata_columns, groups) assert str(e_info.value) == "The specified group-by categories (['year', 'month']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'." - def test_filter_groupby_missing_date_warn(self, valid_metadata: pd.DataFrame, capsys): + def test_filter_groupby_missing_date_warn(self, capsys): groups = ['country', 'year', 'month'] - metadata = valid_metadata.copy() - metadata = metadata.drop('date', axis='columns') - strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) - assert group_by_strain == { - 'SEQ_1': ('A', 'unknown', 'unknown'), - 'SEQ_2': ('A', 'unknown', 'unknown'), - 'SEQ_3': ('B', 'unknown', 'unknown'), - 'SEQ_4': ('B', 'unknown', 'unknown'), - 'SEQ_5': ('B', 'unknown', 'unknown') - } + metadata_columns = {'strain', 'country'} + augur.filter.subsample.get_valid_group_by_columns(metadata_columns, groups) captured = capsys.readouterr() assert captured.err == "WARNING: A 'date' column could not be found to group-by ['month', 'year'].\nFiltering by group may behave differently than expected!\n" - - def test_filter_groupby_no_strains(self, valid_metadata: pd.DataFrame): - groups = ['country', 'year', 'month'] - metadata = valid_metadata.copy() - strains = [] - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) - assert group_by_strain == {} - - def test_filter_groupby_only_year_provided(self, valid_metadata: pd.DataFrame): - groups = ['country', 'year'] - metadata = valid_metadata.copy() - metadata['date'] = '2020' - strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) - assert group_by_strain == { - 'SEQ_1': ('A', 2020), - 'SEQ_2': ('A', 2020), - 'SEQ_3': ('B', 2020), - 'SEQ_4': ('B', 2020), - 'SEQ_5': ('B', 2020) - } - - def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFrame): - groups = ['country', 'year', 'month'] - metadata = valid_metadata.copy() - metadata['date'] = '2020-01' - strains = metadata.index.tolist() - group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups) - assert group_by_strain == { - 'SEQ_1': ('A', 2020, '2020-01'), - 'SEQ_2': ('A', 2020, '2020-01'), - 'SEQ_3': ('B', 2020, '2020-01'), - 'SEQ_4': ('B', 2020, '2020-01'), - 'SEQ_5': ('B', 2020, '2020-01') - } diff --git a/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t b/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t index d8fbd62e8..0ed962aad 100644 --- a/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t +++ b/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t @@ -19,14 +19,13 @@ The query initially filters 3 strains from Colombia, one of which is added back > --output-log filtered_log.tsv 4 strains were dropped during filtering 1 had no metadata + 2 were filtered out by the query: "country != 'Colombia'" 1 had no sequence data - 3 were filtered out by the query: "country != 'Colombia'" - \\t1 was added back because it was in .*include\.txt.* (re) + \\t1 was force-included because it was in .*include\.txt.* (re) 9 strains passed all filters $ head -n 1 filtered_log.tsv; tail -n +2 filtered_log.tsv | sort -k 1,1 strain filter kwargs - COL/FLR_00008/2015 filter_by_query "[[""query"", ""country != 'Colombia'""]]" COL/FLR_00008/2015\tforce_include_strains\t"[[""include_file"", ""*/data/include.txt""]]" (esc) (glob) COL/FLR_00024/2015 filter_by_query "[[""query"", ""country != 'Colombia'""]]" Colombia/2016/ZC204Se filter_by_query "[[""query"", ""country != 'Colombia'""]]" diff --git a/tests/functional/filter/cram/filter-min-max-date-output.t b/tests/functional/filter/cram/filter-min-max-date-output.t index 5ce60854c..6bce49cdc 100644 --- a/tests/functional/filter/cram/filter-min-max-date-output.t +++ b/tests/functional/filter/cram/filter-min-max-date-output.t @@ -10,6 +10,6 @@ Check output of min/max date filters. > --max-date 2016-02-01 \ > --output-metadata filtered_metadata.tsv 8 strains were dropped during filtering - 1 was dropped because it was earlier than 2015.0 or missing a date 7 were dropped because they were later than 2016.09 or missing a date + 1 was dropped because it was earlier than 2015.0 or missing a date 4 strains passed all filters diff --git a/tests/functional/filter/cram/filter-mismatched-sequences.t b/tests/functional/filter/cram/filter-mismatched-sequences.t index f2aea5dfd..ddee95077 100644 --- a/tests/functional/filter/cram/filter-mismatched-sequences.t +++ b/tests/functional/filter/cram/filter-mismatched-sequences.t @@ -34,8 +34,7 @@ because --include takes precedence. > --include metadata-ids.txt \ > --output-strains filtered_strains.txt 0 strains were dropped during filtering - 3 were dropped by `--exclude-all` - 3 were added back because they were in metadata-ids.txt + 3 were force-included because they were in metadata-ids.txt 3 strains passed all filters $ wc -l filtered_strains.txt diff --git a/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t b/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t index 018fdbde5..be4a04061 100644 --- a/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t +++ b/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t @@ -21,6 +21,6 @@ Confirm that `--exclude-ambiguous-dates-by` works for all year only ambiguous da > --empty-output-reporting silent \ > --output-strains filtered_strains.txt 4 strains were dropped during filtering - 1 was filtered out by the query: "region=="Asia"" 3 were dropped because of their ambiguous date in any + 1 was filtered out by the query: "region=="Asia"" 0 strains passed all filters diff --git a/tests/functional/filter/cram/filter-sequences-vcf.t b/tests/functional/filter/cram/filter-sequences-vcf.t index f6ddaa6cb..48fd56f0e 100644 --- a/tests/functional/filter/cram/filter-sequences-vcf.t +++ b/tests/functional/filter/cram/filter-sequences-vcf.t @@ -12,8 +12,8 @@ Filter TB strains from VCF and save as a list of filtered strains. > --output-strains filtered_strains.txt > /dev/null Note: You did not provide a sequence index, so Augur will generate one. You can generate your own index ahead of time with `augur index` and pass it with `augur filter --sequence-index`. 162 strains were dropped during filtering - 155 had no sequence data 7 were dropped because they were earlier than 2012.0 or missing a date + 155 had no sequence data 3 strains passed all filters $ wc -l filtered_strains.txt \s*3 .* (re) diff --git a/tests/functional/filter/cram/filter-subsample-missing-date-parts.t b/tests/functional/filter/cram/filter-subsample-missing-date-parts.t index c70e62213..0e212e137 100644 --- a/tests/functional/filter/cram/filter-subsample-missing-date-parts.t +++ b/tests/functional/filter/cram/filter-subsample-missing-date-parts.t @@ -42,8 +42,8 @@ month information in their date fields. > --output-log log.txt \ > --output-strains filtered_strains.txt > /dev/null 2 strains were dropped during filtering - 1 was dropped during grouping due to ambiguous year information 1 was dropped during grouping due to ambiguous month information + 1 was dropped during grouping due to ambiguous year information 0 were dropped because of subsampling criteria 1 strain passed all filters $ cat log.txt diff --git a/tests/functional/filter/cram/subsample-max-sequences-with-probabilistic-sampling-warning.t b/tests/functional/filter/cram/subsample-max-sequences-with-probabilistic-sampling-warning.t index 7939bbc0c..abb613e57 100644 --- a/tests/functional/filter/cram/subsample-max-sequences-with-probabilistic-sampling-warning.t +++ b/tests/functional/filter/cram/subsample-max-sequences-with-probabilistic-sampling-warning.t @@ -18,8 +18,8 @@ Explicitly use probabilistic subsampling to handle the case when there are more Sampling probabilistically at 0.6250 sequences per group, meaning it is possible to have more than the requested maximum of 5 sequences after filtering. 10 strains were dropped during filtering 1 had no metadata - 1 had no sequence data 1 was dropped because it was earlier than 2012.0 or missing a date + 1 had no sequence data 1 was dropped during grouping due to ambiguous month information 6 were dropped because of subsampling criteria, using seed 314159 3 strains passed all filters @@ -39,8 +39,8 @@ Using the default probabilistic subsampling, should work the same as the previou Sampling probabilistically at 0.6250 sequences per group, meaning it is possible to have more than the requested maximum of 5 sequences after filtering. 10 strains were dropped during filtering 1 had no metadata - 1 had no sequence data - 1 was dropped because it was earlier than 2012.0 or missing a date + \t1 was dropped because it was earlier than 2012.0 or missing a date (esc) + \t1 had no sequence data (esc) 1 was dropped during grouping due to ambiguous month information 6 were dropped because of subsampling criteria, using seed 314159 3 strains passed all filters diff --git a/tests/functional/filter/cram/subsample-probabilistic-sampling-not-always-used.t b/tests/functional/filter/cram/subsample-probabilistic-sampling-not-always-used.t index cc4393053..352d6a45e 100644 --- a/tests/functional/filter/cram/subsample-probabilistic-sampling-not-always-used.t +++ b/tests/functional/filter/cram/subsample-probabilistic-sampling-not-always-used.t @@ -13,7 +13,7 @@ Ensure probabilistic sampling is not used when unnecessary. > --output-metadata filtered_metadata.tsv Sampling at 10 per group. 2 strains were dropped during filtering - 1 was dropped during grouping due to ambiguous year information 1 was dropped during grouping due to ambiguous month information + 1 was dropped during grouping due to ambiguous year information 0 were dropped because of subsampling criteria, using seed 314159 10 strains passed all filters diff --git a/tests/functional/filter/cram/subsample-probabilistic-sampling-output.t b/tests/functional/filter/cram/subsample-probabilistic-sampling-output.t index 865204107..1a13dad6e 100644 --- a/tests/functional/filter/cram/subsample-probabilistic-sampling-output.t +++ b/tests/functional/filter/cram/subsample-probabilistic-sampling-output.t @@ -14,7 +14,7 @@ Check output of probabilistic sampling. WARNING: Asked to provide at most 3 sequences, but there are 8 groups. Sampling probabilistically at 0.3750 sequences per group, meaning it is possible to have more than the requested maximum of 3 sequences after filtering. 10 strains were dropped during filtering - 1 was dropped during grouping due to ambiguous year information 1 was dropped during grouping due to ambiguous month information + 1 was dropped during grouping due to ambiguous year information 8 were dropped because of subsampling criteria, using seed 314159 2 strains passed all filters diff --git a/tests/functional/filter/cram/subsample-skip-ambiguous-dates.t b/tests/functional/filter/cram/subsample-skip-ambiguous-dates.t index dec6bedad..339621947 100644 --- a/tests/functional/filter/cram/subsample-skip-ambiguous-dates.t +++ b/tests/functional/filter/cram/subsample-skip-ambiguous-dates.t @@ -15,8 +15,8 @@ Strains with ambiguous years or months should be dropped and logged. WARNING: Asked to provide at most 5 sequences, but there are 6 groups. Sampling probabilistically at 0.8333 sequences per group, meaning it is possible to have more than the requested maximum of 5 sequences after filtering. 8 strains were dropped during filtering - 1 was dropped during grouping due to ambiguous year information 1 was dropped during grouping due to ambiguous month information + 1 was dropped during grouping due to ambiguous year information 6 were dropped because of subsampling criteria 4 strains passed all filters $ grep "SG_018" filtered_log.tsv | cut -f 1-2 diff --git a/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t b/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t index b11028c24..97d19eb49 100644 --- a/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t +++ b/tests/functional/filter/cram/subsample-weighted-and-uniform-mix.t @@ -104,7 +104,7 @@ requested 17, so the total number of sequences outputted is lower than requested > --output-strains strains.txt Sampling with weights defined by weights-A1B1.tsv. NOTE: Weights were not provided for the column 'year'. Using equal weights across values in that column. - WARNING: Targeted 17 sequences for group ['year=2002', "location='A'"] but only 1 is available. + WARNING: Targeted 17 sequences for group ["year='2002'", "location='A'"] but only 1 is available. 168 strains were dropped during filtering 168 were dropped because of subsampling criteria 83 strains passed all filters From deb00d6b0c18dd50a475090b58712d7c62e46c16 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Mon, 27 Mar 2023 14:07:03 -0700 Subject: [PATCH 11/15] dates: Remove unused is_date_ambiguous() The re-implementation of augur filter accounts for this directly in filter_by_ambiguous_date(). Remove old tests and add equivalent tests. Note that some date strings have been changed since the previous values would fail date format validation. --- augur/dates/__init__.py | 31 ---------- tests/dates/test_dates.py | 29 --------- tests/filter/__init__.py | 16 +++++ .../filter/test_exclude_ambiguous_dates_by.py | 60 +++++++++++++++++++ tests/filter/test_relative_dates.py | 16 +---- 5 files changed, 77 insertions(+), 75 deletions(-) create mode 100644 tests/filter/__init__.py create mode 100644 tests/filter/test_exclude_ambiguous_dates_by.py diff --git a/augur/dates/__init__.py b/augur/dates/__init__.py index 2b31e0cb7..a8ccd0c22 100644 --- a/augur/dates/__init__.py +++ b/augur/dates/__init__.py @@ -76,37 +76,6 @@ def numeric_date_type(date): except InvalidDate as error: raise argparse.ArgumentTypeError(str(error)) from error -def is_date_ambiguous(date, ambiguous_by): - """ - Returns whether a given date string in the format of YYYY-MM-DD is ambiguous by a given part of the date (e.g., day, month, year, or any parts). - - Parameters - ---------- - date : str - Date string in the format of YYYY-MM-DD - ambiguous_by : str - Field of the date string to test for ambiguity ("day", "month", "year", "any") - """ - date_components = date.split('-', 2) - - if len(date_components) == 3: - year, month, day = date_components - elif len(date_components) == 2: - year, month = date_components - day = "XX" - else: - year = date_components[0] if date_components[0] else 'X' - month = "XX" - day = "XX" - - # Determine ambiguity hierarchically such that, for example, an ambiguous - # month implicates an ambiguous day even when day information is available. - return any(( - "X" in year, - "X" in month and ambiguous_by in ("any", "month", "day"), - "X" in day and ambiguous_by in ("any", "day") - )) - def get_numerical_date_from_value(value, fmt=None, min_max_year=None, ambiguity_resolver='both'): value = str(value) if re.match(r'^-*\d+\.\d+$', value): diff --git a/tests/dates/test_dates.py b/tests/dates/test_dates.py index 1427ea3fa..c7ae1dcec 100644 --- a/tests/dates/test_dates.py +++ b/tests/dates/test_dates.py @@ -60,35 +60,6 @@ def test_get_numerical_date_from_value_current_day_limit(self): == pytest.approx(2000.138, abs=1e-3) ) - def test_is_date_ambiguous(self): - """is_date_ambiguous should return true for ambiguous dates and false for valid dates.""" - # Test complete date strings with ambiguous values. - assert dates.is_date_ambiguous("2019-0X-0X", "any") - assert dates.is_date_ambiguous("2019-XX-09", "month") - assert dates.is_date_ambiguous("2019-03-XX", "day") - assert dates.is_date_ambiguous("201X-03-09", "year") - assert dates.is_date_ambiguous("20XX-01-09", "month") - assert dates.is_date_ambiguous("2019-XX-03", "day") - assert dates.is_date_ambiguous("20XX-01-03", "day") - - # Test incomplete date strings with ambiguous values. - assert dates.is_date_ambiguous("2019", "any") - assert dates.is_date_ambiguous("201X", "year") - assert dates.is_date_ambiguous("2019-XX", "month") - assert dates.is_date_ambiguous("2019-10", "day") - assert dates.is_date_ambiguous("2019-XX", "any") - assert dates.is_date_ambiguous("2019-XX", "day") - - # Test complete date strings without ambiguous dates for the requested field. - assert not dates.is_date_ambiguous("2019-09-03", "any") - assert not dates.is_date_ambiguous("2019-03-XX", "month") - assert not dates.is_date_ambiguous("2019-09-03", "day") - assert not dates.is_date_ambiguous("2019-XX-XX", "year") - - # Test incomplete date strings without ambiguous dates for the requested fields. - assert not dates.is_date_ambiguous("2019", "year") - assert not dates.is_date_ambiguous("2019-10", "month") - def test_get_numerical_dates_dict_error(self): """Using get_numerical_dates with metadata represented as a dict should raise an error.""" metadata = { diff --git a/tests/filter/__init__.py b/tests/filter/__init__.py new file mode 100644 index 000000000..0d004779d --- /dev/null +++ b/tests/filter/__init__.py @@ -0,0 +1,16 @@ +import argparse +import shlex +from augur.filter import register_arguments + + +def parse_args(args: str): + parser = argparse.ArgumentParser() + register_arguments(parser) + return parser.parse_args(shlex.split(args)) + + +def write_metadata(tmpdir, metadata): + fn = str(tmpdir / "metadata.tsv") + with open(fn, "w") as fh: + fh.write("\n".join(("\t".join(md) for md in metadata))) + return fn diff --git a/tests/filter/test_exclude_ambiguous_dates_by.py b/tests/filter/test_exclude_ambiguous_dates_by.py new file mode 100644 index 000000000..788e0138c --- /dev/null +++ b/tests/filter/test_exclude_ambiguous_dates_by.py @@ -0,0 +1,60 @@ +# This file contains functional tests that would normally be written as +# Cram-style tests. However, pytest is nice here since it is easy to use with +# parameterized inputs/outputs (not straightforward to set up for Cram tests¹). +# ¹ https://github.com/nextstrain/augur/pull/1183#discussion_r1142687476 + +import pytest + +from augur.errors import AugurError +from augur.filter._run import run + +from . import parse_args, write_metadata + + +@pytest.mark.parametrize( + "date, ambiguity", + [ + # Test complete date strings with ambiguous values. + ("2019-0X-0X", "any"), + ("2019-XX-XX", "month"), + ("2019-XX-XX", "day"), + ("2019-03-XX", "day"), + ("201X-XX-XX", "year"), + ("201X-XX-XX", "month"), + ("201X-XX-XX", "day"), + + # Test incomplete date strings with ambiguous values. + ("2019", "month"), + ("2019", "day"), + ("2019", "any"), + ("201X", "year"), + ("201X", "month"), + ("201X", "day"), + ("201X", "any"), + ], +) +def test_date_is_dropped(tmpdir, date, ambiguity): + metadata = write_metadata(tmpdir, (("strain","date"), + ("SEQ1" , date))) + args = parse_args(f'--metadata {metadata} --exclude-ambiguous-dates-by {ambiguity}') + with pytest.raises(AugurError, match="All samples have been dropped"): + run(args) + +@pytest.mark.parametrize( + "date, ambiguity", + [ + # Test complete date strings without the specified level of ambiguity. + ("2019-09-03", "any"), + ("2019-03-XX", "month"), + ("2019-09-03", "day"), + ("2019-XX-XX", "year"), + + # Test incomplete date strings without the specified level of ambiguity. + ("2019", "year"), + ], +) +def test_date_is_not_dropped(tmpdir, date, ambiguity): + metadata = write_metadata(tmpdir, (("strain","date"), + ("SEQ1" , date))) + args = parse_args(f'--metadata {metadata} --exclude-ambiguous-dates-by {ambiguity}') + run(args) diff --git a/tests/filter/test_relative_dates.py b/tests/filter/test_relative_dates.py index 39ea34a04..7d9d1f78f 100644 --- a/tests/filter/test_relative_dates.py +++ b/tests/filter/test_relative_dates.py @@ -4,26 +4,12 @@ # straightforward to set up for Cram tests¹). # ¹ https://github.com/nextstrain/augur/pull/1183#discussion_r1142687476 -import argparse from freezegun import freeze_time import pytest -import shlex -from augur.filter import register_arguments from augur.filter._run import run - -def parse_args(args): - parser = argparse.ArgumentParser() - register_arguments(parser) - return parser.parse_args(shlex.split(args)) - - -def write_metadata(tmpdir, metadata): - fn = str(tmpdir / "metadata.tsv") - with open(fn, "w") as fh: - fh.write("\n".join(("\t".join(md) for md in metadata))) - return fn +from . import parse_args, write_metadata @freeze_time("2020-03-25") From 0cd748ee24cf87521b9053473043c148236e64ae Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 7 Jul 2023 11:50:58 -0700 Subject: [PATCH 12/15] =?UTF-8?q?=F0=9F=9A=A7=20filter:=20Add=20a=20--debu?= =?UTF-8?q?g=20option?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Now that disk space usage is uncapped (compared to previous pandas implementation using in-memory chunks), it can be useful to know how large the database file was. However, I'll have to think more about how useful this is once database files are passed in by the user. Ideas: - Mark this as an experimental feature for `augur filter`, to be changed or removed with any version. - Add it to the `augur db` interface, e.g. output of `augur db inspect`. However, it can still be useful to know the sizes of "intermediate" tables. It'd also be useful to add runtime information here. Ideas: - Print run times of each "major" function in real-time. This can probably be achieved by some sort of decorator function. --- augur/filter/__init__.py | 1 + augur/filter/_run.py | 8 +++++- augur/filter/constants.py | 1 + augur/filter/debug.py | 6 ++++ augur/filter/io.py | 35 +++++++++++++++++++++++ docs/api/developer/augur.filter.debug.rst | 7 +++++ docs/api/developer/augur.filter.rst | 1 + 7 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 augur/filter/debug.py create mode 100644 docs/api/developer/augur.filter.debug.rst diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index 2c6c9d7db..3848629ae 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -107,6 +107,7 @@ def register_arguments(parser): output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)") output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.") output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.") + output_group.add_argument('--debug', action='store_true', help="Run in debug mode.") output_group.add_argument( '--empty-output-reporting', type=EmptyOutputReportingMethod.argtype, diff --git a/augur/filter/_run.py b/augur/filter/_run.py index 9601e81f5..3e825de8c 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -8,7 +8,8 @@ from augur.io.tabular_file import InvalidDelimiter from . import constants from .dates import parse_dates -from .io import get_useful_metadata_columns, initialize_input_source_table, import_metadata, import_sequence_index, write_outputs +from .debug import print_debug +from .io import get_useful_metadata_columns, initialize_input_source_table, import_metadata, import_sequence_index, print_db_report, write_outputs from .include_exclude_rules import apply_filters, construct_filters from .report import print_report from .subsample import apply_subsampling @@ -21,6 +22,9 @@ def run(args: Namespace): # but that would be tedious and makes it harder to trace references back # to the source. constants.RUNTIME_DB_FILE = file.name + constants.RUNTIME_DEBUG = args.debug + + print_debug(f"Temporary database file: {constants.RUNTIME_DB_FILE!r}") initialize_input_source_table() @@ -49,6 +53,8 @@ def run(args: Namespace): print_report(args) + print_db_report() + # TODO: The current implementation assumes the database file is hidden from # the user. If this ever changes, clean the database of any # tables/indexes/etc. diff --git a/augur/filter/constants.py b/augur/filter/constants.py index 66c107e86..2ad77113d 100644 --- a/augur/filter/constants.py +++ b/augur/filter/constants.py @@ -1,5 +1,6 @@ # Constants set at run time. RUNTIME_DB_FILE: str = None +RUNTIME_DEBUG: bool = False # ID column used for all tables defined internally. diff --git a/augur/filter/debug.py b/augur/filter/debug.py new file mode 100644 index 000000000..cef830955 --- /dev/null +++ b/augur/filter/debug.py @@ -0,0 +1,6 @@ +from . import constants + + +def print_debug(message): + if constants.RUNTIME_DEBUG: + print(f"DEBUG: {message}") diff --git a/augur/filter/io.py b/augur/filter/io.py index a9acb97f6..41db143cf 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -23,6 +23,7 @@ from augur.io.tabular_file import InvalidDelimiter, TabularFile from augur.io.vcf import is_vcf, write_vcf from . import constants +from .debug import print_debug from .include_exclude_rules import extract_variables, parse_filter_query @@ -497,3 +498,37 @@ def _get_strains_to_drop_from_vcf(): ) """) return {str(row[SEQUENCE_INDEX_ID_COLUMN]) for row in result} + + +def print_db_report(): + if not constants.RUNTIME_DEBUG: + return + + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + result = db.connection.execute(f""" + SELECT + name, + SUM(pgsize) AS size + FROM dbstat + GROUP BY name; + """) + rows = result.fetchall() + + print_debug(f'The total size of the database was {_human_readable_size(sum(int(row["size"]) for row in rows))}. Breakdown:') + + for row in sorted(rows, key=lambda row: int(row["size"]), reverse=True): + print_debug(f'{_human_readable_size(row["size"]): >10} {row["name"]}') + + +def _human_readable_size(bytes: int, decimal_places=1): + """Return size in bytes as a human-readable string using larger units. + + Adapted from https://stackoverflow.com/a/43690506 + """ + size = float(bytes) + units = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'] + for unit in units: + if size < 1024.0 or unit == units[-1]: + break + size /= 1024.0 + return f"{size:.{decimal_places}f} {unit}" diff --git a/docs/api/developer/augur.filter.debug.rst b/docs/api/developer/augur.filter.debug.rst new file mode 100644 index 000000000..5f0eee535 --- /dev/null +++ b/docs/api/developer/augur.filter.debug.rst @@ -0,0 +1,7 @@ +augur.filter.debug module +========================= + +.. automodule:: augur.filter.debug + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/developer/augur.filter.rst b/docs/api/developer/augur.filter.rst index 0e12aa9fd..329e3710d 100644 --- a/docs/api/developer/augur.filter.rst +++ b/docs/api/developer/augur.filter.rst @@ -14,6 +14,7 @@ Submodules augur.filter.constants augur.filter.dates + augur.filter.debug augur.filter.include_exclude_rules augur.filter.io augur.filter.report From 260343d725e0204c96cda3fb0dd681e73f1ac639 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 7 Jul 2023 13:51:22 -0700 Subject: [PATCH 13/15] =?UTF-8?q?=F0=9F=9A=A7=20Add=20timing=20output=20to?= =?UTF-8?q?=20--debug=20for=20long-running=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- augur/filter/dates.py | 2 ++ augur/filter/debug.py | 20 ++++++++++++++++++++ augur/filter/include_exclude_rules.py | 2 ++ augur/filter/io.py | 5 ++++- augur/filter/subsample.py | 2 ++ 5 files changed, 30 insertions(+), 1 deletion(-) diff --git a/augur/filter/dates.py b/augur/filter/dates.py index b29ec8830..8fe479235 100644 --- a/augur/filter/dates.py +++ b/augur/filter/dates.py @@ -3,11 +3,13 @@ from augur.dates import get_numerical_date_from_value from augur.dates.errors import InvalidDate from augur.errors import AugurError +from augur.filter.debug import add_debugging from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier from . import constants +@add_debugging def parse_dates(): """Validate dates and create a date table.""" # First, determine if there is a date column. diff --git a/augur/filter/debug.py b/augur/filter/debug.py index cef830955..868b4891e 100644 --- a/augur/filter/debug.py +++ b/augur/filter/debug.py @@ -1,6 +1,26 @@ +from functools import wraps +import time +from typing import Callable + from . import constants def print_debug(message): if constants.RUNTIME_DEBUG: print(f"DEBUG: {message}") + + +def add_debugging(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + if constants.RUNTIME_DEBUG: + start_time = time.perf_counter() + print_debug(f'Starting {func.__name__}.') + result = func(*args, **kwargs) + end_time = time.perf_counter() + total_time = end_time - start_time + print_debug(f'Function {func.__name__} finished in {total_time:.4f} seconds.') + return result + else: + return func(*args, **kwargs) + return wrapper diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index 619743c66..05c50239c 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -12,6 +12,7 @@ from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier from augur.io.vcf import is_vcf as filename_is_vcf from . import constants +from .debug import add_debugging try: # pandas ≥1.5.0 only @@ -667,6 +668,7 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: return exclude_by, include_by +@add_debugging def apply_filters(exclude_by: List[FilterOption], include_by: List[FilterOption]): """Apply exclusion and force-inclusion rules to filter strains from the metadata.""" init_filter_reason_table() diff --git a/augur/filter/io.py b/augur/filter/io.py index 41db143cf..ac5f9f3f0 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -23,7 +23,7 @@ from augur.io.tabular_file import InvalidDelimiter, TabularFile from augur.io.vcf import is_vcf, write_vcf from . import constants -from .debug import print_debug +from .debug import print_debug, add_debugging from .include_exclude_rules import extract_variables, parse_filter_query @@ -123,6 +123,7 @@ def _validate_priorities_table(): raise ValueError(f"Priority score for strain '{row[constants.ID_COLUMN]}' ('{row[constants.PRIORITY_COLUMN]}') is not a valid number.") +@add_debugging def _write_metadata_based_outputs(input_metadata_path: str, delimiters: Sequence[str], id_columns: Sequence[str], output_metadata_path: str, output_strains_path: str): @@ -229,6 +230,7 @@ def _import_tabular_file(file: TabularFile, db: Sqlite3Database, table: str, col db.insert(table, columns, file.rows()) +@add_debugging def import_metadata(metadata: Metadata, columns): """Import metadata into the database.""" with Sqlite3Database(constants.RUNTIME_DB_FILE, mode="rw") as db: @@ -366,6 +368,7 @@ def write_outputs(args): _output_log(args.output_log) +@add_debugging def _read_and_output_sequences(args): """Read sequences and output all that passed filtering. """ diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index a4f493867..0b9c80e5d 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -6,6 +6,7 @@ from typing import Collection, Dict, Iterable, List, Optional, Sequence, Set, Tuple from augur.errors import AugurError +from augur.filter.debug import add_debugging from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err from augur.io.sqlite3 import Sqlite3Database, sanitize_identifier @@ -448,6 +449,7 @@ def _calculate_sequences_per_group( return int(lo) +@add_debugging def apply_subsampling(args): """Apply subsampling to update the filter reason table. From a0a8e81f238fb9b71ff8373dd8388cc3e97bcecc Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 2 Feb 2024 15:12:37 -0800 Subject: [PATCH 14/15] =?UTF-8?q?=F0=9F=9A=A7=20filter:=20Add=20--query-sq?= =?UTF-8?q?lite=20option?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds a new flag to query the SQLite database natively. `--query`/`--query-pandas` will still behave as expected. All Pandas-based query functions are renamed to be Pandas-specific. To avoid breaking changes, alias `--query` to `--query-pandas`. --- augur/filter/__init__.py | 10 +- augur/filter/include_exclude_rules.py | 145 ++++++++++++++++-- augur/filter/io.py | 19 ++- augur/filter/report.py | 3 +- mypy.ini | 3 + setup.py | 1 + ...ilter-metadata-sequence-strains-mismatch.t | 6 +- ...ter-query-and-exclude-ambiguous-dates-by.t | 2 +- .../filter/cram/filter-query-columns.t | 6 +- .../filter/cram/filter-query-errors.t | 45 ++++-- .../filter/cram/filter-query-numerical.t | 32 +++- 11 files changed, 237 insertions(+), 35 deletions(-) diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index 3848629ae..b02f56e7b 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -19,17 +19,23 @@ def register_arguments(parser): input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata") input_group.add_argument('--sequences', '-s', help="sequences in FASTA or VCF format") input_group.add_argument('--sequence-index', help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.") - input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used.") + input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used. NOTE: this only applies to --query/--query-pandas.") input_group.add_argument('--metadata-id-columns', default=DEFAULT_ID_COLUMNS, nargs="+", action=ExtendOverwriteDefault, help="names of possible metadata columns containing identifier information, ordered by priority. Only one ID column will be inferred.") input_group.add_argument('--metadata-delimiters', default=DEFAULT_DELIMITERS, nargs="+", action=ExtendOverwriteDefault, help="delimiters to accept when reading a metadata file. Only one delimiter will be inferred.") metadata_filter_group = parser.add_argument_group("metadata filters", "filters to apply to metadata") metadata_filter_group.add_argument( - '--query', + '--query-pandas', '--query', help="""Filter samples by attribute. Uses Pandas Dataframe querying, see https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#indexing-query for syntax. (e.g., --query "country == 'Colombia'" or --query "(country == 'USA' & (division == 'Washington'))")""" ) + metadata_filter_group.add_argument( + '--query-sqlite', + help="""Filter samples by attribute. + Uses SQL WHERE clause querying, see https://www.sqlite.org/lang_expr.html for syntax. + (e.g., --query "country = 'Colombia'" or --query "(country = 'USA' AND division = 'Washington')")""" + ) metadata_filter_group.add_argument('--query-columns', type=column_type_pair, nargs="+", action="extend", help=f""" Use alongside --query to specify columns and data types in the format 'column:type', where type is one of ({','.join(ACCEPTED_TYPES)}). Automatic type inference will be attempted on all unspecified columns used in the query. diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index 05c50239c..3d539e6f5 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -3,6 +3,7 @@ import re import pandas as pd import sqlite3 +import sqlparse from typing import Any, Callable, Dict, List, Optional, Tuple from augur.errors import AugurError from augur.index import ID_COLUMN as SEQUENCE_INDEX_ID_COLUMN @@ -148,9 +149,66 @@ def filter_by_exclude_where(exclude_where) -> FilterFunctionReturn: return expression, parameters -def filter_by_query(query: str, chunksize: int, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: +def filter_by_sqlite_query(query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: + """Filter by any valid SQLite expression on the metadata. + + Strains that do *not* match the query will be excluded. + + Parameters + ---------- + query : str + SQL expression used to exclude strains + column_types : str + Dict mapping of data type + """ + with Sqlite3Database(constants.RUNTIME_DB_FILE) as db: + metadata_id_column = db.get_primary_index(constants.METADATA_TABLE) + metadata_columns = set(db.columns(constants.METADATA_TABLE)) + + if column_types is None: + column_types = {} + + # Set columns for type conversion. + variables = extract_potential_sqlite_variables(query) + if variables is not None: + columns = variables.intersection(metadata_columns) + else: + # Column extraction failed. Apply type conversion to all columns. + columns = metadata_columns + + # If a type is not explicitly provided, try converting the column to numeric. + # This should cover most use cases, since one common problem is that the + # built-in data type inference when loading the DataFrame does not + # support nullable numeric columns, so numeric comparisons won't work on + # those columns. pd.to_numeric does proper conversion on those columns, + # and will not make any changes to columns with other values. + for column in columns: + column_types.setdefault(column, 'numeric') + + # FIXME: Apply column_types. + # It's not easy to change the type on the table schema.¹ + # Maybe using CAST? But that always takes place even if the conversion is lossy + # and irreversible (i.e. no error handling options like pd.to_numeric). + # ¹ + # ² + + expression = f""" + {constants.ID_COLUMN} IN ( + SELECT {sanitize_identifier(metadata_id_column)} + FROM {constants.METADATA_TABLE} + WHERE NOT ({query}) + ) + """ + parameters: SqlParameters = {} + return expression, parameters + + +def filter_by_pandas_query(query: str, chunksize: int, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: """Filter by a Pandas expression on the metadata. + Note that this is inefficient compared to native SQLite queries, and is in place + for backwards compatibility. + Parameters ---------- query : str @@ -175,7 +233,7 @@ def filter_by_query(query: str, chunksize: int, column_types: Optional[Dict[str, column_types = {} # Set columns for type conversion. - variables = extract_variables(query) + variables = extract_pandas_query_variables(query) if variables is not None: columns = variables.intersection(metadata_columns) else: @@ -574,17 +632,26 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]: {"exclude_where": exclude_where} )) - # Exclude strains by metadata, using pandas querying. - if args.query: + # Exclude strains by metadata. + if args.query_pandas: kwargs = { - "query": args.query, + "query": args.query_pandas, "chunksize": args.metadata_chunk_size, } if args.query_columns: kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns} exclude_by.append(( - filter_by_query, + filter_by_pandas_query, + kwargs + )) + if args.query_sqlite: + kwargs = {"query": args.query_sqlite} + if args.query_columns: + kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns} + + exclude_by.append(( + filter_by_sqlite_query, kwargs )) @@ -816,20 +883,20 @@ def filter_kwargs_to_str(kwargs: FilterFunctionKwargs): return json.dumps(kwarg_list) -def extract_variables(pandas_query: str): +def extract_pandas_query_variables(pandas_query: str): """Try extracting all variable names used in a pandas query string. If successful, return the variable names as a set. Otherwise, nothing is returned. Examples -------- - >>> extract_variables("var1 == 'value'") + >>> extract_pandas_query_variables("var1 == 'value'") {'var1'} - >>> sorted(extract_variables("var1 == 'value' & var2 == 10")) + >>> sorted(extract_pandas_query_variables("var1 == 'value' & var2 == 10")) ['var1', 'var2'] - >>> extract_variables("var1.str.startswith('prefix')") + >>> extract_pandas_query_variables("var1.str.startswith('prefix')") {'var1'} - >>> extract_variables("this query is invalid") + >>> extract_pandas_query_variables("this query is invalid") Backtick quoting is also supported. @@ -882,3 +949,59 @@ def replace(match: re.Match): modified_query = re.sub(pattern, replace, pandas_query) return modified_query, replacements + + +def extract_potential_sqlite_variables(sqlite_expression: str): + """Try extracting all variable names used in a SQLite expression. + + If successful, return the variable names as a set. Otherwise, nothing is returned. + + Examples + -------- + >>> extract_potential_sqlite_variables("var1 = 'value'") + {'var1'} + >>> sorted(extract_potential_sqlite_variables("var1 = 'value' AND var2 = 10")) + ['var1', 'var2'] + >>> extract_potential_sqlite_variables("var1 LIKE 'prefix%'") + {'var1'} + >>> sorted(extract_potential_sqlite_variables("this query is invalid")) + ['invalid', 'this query'] + """ + # This seems to be more difficult than Pandas query parsing. + # + try: + query = f"SELECT * FROM table WHERE {sqlite_expression}" + where = [x for x in sqlparse.parse(query)[0] if isinstance(x, sqlparse.sql.Where)][0] + variables = set(_get_identifiers(where)) or None + return variables + except: + return None + + +def _get_identifiers(token: sqlparse.sql.Token): + """Yield identifiers from a token's children. + + Inspired by ast.walk. + """ + from collections import deque + todo = deque([token]) + while todo: + node = todo.popleft() + + # Limit to comparisons to avoid false positives. + # I chose not to use this because it also comes with false negatives. + # + # if isinstance(node, sqlparse.sql.Comparison): + # if isinstance(node.left, sqlparse.sql.Identifier): + # yield str(node.left) + # elif hasattr(node.left, 'tokens'): + # todo.extend(node.left.tokens) + # if isinstance(node.right, sqlparse.sql.Identifier): + # yield str(node.right) + # elif hasattr(node.right, 'tokens'): + # todo.extend(node.right.tokens) + + if isinstance(node, sqlparse.sql.Identifier): + yield str(node) + elif hasattr(node, 'tokens'): + todo.extend(node.tokens) diff --git a/augur/filter/io.py b/augur/filter/io.py index ac5f9f3f0..64583f62f 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -24,7 +24,7 @@ from augur.io.vcf import is_vcf, write_vcf from . import constants from .debug import print_debug, add_debugging -from .include_exclude_rules import extract_variables, parse_filter_query +from .include_exclude_rules import extract_pandas_query_variables, extract_potential_sqlite_variables, parse_filter_query def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Sequence[str]): @@ -67,14 +67,14 @@ def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Se columns.add(column) # Add columns used in Pandas queries. - if args.query: + if args.query_pandas: if args.query_columns: # Use column names explicitly specified by the user. for column, dtype in args.query_columns: columns.add(column) # Attempt to automatically extract columns from the query. - variables = extract_variables(args.query) + variables = extract_pandas_query_variables(args.query_pandas) if variables is None and not args.query_columns: print_err(dedent(f"""\ WARNING: Could not infer columns from the pandas query. Reading all metadata columns, @@ -88,6 +88,19 @@ def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Se columns.update(all_columns) else: columns.update(variables) + + if args.query_sqlite: + if args.query_columns: + # Use column names explicitly specified by the user. + for column, dtype in args.query_columns: + columns.add(column) + + # Attempt to automatically extract columns from the query. + variables = extract_potential_sqlite_variables(args.query_sqlite) + if variables is None and not args.query_columns: + raise AugurError("Could not infer columns from the SQLite query. If the query is valid, please specify columns using --query-columns.") + else: + columns.update(variables) return list(columns) diff --git a/augur/filter/report.py b/augur/filter/report.py index 04a3c61de..b46cfd46b 100644 --- a/augur/filter/report.py +++ b/augur/filter/report.py @@ -26,7 +26,8 @@ def print_report(args): include_exclude_rules.filter_by_exclude_all.__name__: "{count} {were} dropped by `--exclude-all`", include_exclude_rules.filter_by_exclude.__name__: "{count} {were} dropped because {they} {were} in {exclude_file}", include_exclude_rules.filter_by_exclude_where.__name__: "{count} {were} dropped because of '{exclude_where}'", - include_exclude_rules.filter_by_query.__name__: "{count} {were} filtered out by the query: \"{query}\"", + include_exclude_rules.filter_by_sqlite_query.__name__: "{count} {were} filtered out by the SQLite query: \"{query}\"", + include_exclude_rules.filter_by_pandas_query.__name__: "{count} {were} filtered out by the Pandas query: \"{query}\"", include_exclude_rules.filter_by_ambiguous_date.__name__: "{count} {were} dropped because of their ambiguous date in {ambiguity}", include_exclude_rules.filter_by_min_date.__name__: "{count} {were} dropped because {they} {were} earlier than {min_date} or missing a date", include_exclude_rules.filter_by_max_date.__name__: "{count} {were} dropped because {they} {were} later than {max_date} or missing a date", diff --git a/mypy.ini b/mypy.ini index 56de27ade..5fff65bc5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -50,3 +50,6 @@ ignore_missing_imports = True [mypy-scipy.*] ignore_missing_imports = True + +[mypy-sqlparse.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index a924fe534..b01beee9a 100644 --- a/setup.py +++ b/setup.py @@ -66,6 +66,7 @@ "pyfastx >=1.0.0, <3.0", "python_calamine >=0.2.0", "scipy ==1.*", + "sqlparse ==0.4.*", "xopen[zstd] >=1.7.0, <3" # TODO: Deprecated, remove v1 support around November 2024 ], extras_require = { diff --git a/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t b/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t index 0ed962aad..7c790325b 100644 --- a/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t +++ b/tests/functional/filter/cram/filter-metadata-sequence-strains-mismatch.t @@ -19,7 +19,7 @@ The query initially filters 3 strains from Colombia, one of which is added back > --output-log filtered_log.tsv 4 strains were dropped during filtering 1 had no metadata - 2 were filtered out by the query: "country != 'Colombia'" + 2 were filtered out by the Pandas query: "country != 'Colombia'" 1 had no sequence data \\t1 was force-included because it was in .*include\.txt.* (re) 9 strains passed all filters @@ -27,6 +27,6 @@ The query initially filters 3 strains from Colombia, one of which is added back $ head -n 1 filtered_log.tsv; tail -n +2 filtered_log.tsv | sort -k 1,1 strain filter kwargs COL/FLR_00008/2015\tforce_include_strains\t"[[""include_file"", ""*/data/include.txt""]]" (esc) (glob) - COL/FLR_00024/2015 filter_by_query "[[""query"", ""country != 'Colombia'""]]" - Colombia/2016/ZC204Se filter_by_query "[[""query"", ""country != 'Colombia'""]]" + COL/FLR_00024/2015 filter_by_pandas_query "[[""query"", ""country != 'Colombia'""]]" + Colombia/2016/ZC204Se filter_by_pandas_query "[[""query"", ""country != 'Colombia'""]]" HND/2016/HU_ME59 filter_by_sequence_index [] diff --git a/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t b/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t index be4a04061..573fcfe08 100644 --- a/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t +++ b/tests/functional/filter/cram/filter-query-and-exclude-ambiguous-dates-by.t @@ -22,5 +22,5 @@ Confirm that `--exclude-ambiguous-dates-by` works for all year only ambiguous da > --output-strains filtered_strains.txt 4 strains were dropped during filtering 3 were dropped because of their ambiguous date in any - 1 was filtered out by the query: "region=="Asia"" + 1 was filtered out by the Pandas query: "region=="Asia"" 0 strains passed all filters diff --git a/tests/functional/filter/cram/filter-query-columns.t b/tests/functional/filter/cram/filter-query-columns.t index 807456d22..15b6d89bc 100644 --- a/tests/functional/filter/cram/filter-query-columns.t +++ b/tests/functional/filter/cram/filter-query-columns.t @@ -19,7 +19,7 @@ Automatic inference works. > --query "coverage >= 0.95 & category == 'B'" \ > --output-strains filtered_strains.txt 3 strains were dropped during filtering - 3 were filtered out by the query: "coverage >= 0.95 & category == 'B'" + 3 were filtered out by the Pandas query: "coverage >= 0.95 & category == 'B'" 1 strain passed all filters Specifying coverage:float explicitly also works. @@ -30,7 +30,7 @@ Specifying coverage:float explicitly also works. > --query-columns coverage:float \ > --output-strains filtered_strains.txt 3 strains were dropped during filtering - 3 were filtered out by the query: "coverage >= 0.95 & category == 'B'" + 3 were filtered out by the Pandas query: "coverage >= 0.95 & category == 'B'" 1 strain passed all filters Specifying coverage:float category:str also works. @@ -41,7 +41,7 @@ Specifying coverage:float category:str also works. > --query-columns coverage:float category:str \ > --output-strains filtered_strains.txt 3 strains were dropped during filtering - \t3 were filtered out by the query: "coverage >= 0.95 & category == 'B'" (esc) + \t3 were filtered out by the Pandas query: "coverage >= 0.95 & category == 'B'" (esc) 1 strain passed all filters Specifying category:float does not work. diff --git a/tests/functional/filter/cram/filter-query-errors.t b/tests/functional/filter/cram/filter-query-errors.t index e1f35b731..0592576d8 100644 --- a/tests/functional/filter/cram/filter-query-errors.t +++ b/tests/functional/filter/cram/filter-query-errors.t @@ -2,44 +2,63 @@ Setup $ source "$TESTDIR"/_setup.sh -Using a pandas query with a nonexistent column results in a specific error. +Using a query with a nonexistent column results in a specific error. $ ${AUGUR} filter \ > --metadata "$TESTDIR/../data/metadata.tsv" \ - > --query "invalid == 'value'" \ + > --query-pandas "invalid == 'value'" \ + > --output-strains filtered_strains.txt > /dev/null + WARNING: Column 'invalid' does not exist in the metadata file. Ignoring it. + ERROR: Query contains a column that does not exist in metadata. + [2] + + $ ${AUGUR} filter \ + > --metadata "$TESTDIR/../data/metadata.tsv" \ + > --query-sqlite "invalid = 'value'" \ > --output-strains filtered_strains.txt > /dev/null WARNING: Column 'invalid' does not exist in the metadata file. This may cause subsequent errors. ERROR: Query contains a column that does not exist in metadata: name 'invalid' is not defined [2] -Using pandas queries with bad syntax results in meaningful errors. - -Some error messages from Pandas may be useful, so they are exposed: +Using bad syntax in some queries results in meaningful errors, so they are exposed: $ ${AUGUR} filter \ > --metadata "$TESTDIR/../data/metadata.tsv" \ - > --query "region >= 0.50" \ + > --query-pandas "region >= 0.50" \ > --output-strains filtered_strains.txt > /dev/null ERROR: Internal Pandas error when applying query: '>=' not supported between instances of 'str' and 'float' Ensure the syntax is valid per . [2] -However, other Pandas errors are not so helpful, so a link is provided for users to learn more about query syntax. +FIXME: SQLite is not strongly typed, so this does not result in a syntax error: $ ${AUGUR} filter \ > --metadata "$TESTDIR/../data/metadata.tsv" \ - > --query "country = 'value'" \ + > --query-sqlite "region >= 0.50" \ + > --output-strains filtered_strains.txt + 0 strains were dropped during filtering + 12 strains passed all filters + +However, other errors are not so helpful, so a link is provided for users to learn more about query syntax. + +Unlike SQLite, Pandas does not understand '='. + + $ ${AUGUR} filter \ + > --metadata "$TESTDIR/../data/metadata.tsv" \ + > --query-pandas "virus = 'zika'" \ > --output-strains filtered_strains.txt > /dev/null ERROR: Internal Pandas error when applying query: cannot assign without a target object Ensure the syntax is valid per . [2] +Nonsensical queries behave the same for both Pandas and SQLite. + $ ${AUGUR} filter \ > --metadata "$TESTDIR/../data/metadata.tsv" \ - > --query "some bad syntax" \ + > --query-pandas "some bad syntax" \ > --output-strains filtered_strains.txt > /dev/null WARNING: Could not infer columns from the pandas query. Reading all metadata columns, which may impact execution time. If the query is valid, please open a new issue: @@ -53,3 +72,11 @@ However, other Pandas errors are not so helpful, so a link is provided for users invalid syntax (, line 1) Ensure the syntax is valid per . [2] + + $ ${AUGUR} filter \ + > --metadata "$TESTDIR/../data/metadata.tsv" \ + > --query-sqlite "some bad syntax" \ + > --output-strains filtered_strains.txt > /dev/null + WARNING: Column 'bad syntax' does not exist in the metadata file. Ignoring it. + ERROR: Error when applying query. Ensure the syntax is valid per . + [2] diff --git a/tests/functional/filter/cram/filter-query-numerical.t b/tests/functional/filter/cram/filter-query-numerical.t index 875c04802..a7cbf1bb9 100644 --- a/tests/functional/filter/cram/filter-query-numerical.t +++ b/tests/functional/filter/cram/filter-query-numerical.t @@ -72,12 +72,40 @@ comparing strings, it's likely that SEQ3 will be dropped or errors arise. $ ${AUGUR} filter \ > --metadata metadata.tsv \ - > --query "metric1 > 4 & metric1 < metric2" \ + > --query-pandas "metric1 > 4 & metric1 < metric2" \ > --output-strains filtered_strains.txt 1 strain was dropped during filtering - 1 was filtered out by the query: "metric1 > 4 & metric1 < metric2" + 1 was filtered out by the Pandas query: "metric1 > 4 & metric1 < metric2" 2 strains passed all filters $ sort filtered_strains.txt SEQ2 SEQ3 + +Do the same with a SQLite query. +Currently this does not work as expected because the type affinities are +hardcoded to TEXT, and type conversion to NUMERIC only happens when one of the +operands are NUMERIC. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query-sqlite "metric1 > 4 AND metric1 < metric2" \ + > --output-strains filtered_strains.txt + 2 strains were dropped during filtering + 2 were filtered out by the SQLite query: "metric1 > 4 AND metric1 < metric2" + 1 strain passed all filters + $ cat filtered_strains.txt + SEQ2 + +However, a numerical comparison between a column and a numeric literal works as expected. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query-sqlite "metric1 > 4" \ + > --output-strains filtered_strains.txt + 1 strain was dropped during filtering + 1 was filtered out by the SQLite query: "metric1 > 4" + 2 strains passed all filters + $ cat filtered_strains.txt + SEQ2 + SEQ3 From 99c5a0d3def67a09bf2dcadddaf7f56cea4a0b54 Mon Sep 17 00:00:00 2001 From: Victor Lin <13424970+victorlin@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:49:22 -0700 Subject: [PATCH 15/15] =?UTF-8?q?=F0=9F=9A=A7=20support=20db=20file=20as?= =?UTF-8?q?=20metadata=20index?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- augur/filter/__init__.py | 2 ++ augur/filter/_run.py | 34 +++++++++++++++++++++------------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index b02f56e7b..0a7d107ba 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -18,6 +18,7 @@ def register_arguments(parser): input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered") input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata") input_group.add_argument('--sequences', '-s', help="sequences in FASTA or VCF format") + input_group.add_argument('--metadata-index', metavar="FILE", help="SQLite3 database file with metadata preloaded") input_group.add_argument('--sequence-index', help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.") input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used. NOTE: this only applies to --query/--query-pandas.") input_group.add_argument('--metadata-id-columns', default=DEFAULT_ID_COLUMNS, nargs="+", action=ExtendOverwriteDefault, help="names of possible metadata columns containing identifier information, ordered by priority. Only one ID column will be inferred.") @@ -110,6 +111,7 @@ def register_arguments(parser): output_group = parser.add_argument_group("outputs", "options related to outputs, at least one of the possible representations of filtered data (--output, --output-metadata, --output-strains) is required") output_group.add_argument('--output', '--output-sequences', '-o', help="filtered sequences in FASTA format") output_group.add_argument('--output-metadata', help="metadata for strains that passed filters") + output_group.add_argument('--output-metadata-index', help="SQLite3 database file with metadata preloaded") output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)") output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.") output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.") diff --git a/augur/filter/_run.py b/augur/filter/_run.py index 3e825de8c..e08db6451 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -1,4 +1,5 @@ from argparse import Namespace +import shutil from tempfile import NamedTemporaryFile from augur.errors import AugurError @@ -26,22 +27,29 @@ def run(args: Namespace): print_debug(f"Temporary database file: {constants.RUNTIME_DB_FILE!r}") - initialize_input_source_table() + if args.metadata_index: + shutil.copyfile(args.metadata_index, file.name) + else: + initialize_input_source_table() - try: - metadata = Metadata(args.metadata, id_columns=args.metadata_id_columns, delimiters=args.metadata_delimiters) - except InvalidDelimiter: - raise AugurError( - f"Could not determine the delimiter of {args.metadata!r}. " - f"Valid delimiters are: {args.metadata_delimiters!r}. " - "This can be changed with --metadata-delimiters." - ) - columns = get_useful_metadata_columns(args, metadata.id_column, metadata.columns) - import_metadata(metadata, columns) + try: + metadata = Metadata(args.metadata, id_columns=args.metadata_id_columns, delimiters=args.metadata_delimiters) + except InvalidDelimiter: + raise AugurError( + f"Could not determine the delimiter of {args.metadata!r}. " + f"Valid delimiters are: {args.metadata_delimiters!r}. " + "This can be changed with --metadata-delimiters." + ) + columns = get_useful_metadata_columns(args, metadata.id_column, metadata.columns) + import_metadata(metadata, columns) - import_sequence_index(args) + import_sequence_index(args) - parse_dates() + parse_dates() + + if args.output_metadata_index: + print(f"Saving database file to {args.output_metadata_index!r}") + shutil.copyfile(file.name, args.output_metadata_index) exclude_by, include_by = construct_filters(args) apply_filters(exclude_by, include_by)