-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-807] Support integer label type in ctc_loss operator #12468
Changes from 14 commits
f0a757b
7af7274
5e99e7e
eb30964
774c61b
d9dc6e6
59f48f2
1b3d141
299b1e7
ec5cc3c
59e5d7c
c8b7cd4
973daca
217069e
4574c7c
fa61a0a
3fbb3f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4516,6 +4516,27 @@ def test_ctc_loss(): | |
true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch | ||
check_ctc_loss(acts2, labels2, true_loss) | ||
|
||
# Test 3: check use integer type as label | ||
labels3 = np.array([[2, 3, 1], [2, 0, 0]], dtype=np.int32) | ||
true_loss = np.array([7.3557, 5.4091], dtype=np.float32) # from Torch | ||
check_ctc_loss(acts2, labels3, true_loss) | ||
|
||
@with_seed(1) | ||
def test_ctc_loss_with_large_classes(): | ||
ctx = default_context() | ||
batch_size = 1024 | ||
seq_len = 35 | ||
label_len = 10 | ||
num_classes = 6000 | ||
x = np.random.uniform(size=(seq_len, batch_size, num_classes)) | ||
y = np.random.randint(0, num_classes, size=(batch_size, label_len)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again this does not seem like a good way of testing this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any suggestion to test the large classes? I could compare this with WarpCtc implementation result if that can be treated as golden. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make a small example, calculate a the value and test for that, like in any other CTC tests. Since this is for testing the type, the batch size and sequence lengths are irrelevant. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The label type is tested in line 4520. This testcase is to test the large number of classes that would crash reported in issue #10995 |
||
|
||
data = mx.nd.array(x, ctx=ctx) | ||
label = mx.nd.array(y, ctx=ctx) | ||
loss = mx.nd.contrib.ctc_loss(data=data, label=label) | ||
loss = mx.nd.make_loss(loss) | ||
expected_output_sum = 282733.95318603516 | ||
assert np.isclose(sum(loss.asnumpy()), expected_output_sum) | ||
|
||
@with_seed() | ||
def test_ctc_loss_grad(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this help to verify the correctness?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simply used the example reported in the original issue to make sure this fix addressed that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue that needs testing is the type of the labels, so a large batch size doesn't seem helpful or necessary for verifying the correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, tests with fixed seed are treated as a test quality issue and are being eliminated right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The label type is tested in line 4520. This testcase is to test the large number of classes that would crash reported in issue #10995
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, then make a test for it. Batch size is still not relevant, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not the batch_size in training. It is the size of the vocabulary. We need this variable to create the 3D tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the variable name and removed the fixed seed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it really is not, the vocabulary size, regardless of how you name it. Please check the API doc and see its usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my misunderstanding the API. I have updated the unit tests based on your suggestion. Please review it again. Thanks!