Skip to content

Commit

Permalink
Support AudioSet training with weighted sampler (#1727)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 authored Aug 22, 2024
1 parent 5952972 commit 3fc06cc
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 33 deletions.
36 changes: 30 additions & 6 deletions egs/audioset/AT/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,40 @@ python zipformer/train.py \
--master-port 13455
```

We recommend that you train the model with weighted sampler, as the model converges
faster with better performance:

| Model | mAP |
| ------ | ------- |
| Zipformer-AT, train with weighted sampler | 46.6 |

The evaluation command is:

```bash
python zipformer/evaluate.py \
--epoch 32 \
--avg 8 \
--exp-dir zipformer/exp_at_as_full \
--max-duration 500
export CUDA_VISIBLE_DEVICES="4,5,6,7"
subset=full
weighted_sampler=1
bucket_sampler=0
lr_epochs=15

python zipformer/train.py \
--world-size 4 \
--audioset-subset $subset \
--num-epochs 120 \
--start-epoch 1 \
--use-fp16 1 \
--num-events 527 \
--lr-epochs $lr_epochs \
--exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \
--weighted-sampler $weighted_sampler \
--bucketing-sampler $bucket_sampler \
--max-duration 1000 \
--enable-musan True \
--master-port 13452
```

The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler


#### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M

Expand Down Expand Up @@ -92,4 +116,4 @@ python zipformer/evaluate.py \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--exp-dir zipformer/exp_small_at_as_full \
--max-duration 500
```
```
73 changes: 73 additions & 0 deletions egs/audioset/AT/local/compute_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file generates the manifest and computes the fbank features for AudioSet
dataset. The generated manifests and features are stored in data/fbank.
"""

import argparse

import lhotse
from lhotse import load_manifest


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz"
)

parser.add_argument(
"--output",
type=str,
required=True,
)
return parser


def main():
# Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py
parser = get_parser()
args = parser.parse_args()

cuts = load_manifest(args.input_manifest)

print(f"A total of {len(cuts)} cuts.")

label_count = [0] * 527 # a total of 527 classes
for c in cuts:
audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";")))
for label in labels:
label_count[label] += 1

with open(args.output, "w") as f:
for c in cuts:
audio_event = c.supervisions[0].audio_event
labels = list(map(int, audio_event.split(";")))
weight = 0
for label in labels:
weight += 1000 / (label_count[label] + 0.01)
f.write(f"{c.id} {weight}\n")


if __name__ == "__main__":
main()
13 changes: 12 additions & 1 deletion egs/audioset/AT/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ stage=-1
stop_stage=4

dl_dir=$PWD/download
fbank_dir=data/fbank

# we assume that you have your downloaded the AudioSet and placed
# it under $dl_dir/audioset, the folder structure should look like
Expand Down Expand Up @@ -49,7 +50,6 @@ fi

if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set"
fbank_dir=data/fbank
if [! -e $fbank_dir/.balanced.done]; then
python local/generate_audioset_manifest.py \
--dataset-dir $dl_dir/audioset \
Expand Down Expand Up @@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
touch data/fbank/.musan.done
fi
fi

# The following stages are required to do weighted-sampling training
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare for weighted-sampling training"
if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then
lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz
fi
python ./local/compute_weight.py \
--input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \
--output $fbank_dir/sampling_weights_full.txt
fi
107 changes: 82 additions & 25 deletions egs/audioset/AT/zipformer/at_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
WeightedSimpleCutSampler,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
Expand Down Expand Up @@ -99,6 +100,20 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--weighted-sampler",
type=str2bool,
default=False,
help="When enabled, samples are drawn from by their weights. "
"It cannot be used together with bucketing sampler",
)
group.add_argument(
"--num-samples",
type=int,
default=200000,
help="The number of samples to be drawn in each epoch. Only be used"
"for weighed sampler",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
Expand Down Expand Up @@ -295,6 +310,9 @@ def train_dataloaders(
)

if self.args.bucketing_sampler:
assert (
not self.args.weighted_sampler
), "weighted sampling is not supported in bucket sampler"
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
Expand All @@ -304,13 +322,26 @@ def train_dataloaders(
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
)
if self.args.weighted_sampler:
# assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset"
logging.info("Using weighted SimpleCutSampler")
weights = self.audioset_sampling_weights()
train_sampler = WeightedSimpleCutSampler(
cuts_train,
weights,
num_samples=self.args.num_samples,
max_duration=self.args.max_duration,
shuffle=False, # do not support shuffle
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
drop_last=self.args.drop_last,
)
logging.info("About to create train dataloader")

if sampler_state_dict is not None:
Expand Down Expand Up @@ -373,11 +404,9 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = AudioTaggingDataset(
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
Expand All @@ -397,21 +426,30 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
@lru_cache()
def audioset_train_cuts(self) -> CutSet:
logging.info("About to get the audioset training cuts.")
balanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
)
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
if not self.args.weighted_sampler:
balanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz"
)
if self.args.audioset_subset == "full":
unbalanced_cuts = load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz"
)
cuts = CutSet.mux(
balanced_cuts,
unbalanced_cuts,
weights=[20000, 2000000],
stop_early=True,
)
else:
cuts = balanced_cuts
else:
cuts = balanced_cuts
# assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet"
cuts = load_manifest(
self.args.manifest_dir
/ f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz"
)
logging.info(f"Get {len(cuts)} cuts in total.")

return cuts

@lru_cache()
Expand All @@ -420,3 +458,22 @@ def audioset_eval_cuts(self) -> CutSet:
return load_manifest_lazy(
self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz"
)

@lru_cache()
def audioset_sampling_weights(self):
logging.info(
f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet"
)
weights = []
with open(
self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt",
"r",
) as f:
while True:
line = f.readline()
if not line:
break
weight = float(line.split()[1])
weights.append(weight)
logging.info(f"Get the sampling weight for {len(weights)} cuts")
return weights
11 changes: 10 additions & 1 deletion egs/audioset/AT/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,14 @@ def save_bad_model(suffix: str = ""):
rank=0,
)

num_samples = 0
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))

params.batch_idx_train += 1
batch_size = batch["inputs"].size(0)
num_samples += batch_size

try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
Expand Down Expand Up @@ -919,6 +921,12 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/valid_", params.batch_idx_train
)

if num_samples > params.num_samples:
logging.info(
f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch"
)
break

loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
Expand Down Expand Up @@ -1032,7 +1040,8 @@ def remove_short_and_long_utt(c: Cut):

return True

train_cuts = train_cuts.filter(remove_short_and_long_utt)
if not params.weighted_sampler:
train_cuts = train_cuts.filter(remove_short_and_long_utt)

if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
Expand Down

0 comments on commit 3fc06cc

Please sign in to comment.