-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use
torch.testing.assert_close
for better signals (#1742)
- Loading branch information
Showing
9 changed files
with
342 additions
and
342 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
334 changes: 167 additions & 167 deletions
334
apex/contrib/test/transducer/test_transducer_joint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,167 +1,167 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
SKIP_TEST = None | ||
try: | ||
from apex.contrib.transducer import TransducerJoint | ||
from apex.contrib.transducer import _transducer_ref as transducer_ref | ||
except ImportError as e: | ||
SKIP_TEST = e | ||
|
||
|
||
@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") | ||
class TransducerJointTest(unittest.TestCase): | ||
def setUp(self, seed=1234): | ||
torch.manual_seed(seed) | ||
|
||
def gen_input(self, for_vector_kernel): | ||
self.B = 4 | ||
T_min = 51 | ||
T_max = 101 | ||
U_min = 12 | ||
U_max = 25 | ||
if for_vector_kernel: | ||
H = 512 | ||
else: | ||
H = 509 | ||
dtype = torch.float16 | ||
device = "cuda" | ||
|
||
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device) | ||
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device) | ||
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device) | ||
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) | ||
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) | ||
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max | ||
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max | ||
self.dropout_prob = 0.5 | ||
|
||
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by | ||
# the loss function | ||
for b in range(self.B): | ||
self.h_grad[b, self.f_len[b]:, :, :] = 0 | ||
self.h_grad[b, :, self.g_len[b]:, :] = 0 | ||
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len) | ||
|
||
|
||
def _pack(self, x, f_len, g_len): | ||
B = x.size(0) | ||
list_x = [] | ||
for b in range(B): | ||
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])] | ||
x_row = torch.cat(list_x_row) | ||
list_x.append(x_row) | ||
x_packed = torch.cat(list_x).data.clone() | ||
x_packed.requires_grad = True | ||
batch_offset = torch.cumsum(f_len * g_len, dim=0) | ||
return x_packed | ||
|
||
def _unpack(self, x, f_len, g_len): | ||
batch_offset = torch.cumsum(f_len * g_len, dim=0) | ||
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8) | ||
B = self.h_grad.size(0) | ||
H = self.h_grad.size(-1) | ||
for b in range(B): | ||
my_batch_offset = 0 if b == 0 else batch_offset[b-1] | ||
my_f_len = f_len[b] | ||
my_g_len = g_len[b] | ||
for t in range(my_f_len): | ||
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : | ||
my_batch_offset + t*my_g_len + my_g_len] | ||
return x_unpacked | ||
|
||
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): | ||
self.gen_input(for_vector_kernel=for_vector_kernel) | ||
# Generate reference | ||
f_ref = self.f_tst.data.clone() | ||
g_ref = self.g_tst.data.clone() | ||
f_ref.requires_grad = True | ||
g_ref.requires_grad = True | ||
|
||
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, | ||
dropout_prob=self.dropout_prob, probe_mask=True) | ||
if not pack_output: | ||
h_tst = my_joint( f=self.f_tst, | ||
g=self.g_tst, | ||
f_len=self.f_len, | ||
g_len=self.g_len) | ||
h_tst.backward(self.h_grad) | ||
if dropout: | ||
mask = my_joint.mask_probe[0] | ||
else: | ||
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) | ||
h_tst = my_joint( f=self.f_tst, | ||
g=self.g_tst, | ||
f_len=self.f_len, | ||
g_len=self.g_len, | ||
batch_offset=batch_offset, | ||
packed_batch=batch_offset[-1]) | ||
h_tst.backward(self.h_grad_packed) | ||
if dropout: | ||
mask_packed = my_joint.mask_probe[0] | ||
mask = self._unpack(mask_packed, self.f_len, self.g_len) | ||
|
||
# reference | ||
h_ref, f_grad_ref, g_grad_ref \ | ||
= transducer_ref.transducer_joint_reference(f=f_ref, | ||
g=g_ref, | ||
h_grad=self.h_grad, | ||
f_len=self.f_len, | ||
g_len=self.g_len, | ||
pack_output=pack_output, | ||
relu=relu, | ||
dropout=dropout, | ||
dropout_prob=self.dropout_prob, | ||
mask=mask if dropout else None) | ||
|
||
f_grad_tst = self.f_tst.grad | ||
g_grad_tst = self.g_tst.grad | ||
|
||
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5)) | ||
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5)) | ||
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)) | ||
|
||
def test_transducer_joint(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_vec(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_pack(self): | ||
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_vec_pack(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) | ||
|
||
def test_transducer_joint_vec_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) | ||
|
||
def test_transducer_joint_pack_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) | ||
|
||
def test_transducer_joint_vec_pack_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_vec_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_pack_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_vec_pack_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() | ||
import unittest | ||
|
||
import torch | ||
|
||
SKIP_TEST = None | ||
try: | ||
from apex.contrib.transducer import TransducerJoint | ||
from apex.contrib.transducer import _transducer_ref as transducer_ref | ||
except ImportError as e: | ||
SKIP_TEST = e | ||
|
||
|
||
@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") | ||
class TransducerJointTest(unittest.TestCase): | ||
def setUp(self, seed=1234): | ||
torch.manual_seed(seed) | ||
|
||
def gen_input(self, for_vector_kernel): | ||
self.B = 4 | ||
T_min = 51 | ||
T_max = 101 | ||
U_min = 12 | ||
U_max = 25 | ||
if for_vector_kernel: | ||
H = 512 | ||
else: | ||
H = 509 | ||
dtype = torch.float16 | ||
device = "cuda" | ||
|
||
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device) | ||
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device) | ||
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device) | ||
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) | ||
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) | ||
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max | ||
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max | ||
self.dropout_prob = 0.5 | ||
|
||
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by | ||
# the loss function | ||
for b in range(self.B): | ||
self.h_grad[b, self.f_len[b]:, :, :] = 0 | ||
self.h_grad[b, :, self.g_len[b]:, :] = 0 | ||
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len) | ||
|
||
|
||
def _pack(self, x, f_len, g_len): | ||
B = x.size(0) | ||
list_x = [] | ||
for b in range(B): | ||
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])] | ||
x_row = torch.cat(list_x_row) | ||
list_x.append(x_row) | ||
x_packed = torch.cat(list_x).data.clone() | ||
x_packed.requires_grad = True | ||
batch_offset = torch.cumsum(f_len * g_len, dim=0) | ||
return x_packed | ||
|
||
def _unpack(self, x, f_len, g_len): | ||
batch_offset = torch.cumsum(f_len * g_len, dim=0) | ||
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8) | ||
B = self.h_grad.size(0) | ||
H = self.h_grad.size(-1) | ||
for b in range(B): | ||
my_batch_offset = 0 if b == 0 else batch_offset[b-1] | ||
my_f_len = f_len[b] | ||
my_g_len = g_len[b] | ||
for t in range(my_f_len): | ||
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : | ||
my_batch_offset + t*my_g_len + my_g_len] | ||
return x_unpacked | ||
|
||
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): | ||
self.gen_input(for_vector_kernel=for_vector_kernel) | ||
# Generate reference | ||
f_ref = self.f_tst.data.clone() | ||
g_ref = self.g_tst.data.clone() | ||
f_ref.requires_grad = True | ||
g_ref.requires_grad = True | ||
|
||
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, | ||
dropout_prob=self.dropout_prob, probe_mask=True) | ||
if not pack_output: | ||
h_tst = my_joint( f=self.f_tst, | ||
g=self.g_tst, | ||
f_len=self.f_len, | ||
g_len=self.g_len) | ||
h_tst.backward(self.h_grad) | ||
if dropout: | ||
mask = my_joint.mask_probe[0] | ||
else: | ||
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) | ||
h_tst = my_joint( f=self.f_tst, | ||
g=self.g_tst, | ||
f_len=self.f_len, | ||
g_len=self.g_len, | ||
batch_offset=batch_offset, | ||
packed_batch=batch_offset[-1]) | ||
h_tst.backward(self.h_grad_packed) | ||
if dropout: | ||
mask_packed = my_joint.mask_probe[0] | ||
mask = self._unpack(mask_packed, self.f_len, self.g_len) | ||
|
||
# reference | ||
h_ref, f_grad_ref, g_grad_ref \ | ||
= transducer_ref.transducer_joint_reference(f=f_ref, | ||
g=g_ref, | ||
h_grad=self.h_grad, | ||
f_len=self.f_len, | ||
g_len=self.g_len, | ||
pack_output=pack_output, | ||
relu=relu, | ||
dropout=dropout, | ||
dropout_prob=self.dropout_prob, | ||
mask=mask if dropout else None) | ||
|
||
f_grad_tst = self.f_tst.grad | ||
g_grad_tst = self.g_tst.grad | ||
|
||
torch.testing.assert_close(h_ref, h_tst, atol=1e-5, rtol=1e-5) | ||
torch.testing.assert_close(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5) | ||
torch.testing.assert_close(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4) | ||
|
||
def test_transducer_joint(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_vec(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_pack(self): | ||
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_vec_pack(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) | ||
|
||
def test_transducer_joint_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) | ||
|
||
def test_transducer_joint_vec_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) | ||
|
||
def test_transducer_joint_pack_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) | ||
|
||
def test_transducer_joint_vec_pack_relu(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_vec_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_pack_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) | ||
|
||
@unittest.expectedFailure | ||
def test_transducer_joint_vec_pack_relu_dropout(self): | ||
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.