Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor subsampling into its own function #746

Merged
merged 1 commit into from
Jul 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
341 changes: 194 additions & 147 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,188 @@ def include_by_query(metadata, include_where):
return set(metadata[included].index.values)


def subsample(metadata,
strains_to_keep,
group_by=None,
sequences_per_group=None,
max_sequences=None,
probabilistic_sampling=False,
priority=None,
random_seed=None):
"""Subsample metadata into a fixed number of strains per group. If the user
specifies a maximum number of subsampled strains, calculate the
corresponding sequences per group for the available groups. If no group is
defined, use a dummy group.

Optionally, sorts strains by a given priority score instead of returning
random strains for each group.

Parameters
----------
metadata : pandas.DataFrame
Metadata to subsample
strains_to_keep : set[str]
Strain names to consider for subsampling from the given metadata
group_by : list[str]
Column(s) to group metadata by prior to subsampling. When omitted, subsampling uses a "_dummy" group.
sequences_per_group : int
Number of sequences to sample per group
max_sequences : int
Maximum number of sequences to sample total. Mutually exclusive of ``sequences_per_group``.
probabilistic_sampling : bool
Enable probabilistic subsampling
priority : str
Name of a tab-delimited file containing priorities assigned to strains such that higher numbers indicate higher priority.
random_seed : str
Random seed for subsampling

Returns
-------
set[str]:
Strains that pass the filter

"""
# Set the random seed, for more reproducible results.
if random_seed:
random.seed(random_seed)

# Disable probabilistic sampling when user's request a specific number of
# sequences per group. In this case, users expect deterministic behavior and
# probabilistic behavior is surprising.
if sequences_per_group:
probabilistic_sampling = False

if group_by:
groups = group_by
else:
groups = ["_dummy"]

spg = sequences_per_group
seq_names_by_group = defaultdict(list)

for seq_name in strains_to_keep:
group = []
m = metadata.loc[seq_name].to_dict()
# collect group specifiers
for c in groups:
if c == "_dummy":
group.append(c)
elif c in m:
group.append(m[c])
elif c in ['month', 'year'] and 'date' in m:
try:
year = int(m["date"].split('-')[0])
except:
print("WARNING: no valid year, skipping",seq_name, m["date"])
continue
if c=='month':
try:
month = int(m["date"].split('-')[1])
except:
month = random.randint(1,12)
group.append((year, month))
else:
group.append(year)
else:
group.append('unknown')
seq_names_by_group[tuple(group)].append(seq_name)

#If didnt find any categories specified, all seqs will be in 'unknown' - but don't sample this!
if len(seq_names_by_group)==1 and ('unknown' in seq_names_by_group or ('unknown',) in seq_names_by_group):
print("WARNING: The specified group-by categories (%s) were not found."%groups,
"No sequences-per-group sampling will be done.")
if any([x in groups for x in ['year','month']]):
print("Note that using 'year' or 'year month' requires a column called 'date'.")
print("\n")
return strains_to_keep

# Check to see if some categories are missing to warn the user
group_by = set(['date' if cat in ['year','month'] else cat
for cat in groups])
missing_cats = [cat for cat in group_by if cat not in metadata.columns.values and cat != "_dummy"]
if missing_cats:
print("WARNING:")
if any([cat != 'date' for cat in missing_cats]):
print("\tSome of the specified group-by categories couldn't be found: ",
", ".join([str(cat) for cat in missing_cats if cat != 'date']))
if any([cat == 'date' for cat in missing_cats]):
print("\tA 'date' column could not be found to group-by year or month.")
print("\tFiltering by group may behave differently than expected!\n")

if priority: # read priorities
priorities = read_priority_scores(priority)

if spg is None:
# this is only possible if we have imposed a maximum number of samples
# to produce. we need binary search until we have the correct spg.
try:
length_of_sequences_per_group = [
len(sequences_in_group)
for sequences_in_group in seq_names_by_group.values()
]

if probabilistic_sampling:
spg = _calculate_fractional_sequences_per_group(
max_sequences,
length_of_sequences_per_group
)
else:
spg = _calculate_sequences_per_group(
max_sequences,
length_of_sequences_per_group
)
except TooManyGroupsError as ex:
print(f"ERROR: {ex}", file=sys.stderr)
sys.exit(1)
print("sampling at {} per group.".format(spg))

if probabilistic_sampling:
random_generator = np.random.default_rng()

# subsample each groups, either by taking the spg highest priority strains or
# sampling at random from the sequences in the group
seq_subsample = set()
subsampling_attempts = 0

# Attempt to subsample with the given constraints for a fixed number
# of times. For small values of maximum sequences, subsampling can
# randomly select zero sequences to keep. When this happens, we can
# usually find a non-zero number of samples by repeating the
# process.
while len(seq_subsample) == 0 and subsampling_attempts < MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS:
subsampling_attempts += 1

for group, sequences_in_group in seq_names_by_group.items():
if probabilistic_sampling:
tmp_spg = random_generator.poisson(spg)
else:
tmp_spg = spg

if tmp_spg == 0:
continue

if priority: #sort descending by priority
seq_subsample.update(
set(
sorted(
sequences_in_group,
key=lambda x: priorities[x],
reverse=True
)[:tmp_spg]
)
)
else:
seq_subsample.update(
set(
sequences_in_group
if len(sequences_in_group)<=tmp_spg
else random.sample(sequences_in_group, tmp_spg)
)
)

return seq_subsample


def register_arguments(parser):
input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered")
input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata, as CSV or TSV")
Expand Down Expand Up @@ -654,155 +836,20 @@ def run(args):
num_excluded_by_nuc = len(seq_keep - filtered)
seq_keep = filtered

# subsampling. This will sort sequences into groups by meta data fields
# specified in --group-by and then take at most --sequences-per-group
# from each group. Within each group, sequences are optionally sorted
# by a priority score specified in a file --priority
# Fix seed for the RNG if specified
if args.subsample_seed:
random.seed(args.subsample_seed)
num_excluded_subsamp = 0

# Disable probabilistic sampling when user's request a specific number of
# sequences per group. In this case, users expect deterministic behavior and
# probabilistic behavior is surprising.
probabilistic_sampling = args.probabilistic_sampling
if args.sequences_per_group:
probabilistic_sampling = False

if args.subsample_max_sequences or (args.group_by and args.sequences_per_group):

#set groups to group_by values
if args.group_by:
groups = args.group_by
#if group_by not specified use dummy category
else:
groups = ["_dummy"]

spg = args.sequences_per_group
seq_names_by_group = defaultdict(list)

for seq_name in seq_keep:
group = []
m = metadata.loc[seq_name].to_dict()
# collect group specifiers
for c in groups:
if c == "_dummy":
group.append(c)
elif c in m:
group.append(m[c])
elif c in ['month', 'year'] and 'date' in m:
try:
year = int(m["date"].split('-')[0])
except:
print("WARNING: no valid year, skipping",seq_name, m["date"])
continue
if c=='month':
try:
month = int(m["date"].split('-')[1])
except:
month = random.randint(1,12)
group.append((year, month))
else:
group.append(year)
else:
group.append('unknown')
seq_names_by_group[tuple(group)].append(seq_name)

#If didnt find any categories specified, all seqs will be in 'unknown' - but don't sample this!
if len(seq_names_by_group)==1 and ('unknown' in seq_names_by_group or ('unknown',) in seq_names_by_group):
print("WARNING: The specified group-by categories (%s) were not found."%groups,
"No sequences-per-group sampling will be done.")
if any([x in groups for x in ['year','month']]):
print("Note that using 'year' or 'year month' requires a column called 'date'.")
print("\n")
else:
# Check to see if some categories are missing to warn the user
group_by = set(['date' if cat in ['year','month'] else cat
for cat in groups])
missing_cats = [cat for cat in group_by if cat not in meta_columns and cat != "_dummy"]
if missing_cats:
print("WARNING:")
if any([cat != 'date' for cat in missing_cats]):
print("\tSome of the specified group-by categories couldn't be found: ",
", ".join([str(cat) for cat in missing_cats if cat != 'date']))
if any([cat == 'date' for cat in missing_cats]):
print("\tA 'date' column could not be found to group-by year or month.")
print("\tFiltering by group may behave differently than expected!\n")

if args.priority: # read priorities
priorities = read_priority_scores(args.priority)

if spg is None:
# this is only possible if we have imposed a maximum number of samples
# to produce. we need binary search until we have the correct spg.
try:
length_of_sequences_per_group = [
len(sequences_in_group)
for sequences_in_group in seq_names_by_group.values()
]

if probabilistic_sampling:
spg = _calculate_fractional_sequences_per_group(
args.subsample_max_sequences,
length_of_sequences_per_group
)
else:
spg = _calculate_sequences_per_group(
args.subsample_max_sequences,
length_of_sequences_per_group
)
except TooManyGroupsError as ex:
print(f"ERROR: {ex}", file=sys.stderr)
sys.exit(1)
print("sampling at {} per group.".format(spg))

if probabilistic_sampling:
random_generator = np.random.default_rng()

# subsample each groups, either by taking the spg highest priority strains or
# sampling at random from the sequences in the group
seq_subsample = set()
subsampling_attempts = 0

# Attempt to subsample with the given constraints for a fixed number
# of times. For small values of maximum sequences, subsampling can
# randomly select zero sequences to keep. When this happens, we can
# usually find a non-zero number of samples by repeating the
# process.
while len(seq_subsample) == 0 and subsampling_attempts < MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS:
subsampling_attempts += 1

for group, sequences_in_group in seq_names_by_group.items():
if probabilistic_sampling:
tmp_spg = random_generator.poisson(spg)
else:
tmp_spg = spg

if tmp_spg == 0:
continue

if args.priority: #sort descending by priority
seq_subsample.update(
set(
sorted(
sequences_in_group,
key=lambda x: priorities[x],
reverse=True
)[:tmp_spg]
)
)
else:
seq_subsample.update(
set(
sequences_in_group
if len(sequences_in_group)<=tmp_spg
else random.sample(sequences_in_group, tmp_spg)
)
)

num_excluded_subsamp = len(seq_keep) - len(seq_subsample)
seq_keep = seq_subsample
seq_subsample = subsample(
metadata,
seq_keep,
args.group_by,
args.sequences_per_group,
args.subsample_max_sequences,
args.probabilistic_sampling,
args.priority,
args.subsample_seed,
)
num_excluded_subsamp = len(seq_keep) - len(seq_subsample)
seq_keep = seq_subsample

# force include sequences specified in file.
# Note that this might re-add previously excluded sequences
Expand Down