Skip to content

Commit

Permalink
Refactor subsampling into its own function
Browse files Browse the repository at this point in the history
Adds a new function `subsample` with the logic originally defined in
`run`. This new function allows us to reorganize the `run` function more
substantially and confidently split the subsampling logic into more
manageable components.
  • Loading branch information
huddlej committed Jul 10, 2021
1 parent dbf000f commit 86ef7a1
Showing 1 changed file with 194 additions and 147 deletions.
341 changes: 194 additions & 147 deletions augur/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,188 @@ def include_by_query(metadata, include_where):
return set(metadata[included].index.values)


def subsample(metadata,
strains_to_keep,
group_by=None,
sequences_per_group=None,
max_sequences=None,
probabilistic_sampling=False,
priority=None,
random_seed=None):
"""Subsample metadata into a fixed number of strains per group. If the user
specifies a maximum number of subsampled strains, calculate the
corresponding sequences per group for the available groups. If no group is
defined, use a dummy group.
Optionally, sorts strains by a given priority score instead of returning
random strains for each group.
Parameters
----------
metadata : pandas.DataFrame
Metadata to subsample
strains_to_keep : set[str]
Strain names to consider for subsampling from the given metadata
group_by : list[str]
Column(s) to group metadata by prior to subsampling. When omitted, subsampling uses a "_dummy" group.
sequences_per_group : int
Number of sequences to sample per group
max_sequences : int
Maximum number of sequences to sample total. Mutually exclusive of ``sequences_per_group``.
probabilistic_sampling : bool
Enable probabilistic subsampling
priority : str
Name of a tab-delimited file containing priorities assigned to strains such that higher numbers indicate higher priority.
random_seed : str
Random seed for subsampling
Returns
-------
set[str]:
Strains that pass the filter
"""
# Set the random seed, for more reproducible results.
if random_seed:
random.seed(random_seed)

# Disable probabilistic sampling when user's request a specific number of
# sequences per group. In this case, users expect deterministic behavior and
# probabilistic behavior is surprising.
if sequences_per_group:
probabilistic_sampling = False

if group_by:
groups = group_by
else:
groups = ["_dummy"]

spg = sequences_per_group
seq_names_by_group = defaultdict(list)

for seq_name in strains_to_keep:
group = []
m = metadata.loc[seq_name].to_dict()
# collect group specifiers
for c in groups:
if c == "_dummy":
group.append(c)
elif c in m:
group.append(m[c])
elif c in ['month', 'year'] and 'date' in m:
try:
year = int(m["date"].split('-')[0])
except:
print("WARNING: no valid year, skipping",seq_name, m["date"])
continue
if c=='month':
try:
month = int(m["date"].split('-')[1])
except:
month = random.randint(1,12)
group.append((year, month))
else:
group.append(year)
else:
group.append('unknown')
seq_names_by_group[tuple(group)].append(seq_name)

#If didnt find any categories specified, all seqs will be in 'unknown' - but don't sample this!
if len(seq_names_by_group)==1 and ('unknown' in seq_names_by_group or ('unknown',) in seq_names_by_group):
print("WARNING: The specified group-by categories (%s) were not found."%groups,
"No sequences-per-group sampling will be done.")
if any([x in groups for x in ['year','month']]):
print("Note that using 'year' or 'year month' requires a column called 'date'.")
print("\n")
return strains_to_keep

# Check to see if some categories are missing to warn the user
group_by = set(['date' if cat in ['year','month'] else cat
for cat in groups])
missing_cats = [cat for cat in group_by if cat not in metadata.columns.values and cat != "_dummy"]
if missing_cats:
print("WARNING:")
if any([cat != 'date' for cat in missing_cats]):
print("\tSome of the specified group-by categories couldn't be found: ",
", ".join([str(cat) for cat in missing_cats if cat != 'date']))
if any([cat == 'date' for cat in missing_cats]):
print("\tA 'date' column could not be found to group-by year or month.")
print("\tFiltering by group may behave differently than expected!\n")

if priority: # read priorities
priorities = read_priority_scores(priority)

if spg is None:
# this is only possible if we have imposed a maximum number of samples
# to produce. we need binary search until we have the correct spg.
try:
length_of_sequences_per_group = [
len(sequences_in_group)
for sequences_in_group in seq_names_by_group.values()
]

if probabilistic_sampling:
spg = _calculate_fractional_sequences_per_group(
max_sequences,
length_of_sequences_per_group
)
else:
spg = _calculate_sequences_per_group(
max_sequences,
length_of_sequences_per_group
)
except TooManyGroupsError as ex:
print(f"ERROR: {ex}", file=sys.stderr)
sys.exit(1)
print("sampling at {} per group.".format(spg))

if probabilistic_sampling:
random_generator = np.random.default_rng()

# subsample each groups, either by taking the spg highest priority strains or
# sampling at random from the sequences in the group
seq_subsample = set()
subsampling_attempts = 0

# Attempt to subsample with the given constraints for a fixed number
# of times. For small values of maximum sequences, subsampling can
# randomly select zero sequences to keep. When this happens, we can
# usually find a non-zero number of samples by repeating the
# process.
while len(seq_subsample) == 0 and subsampling_attempts < MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS:
subsampling_attempts += 1

for group, sequences_in_group in seq_names_by_group.items():
if probabilistic_sampling:
tmp_spg = random_generator.poisson(spg)
else:
tmp_spg = spg

if tmp_spg == 0:
continue

if priority: #sort descending by priority
seq_subsample.update(
set(
sorted(
sequences_in_group,
key=lambda x: priorities[x],
reverse=True
)[:tmp_spg]
)
)
else:
seq_subsample.update(
set(
sequences_in_group
if len(sequences_in_group)<=tmp_spg
else random.sample(sequences_in_group, tmp_spg)
)
)

return seq_subsample


def register_arguments(parser):
input_group = parser.add_argument_group("inputs", "metadata and sequences to be filtered")
input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata, as CSV or TSV")
Expand Down Expand Up @@ -654,155 +836,20 @@ def run(args):
num_excluded_by_nuc = len(seq_keep - filtered)
seq_keep = filtered

# subsampling. This will sort sequences into groups by meta data fields
# specified in --group-by and then take at most --sequences-per-group
# from each group. Within each group, sequences are optionally sorted
# by a priority score specified in a file --priority
# Fix seed for the RNG if specified
if args.subsample_seed:
random.seed(args.subsample_seed)
num_excluded_subsamp = 0

# Disable probabilistic sampling when user's request a specific number of
# sequences per group. In this case, users expect deterministic behavior and
# probabilistic behavior is surprising.
probabilistic_sampling = args.probabilistic_sampling
if args.sequences_per_group:
probabilistic_sampling = False

if args.subsample_max_sequences or (args.group_by and args.sequences_per_group):

#set groups to group_by values
if args.group_by:
groups = args.group_by
#if group_by not specified use dummy category
else:
groups = ["_dummy"]

spg = args.sequences_per_group
seq_names_by_group = defaultdict(list)

for seq_name in seq_keep:
group = []
m = metadata.loc[seq_name].to_dict()
# collect group specifiers
for c in groups:
if c == "_dummy":
group.append(c)
elif c in m:
group.append(m[c])
elif c in ['month', 'year'] and 'date' in m:
try:
year = int(m["date"].split('-')[0])
except:
print("WARNING: no valid year, skipping",seq_name, m["date"])
continue
if c=='month':
try:
month = int(m["date"].split('-')[1])
except:
month = random.randint(1,12)
group.append((year, month))
else:
group.append(year)
else:
group.append('unknown')
seq_names_by_group[tuple(group)].append(seq_name)

#If didnt find any categories specified, all seqs will be in 'unknown' - but don't sample this!
if len(seq_names_by_group)==1 and ('unknown' in seq_names_by_group or ('unknown',) in seq_names_by_group):
print("WARNING: The specified group-by categories (%s) were not found."%groups,
"No sequences-per-group sampling will be done.")
if any([x in groups for x in ['year','month']]):
print("Note that using 'year' or 'year month' requires a column called 'date'.")
print("\n")
else:
# Check to see if some categories are missing to warn the user
group_by = set(['date' if cat in ['year','month'] else cat
for cat in groups])
missing_cats = [cat for cat in group_by if cat not in meta_columns and cat != "_dummy"]
if missing_cats:
print("WARNING:")
if any([cat != 'date' for cat in missing_cats]):
print("\tSome of the specified group-by categories couldn't be found: ",
", ".join([str(cat) for cat in missing_cats if cat != 'date']))
if any([cat == 'date' for cat in missing_cats]):
print("\tA 'date' column could not be found to group-by year or month.")
print("\tFiltering by group may behave differently than expected!\n")

if args.priority: # read priorities
priorities = read_priority_scores(args.priority)

if spg is None:
# this is only possible if we have imposed a maximum number of samples
# to produce. we need binary search until we have the correct spg.
try:
length_of_sequences_per_group = [
len(sequences_in_group)
for sequences_in_group in seq_names_by_group.values()
]

if probabilistic_sampling:
spg = _calculate_fractional_sequences_per_group(
args.subsample_max_sequences,
length_of_sequences_per_group
)
else:
spg = _calculate_sequences_per_group(
args.subsample_max_sequences,
length_of_sequences_per_group
)
except TooManyGroupsError as ex:
print(f"ERROR: {ex}", file=sys.stderr)
sys.exit(1)
print("sampling at {} per group.".format(spg))

if probabilistic_sampling:
random_generator = np.random.default_rng()

# subsample each groups, either by taking the spg highest priority strains or
# sampling at random from the sequences in the group
seq_subsample = set()
subsampling_attempts = 0

# Attempt to subsample with the given constraints for a fixed number
# of times. For small values of maximum sequences, subsampling can
# randomly select zero sequences to keep. When this happens, we can
# usually find a non-zero number of samples by repeating the
# process.
while len(seq_subsample) == 0 and subsampling_attempts < MAX_NUMBER_OF_PROBABILISTIC_SAMPLING_ATTEMPTS:
subsampling_attempts += 1

for group, sequences_in_group in seq_names_by_group.items():
if probabilistic_sampling:
tmp_spg = random_generator.poisson(spg)
else:
tmp_spg = spg

if tmp_spg == 0:
continue

if args.priority: #sort descending by priority
seq_subsample.update(
set(
sorted(
sequences_in_group,
key=lambda x: priorities[x],
reverse=True
)[:tmp_spg]
)
)
else:
seq_subsample.update(
set(
sequences_in_group
if len(sequences_in_group)<=tmp_spg
else random.sample(sequences_in_group, tmp_spg)
)
)

num_excluded_subsamp = len(seq_keep) - len(seq_subsample)
seq_keep = seq_subsample
seq_subsample = subsample(
metadata,
seq_keep,
args.group_by,
args.sequences_per_group,
args.subsample_max_sequences,
args.probabilistic_sampling,
args.priority,
args.subsample_seed,
)
num_excluded_subsamp = len(seq_keep) - len(seq_subsample)
seq_keep = seq_subsample

# force include sequences specified in file.
# Note that this might re-add previously excluded sequences
Expand Down

0 comments on commit 86ef7a1

Please sign in to comment.