Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BucketingSampler more randomness? #364

Closed
danpovey opened this issue Aug 10, 2021 · 25 comments
Closed

BucketingSampler more randomness? #364

danpovey opened this issue Aug 10, 2021 · 25 comments

Comments

@danpovey
Copy link
Collaborator

@pzelasko we are still seeing sawtooth patterns in losses when we use BucketingSampler, even with fewer buckets.
I think it's because the individual buckets are sorted by length. Is it possible to shuffle somehow within the buckets, or, say, always randomly pick a batch from the front or back of the bucket? Or perhaps the individual buckets could either be reversed, or not-reversed, randomly or alternately.

@csukuangfj
Copy link
Contributor

csukuangfj commented Aug 10, 2021

https://tensorboard.dev/experiment/h0VXvnqfR8aabI3HJxHxYA/
The above tensorboard log is using bucketing sampler with the default number of buckets (30).
You can see that the attention loss jumps at the beginning of each epoch.

@pzelasko
Copy link
Collaborator

OK, just a sanity check — did you pass in the arguments shuffle=True to the constructor, and called sampler.set_epoch(epoch) at each epoch?

@csukuangfj
Copy link
Contributor

csukuangfj commented Aug 10, 2021

did you pass in the arguments shuffle=True to the constructor, and called sampler.set_epoch(epoch) at each epoch?

Yes. The related code is shown below:

https://github.com/k2-fsa/icefall/blob/master/icefall/dataset/asr_datamodule.py#L172

        if self.args.bucketing_sampler:
            logging.info("Using BucketingSampler.")
            train_sampler = BucketingSampler(
                cuts_train,
                max_duration=self.args.max_duration,
                shuffle=True,
                num_buckets=self.args.num_buckets,
            )

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/train.py#L639

    for epoch in range(params.start_epoch, params.num_epochs):
        train_dl.sampler.set_epoch(epoch)

The training options from the log file are:

2021-08-08 16:39:50,977 INFO [train.py:598] (0/3) 
{'exp_dir': PosixPath('conformer_ctc/exp_new'), 
'lang_dir': PosixPath('data/lang_bpe'), 
'feature_dim': 80,
'weight_decay': 1e-06, 'subsampling_factor': 4, 'start_epoch': 1, 
'num_epochs': 50, 'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 
'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 10, 
'reset_interval': 200, 'valid_interval': 3000, 'beam_size': 10, 'reduction': 'sum', 
'use_double_scores': True, 'accum_grad': 1, 'att_rate': 0.7, 'attention_dim': 512, 'nhead': 8, 
'num_decoder_layers': 6, 'is_espnet_structure': True, 'mmi_loss': False, 
'use_feat_batchnorm': True, 'lr_factor': 5.0, 'warm_step': 80000, 'world_size': 3, 
'master_port': 12354, 'tensorboard': True, 
'feature_dir': PosixPath('data/fbank'), 'max_duration': 200, 
'bucketing_sampler': True, 'num_buckets': 30, 
'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 
'on_the_fly_feats': False, 'full_libri': True}

@pzelasko
Copy link
Collaborator

In that case, I will double-check if the shuffling works as intended in each bucket of the BucketingSampler.

BTW I didn’t notice the drop_last=True option for BucketingSampler in what you posted — was it used?

@danpovey
Copy link
Collaborator Author

There is another factor that could cause this sawtooth pattern, which is that the buckets have the same number of samples, but you consume them in batches that depend on the duration, so the shorter buckets will be consumed first. It might be better to compute the cumulative sum of the durations, and split at percentiles of that, if that is not already what split() there does.

@pzelasko
Copy link
Collaborator

It doesn’t do it. I can make a PR with that option later and let’s see then.

@csukuangfj
Copy link
Contributor

BTW I didn’t notice the drop_last=True option for BucketingSampler in what you posted — was it used?

I think it is unused and is left to the default value, which is False.
By the way, the drop_last is.not.exposed by lhotse, I think.

@pzelasko
Copy link
Collaborator

pzelasko commented Aug 10, 2021 via email

@pzelasko
Copy link
Collaborator

Please check out both with drop_last=True (#357) and with bucket_method='equal_duration' (#365). Let me know if any of these help.

@danpovey
Copy link
Collaborator Author

BTW, if each bucket keeps track of the amount of data it has left (however we define that , e.g. duration or samples), a relatively easy way to implement approximate proportional sampling would be to, whenever the next batch is requested, choose 2 nonempty buckets a and b, and pick from bucket a with probability (a.dur() / (a.dur() + b.dur()), else bucket b. That would ensure that all the buckets would become empty at exactly the same time. Otherwise, some of the buckets will start to deplete approximately [sqrt(num_batches_in_a_bucket) * num_buckets] batches before the end of the epoch even if they have been balanced at the start.

@pzelasko
Copy link
Collaborator

I'm not sure I understand why some buckets would deplete sooner. With the new method, each bucket has the same duration of speech -- since the batches are gathered to satisfy a max total duration, it implies that each bucket should yield the same number of batches. And since the buckets are chosen with uniform probabilities, they should all deplete at the same time.

Still, the method you're proposing could be useful as a source selection strategy for a "mux" sampler (that I'm planning to add soon) so that it avoids depleting its source samplers prematurely.

@pkufool
Copy link
Contributor

pkufool commented Aug 12, 2021

My training crashed with batch size mismatching, there may be some bugs with bucket_method='equal_duration'.

The log showed that the returned x and mask with different batch size 52 vs 46, may be problems with subsampling and masking, I am not sure, will figure out that.

If I did not set bucket_method='equal_duration', it won't cause this crash.

https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/conformer_ctc/transformer.py#L210-L217

        x = self.encoder_embed(x)
        x = self.encoder_pos(x)
        x = x.permute(1, 0, 2)  # (N, T, C) -> (T, N, C)
        mask = encoder_padding_mask(x.size(0), supervisions)
        mask = mask.to(x.device) if mask is not None else None
        x = self.encoder(x, src_key_padding_mask=mask)  # (T, N, C)
        return x, mask 

@pzelasko
Copy link
Collaborator

Okay, thanks for reporting. I’m not sure what might have went wrong at the moment — I’ll try to reproduce this issue. Is this LibriSpeech 100h / 960h?

@danpovey
Copy link
Collaborator Author

danpovey commented Aug 12, 2021 via email

@pzelasko
Copy link
Collaborator

Piotr: my formula was just taking into account the statistics of how random counts with equal probability are not quite the same in practice.

I had a feeling that I’m missing something.. that’s a cool observation. Let me see if I got it right. Since at every step the bucket selection probabilities are uniform, there is a relatively „high” probability that at least one of them will deplete „early” (for some definition of „high” and „early”). And in the expectation „early” is (sqrt(num_bucket_batches) * num_buckets).

@pkufool
Copy link
Contributor

pkufool commented Aug 12, 2021

Okay, thanks for reporting. I’m not sure what might have went wrong at the moment — I’ll try to reproduce this issue. Is this LibriSpeech 100h / 960h?

LibriSpeech 960h, you could reproduce this issue by running the training script in the icefall conformer_ctc folder.

@danpovey
Copy link
Collaborator Author

yes. The factor of num_buckets is simply because i want to measure the time in minibatches before end-of-epoch for the training process, not minibatches per bucket.

@pkufool
Copy link
Contributor

pkufool commented Aug 12, 2021

@pzelasko fangjun told me the batch size mismatching is a bug in masking (posted here k2-fsa/snowfall#240). It won't trigger the bug when setting --concatenate-cuts False .

We will try to fix the bug in masking.

@pzelasko
Copy link
Collaborator

@pzelasko fangjun told me the batch size mismatching is a bug in masking (posted here k2-fsa/snowfall#240). It won't trigger the bug when setting --concatenate-cuts False .

We will try to fix the bug in masking.

Thanks for letting me know!

@pkufool
Copy link
Contributor

pkufool commented Aug 13, 2021

Please check out both with drop_last=True (#357) and with bucket_method='equal_duration' (#365). Let me know if any of these help.

After adding these two options, the sawtooth patterns in losses disappeared. The picture below shows the loss value of the first 3 epochs. The training is on going, we will see if this helps to reduce the WER, will post results later.

image

@csukuangfj
Copy link
Contributor

csukuangfj commented Aug 15, 2021

I am using bucket_method='equal_duration', drop_last=True, and max_duration=400.
Sometimes the total duration of all cuts in a batch is more than 500 seconds, does it indicate some problem?


The information of the batch is printed below:

 id: 9bcbb594-3446-4236-b5db-0ec30dd4c2e1, duration: 6.42775 seconds
 id: 17e1cd8f-c7ad-4401-a3cc-3f842de01914, duration: 6.42775 seconds
 id: 1624-168623-0033-13945-0_sp0.9, duration: 6.3 seconds
 id: 72ba08ef-55df-486f-8152-13971345fc87, duration: 6.42775 seconds
 id: 2989-138035-0073-3067-0, duration: 6.27 seconds
 id: c992bc14-56f6-4385-8943-7661fe6d04a2, duration: 6.42775 seconds
 id: 4406-16883-0001-6169-0_sp1.1, duration: 6.2 seconds
 id: 8419-286676-0029-8688-0_sp0.9, duration: 6.188875 seconds
 id: 250962e6-4263-49ee-88d3-698119303671, duration: 6.42775 seconds
 id: 5703-47212-0010-24110-0_sp0.9, duration: 6.0444375 seconds
 id: 2989-138035-0022-3016-0_sp1.1, duration: 5.92275 seconds
 id: 1eb02a55-bc38-4874-a1e6-e162b90c7dc7, duration: 6.42775 seconds
 id: 4929e838-e8a4-4857-a869-225a86dec79c, duration: 6.42775 seconds
 id: 97f2f7f4-61d6-43ed-91c4-dccc862fee86, duration: 6.42775 seconds
 id: 8324-286683-0021-1312-0_sp0.9, duration: 5.72225 seconds
 id: f62d5a18-11bc-46a9-8bc3-3abfa56be98e, duration: 6.42775 seconds
 id: 7c50b511-cc7a-4a48-bece-5cdbb11bc15a, duration: 6.42775 seconds
 id: 7402-90848-0045-22228-0_sp0.9, duration: 5.538875 seconds
 id: cc32a855-bcdb-415f-94a4-15bf9a5fe248, duration: 6.42775 seconds
 id: 730-358-0039-12108-0_sp1.1, duration: 5.5090625 seconds
 id: 8088-284756-0092-5440-0_sp0.9, duration: 5.488875 seconds
 id: d03fdbd8-c252-4e17-a0ce-f6f27bfaf37d, duration: 6.42775 seconds
 id: 83-11691-0009-9950-0_sp1.1, duration: 5.3318125 seconds
 id: 53639567-5891-4abb-b11d-d083ae0de5da, duration: 6.42775 seconds
 id: 88329e35-a9de-47e1-a25c-bb87a8c569c6, duration: 6.42775 seconds
 id: bac37588-b33a-4050-955e-9fae852b3f92, duration: 6.42775 seconds
 id: 8063-274116-0022-8870-0_sp0.9, duration: 5.2166875 seconds
 id: f9115a1c-c45c-45f2-aeb4-abcfd1dd0880, duration: 6.42775 seconds
 id: 12552983-4ccb-4c83-8bcc-f0488d68219a, duration: 6.42775 seconds
 id: 44018d69-e31d-4d2c-8fa1-0bed26328691, duration: 6.42775 seconds
 id: bc327bac-30e2-460c-8f78-7e52fb75145b, duration: 6.42775 seconds
 id: 7f3f499a-9010-417e-a34c-647ab00964a4, duration: 6.42775 seconds
 id: 4b56c4fa-58dd-4c93-9fa8-e3397d8b9ff3, duration: 6.42775 seconds
 id: b2c1ae16-bde7-4abf-9824-eac34612ced1, duration: 6.42775 seconds
 id: a43012bd-562e-4f15-8400-c8ffde18ead8, duration: 6.42775 seconds
 id: 4c2b145c-2b1b-486c-9737-a68a3840e9ca, duration: 6.42775 seconds
 id: 1926-147987-0006-19961-0_sp0.9, duration: 4.638875 seconds
 id: 730-358-0056-12125-0_sp0.9, duration: 4.62225 seconds
 id: 3436-172171-0057-22518-0_sp1.1, duration: 4.6 seconds
 id: edcb0110-1a16-4d54-95e8-efdafd66d91f, duration: 6.42775 seconds
 id: 6925-80680-0007-2360-0_sp1.1, duration: 4.513625 seconds
 id: 5561-39621-0035-7247-0_sp1.1, duration: 4.5 seconds
 id: 8238-283452-0030-30-0, duration: 4.49 seconds
 id: 3214-167602-0024-24726-0, duration: 4.46 seconds
 id: 1034-121119-0038-9552-0, duration: 4.415 seconds
 id: 57f4c3f5-3118-43dd-bfa9-c6b4874d6f1d, duration: 6.42775 seconds
 id: 39881cc6-7304-4fd4-8763-20fdda58db79, duration: 6.42775 seconds
 id: 2989-138035-0047-3041-0, duration: 4.32 seconds
 id: 2c4e44e7-426c-4006-a9be-052783726d42, duration: 6.42775 seconds
 id: 5688-15787-0015-11314-0, duration: 4.27 seconds
 id: 7bd8ec9b-14d7-45b9-aefc-a6a7bfe8b84d, duration: 6.42775 seconds
 id: 8108-280354-0031-6499-0, duration: 4.17 seconds
 id: aad9c266-59b0-469c-acc4-076cd39288be, duration: 6.42775 seconds
 id: 26d24872-0156-4593-986b-c0fee5cab9fa, duration: 6.42775 seconds
 id: d30de34e-73f6-48b4-8c85-ff1411ffe171, duration: 6.42775 seconds
 id: 2391-145015-0072-13814-0_sp1.1, duration: 3.9909375 seconds
 id: ec33251d-5a63-447f-864f-f3bf18848b5f, duration: 6.42775 seconds
 id: 83-11691-0013-9954-0, duration: 3.905 seconds
 id: 69769ce7-00a9-47e5-93c9-37f1cf271d77, duration: 6.42775 seconds
 id: 669-129074-0018-4182-0_sp0.9, duration: 3.72225 seconds
 id: 7253e979-4ca1-479f-849b-cd6ae51aad55, duration: 6.42775 seconds
 id: 27-124992-0065-19617-0, duration: 3.52 seconds
 id: a0917d43-fada-4ca1-8a29-b88105780614, duration: 6.42775 seconds
 id: 1a162898-12fc-4fd9-af45-40c3cf7dc3a0, duration: 6.42775 seconds
 id: 307-127539-0005-5149-0_sp1.1, duration: 3.37275 seconds
 id: 2989-138028-0051-2969-0_sp0.9, duration: 3.338875 seconds
 id: 89cd0b24-a125-4ee3-b68a-2cc856beb816, duration: 6.42775 seconds
 id: 314ecf14-be73-49fd-a3ed-80948c1faa19, duration: 6.42775 seconds
 id: 10674199-afb7-480e-9e0d-d0f5be9c1228, duration: 6.42775 seconds
 id: 6f08a4dc-917a-41cb-8230-9df6abd6051a, duration: 6.42775 seconds
 id: 61e2fe71-44ef-4c69-9b06-e41967289b84, duration: 6.42775 seconds
 id: 322-124147-0020-14699-0_sp1.1, duration: 3.0909375 seconds
 id: c5bb01bc-2d31-439b-8ce5-3ffad03d38e1, duration: 6.42775 seconds
 id: 587-54108-0046-5636-0_sp0.9, duration: 2.97225 seconds
 id: 3259-158083-0035-18053-0, duration: 2.97 seconds
 id: 1455-134435-0043-1109-0, duration: 2.795 seconds
 id: 7148-59157-0036-4585-0_sp0.9, duration: 2.77775 seconds
 id: 118-47824-0073-11418-0_sp0.9, duration: 2.75 seconds
 id: 37f784aa-0578-4d87-9363-56eddb4d73a5, duration: 6.42775 seconds
 id: 118-121721-0019-11473-0_sp1.1, duration: 2.67725 seconds
 id: 7a678233-e305-4f68-add1-ea1c389c1073, duration: 6.42775 seconds
 id: 1fb1d7d5-320a-4b6a-a32a-7b2e4d440e2d, duration: 6.42775 seconds
 id: 1355-39947-0041-14250-0_sp1.1, duration: 2.5181875 seconds
 id: 69902f44-4032-4d83-b2b1-7aefb81316ee, duration: 6.42775 seconds
 id: 0ef622be-854c-488b-af16-6f79c8eb0b57, duration: 6.42775 seconds
 id: 839-130898-0097-1892-0_sp0.9, duration: 2.488875 seconds
 id: 79350116-59df-4f3d-836b-f20614673248, duration: 6.42775 seconds
 id: 8629-261140-0006-27476-0, duration: 2.375 seconds
 id: 67ecb3b0-c6d3-498a-9ebc-6847c8cb5a5c, duration: 6.42775 seconds
 id: 408fb1c7-4bcf-473f-bf9f-4fe266c988f4, duration: 6.42775 seconds
 id: 1513a7e9-a422-4557-bd9c-7d50a62dfdb2, duration: 6.42775 seconds
 id: 9a84d063-ba93-4a63-80c7-0ef45d92abf9, duration: 6.42775 seconds
 id: 19-198-0000-6573-0, duration: 1.965 seconds
 total duration: 510.634 s
 max duration: 6.428 s


[EDITED]
Part of the training log is shown below. You can see that the total duration in a batch is about 400 seconds most of the time.

2021-08-15 10:13:33,228 INFO [train-960.py:559] Epoch 0, batch 250, batch avg ctc loss 2.1768, batch avg att loss 0.9815, batch avg loss 1.3401, total
 avg ctc loss: 2.3224, total avg att loss: 0.9080, total avg loss: 1.3323, batch size: 28, total duration: 395.252 s, max cut duration: 14.150 s
2021-08-15 10:13:47,234 INFO [train-960.py:559] Epoch 0, batch 260, batch avg ctc loss 1.6297, batch avg att loss 0.9172, batch avg loss 1.1309, total
 avg ctc loss: 2.2863, total avg att loss: 0.9067, total avg loss: 1.3206, batch size: 29, total duration: 398.832 s, max cut duration: 13.791 s
2021-08-15 10:14:01,180 INFO [train-960.py:559] Epoch 0, batch 270, batch avg ctc loss 2.2542, batch avg att loss 0.8885, batch avg loss 1.2982, total
 avg ctc loss: 2.2934, total avg att loss: 0.9076, total avg loss: 1.3233, batch size: 28, total duration: 389.664 s, max cut duration: 13.970 s
2021-08-15 10:14:13,836 INFO [train-960.py:559] Epoch 0, batch 280, batch avg ctc loss 1.9333, batch avg att loss 0.8957, batch avg loss 1.2070, total
 avg ctc loss: 2.2604, total avg att loss: 0.9041, total avg loss: 1.3110, batch size: 26, total duration: 390.496 s, max cut duration: 15.050 s
2021-08-15 10:14:26,862 INFO [train-960.py:559] Epoch 0, batch 290, batch avg ctc loss 1.7030, batch avg att loss 0.7987, batch avg loss 1.0700, total
 avg ctc loss: 2.2311, total avg att loss: 0.9005, total avg loss: 1.2997, batch size: 23, total duration: 396.639 s, max cut duration: 17.350 s
2021-08-15 10:14:39,908 INFO [train-960.py:559] Epoch 0, batch 300, batch avg ctc loss 2.7862, batch avg att loss 0.8740, batch avg loss 1.4477, total
 avg ctc loss: 2.2282, total avg att loss: 0.9000, total avg loss: 1.2984, batch size: 33, total duration: 400.319 s, max cut duration: 12.265 s
2021-08-15 10:14:54,443 INFO [train-960.py:559] Epoch 0, batch 310, batch avg ctc loss 2.0327, batch avg att loss 0.9512, batch avg loss 1.2756, total
 avg ctc loss: 2.2188, total avg att loss: 0.8971, total avg loss: 1.2936, batch size: 30, total duration: 400.378 s, max cut duration: 13.395 s
2021-08-15 10:15:08,943 INFO [train-960.py:559] Epoch 0, batch 320, batch avg ctc loss 4.0124, batch avg att loss 0.8752, batch avg loss 1.8164, total
 avg ctc loss: 2.1969, total avg att loss: 0.8957, total avg loss: 1.2861, batch size: 93, total duration: 507.692 s, max cut duration: 6.428 s
2021-08-15 10:15:22,945 INFO [train-960.py:559] Epoch 0, batch 330, batch avg ctc loss 1.8711, batch avg att loss 0.8965, batch avg loss 1.1889, total
 avg ctc loss: 2.1873, total avg att loss: 0.8940, total avg loss: 1.2820, batch size: 31, total duration: 398.022 s, max cut duration: 12.900 s
2021-08-15 10:15:36,046 INFO [train-960.py:559] Epoch 0, batch 340, batch avg ctc loss 1.7361, batch avg att loss 0.8450, batch avg loss 1.1123, total
 avg ctc loss: 2.1540, total avg att loss: 0.8911, total avg loss: 1.2700, batch size: 26, total duration: 400.698 s, max cut duration: 15.456 s
2021-08-15 10:15:48,664 INFO [train-960.py:559] Epoch 0, batch 350, batch avg ctc loss 1.6065, batch avg att loss 0.8385, batch avg loss 1.0689, total
 avg ctc loss: 2.1344, total avg att loss: 0.8880, total avg loss: 1.2619, batch size: 26, total duration: 400.479 s, max cut duration: 15.465 s

@pzelasko
Copy link
Collaborator

It is likely due to padding: the sampler counts the duration of non padded cuts (=speech only duration for librispeech). Then inside dataset the padding happens and adds extra duration (also true with transforms such as cut concatenate).

@csukuangfj
Copy link
Contributor

It is likely due to padding: the sampler counts the duration of non padded cuts (=speech only duration for librispeech). Then inside dataset the padding happens and adds extra duration (also true with transforms such as cut concatenate).

I see. Thanks!

BTW: I use --concatenate-cuts=0.

@pzelasko
Copy link
Collaborator

Can I close this issue?

@csukuangfj
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants