Skip to content

Commit 19cc873

Browse files
authored
Use torch.testing.assert_close for better signals (#1742)
1 parent 5afc6a0 commit 19cc873

9 files changed

+342
-342
lines changed

apex/contrib/test/fmha/test_fmha.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool):
9696
ctx = ctx.view(b,s,h,d)
9797

9898
ctx_ref = py_mha(qkv, amask, b,s,h,d)
99-
self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))
99+
torch.testing.assert_close(ctx_ref.float(), ctx.float(), atol=1e-3)
100100

101101
labels = torch.randn_like(ctx_ref)
102102
diff = ctx_ref - labels
@@ -114,7 +114,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool):
114114

115115
dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)
116116

117-
self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
117+
torch.testing.assert_close(qkv.grad.float(), dqkv2.float(), atol=1e-3)
118118

119119
def test_128(self):
120120
self.run_test(128, 32, False)
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,167 @@
1-
import unittest
2-
3-
import torch
4-
5-
SKIP_TEST = None
6-
try:
7-
from apex.contrib.transducer import TransducerJoint
8-
from apex.contrib.transducer import _transducer_ref as transducer_ref
9-
except ImportError as e:
10-
SKIP_TEST = e
11-
12-
13-
@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
14-
class TransducerJointTest(unittest.TestCase):
15-
def setUp(self, seed=1234):
16-
torch.manual_seed(seed)
17-
18-
def gen_input(self, for_vector_kernel):
19-
self.B = 4
20-
T_min = 51
21-
T_max = 101
22-
U_min = 12
23-
U_max = 25
24-
if for_vector_kernel:
25-
H = 512
26-
else:
27-
H = 509
28-
dtype = torch.float16
29-
device = "cuda"
30-
31-
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
32-
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
33-
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
34-
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
35-
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
36-
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
37-
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
38-
self.dropout_prob = 0.5
39-
40-
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
41-
# the loss function
42-
for b in range(self.B):
43-
self.h_grad[b, self.f_len[b]:, :, :] = 0
44-
self.h_grad[b, :, self.g_len[b]:, :] = 0
45-
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
46-
47-
48-
def _pack(self, x, f_len, g_len):
49-
B = x.size(0)
50-
list_x = []
51-
for b in range(B):
52-
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
53-
x_row = torch.cat(list_x_row)
54-
list_x.append(x_row)
55-
x_packed = torch.cat(list_x).data.clone()
56-
x_packed.requires_grad = True
57-
batch_offset = torch.cumsum(f_len * g_len, dim=0)
58-
return x_packed
59-
60-
def _unpack(self, x, f_len, g_len):
61-
batch_offset = torch.cumsum(f_len * g_len, dim=0)
62-
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
63-
B = self.h_grad.size(0)
64-
H = self.h_grad.size(-1)
65-
for b in range(B):
66-
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
67-
my_f_len = f_len[b]
68-
my_g_len = g_len[b]
69-
for t in range(my_f_len):
70-
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
71-
my_batch_offset + t*my_g_len + my_g_len]
72-
return x_unpacked
73-
74-
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
75-
self.gen_input(for_vector_kernel=for_vector_kernel)
76-
# Generate reference
77-
f_ref = self.f_tst.data.clone()
78-
g_ref = self.g_tst.data.clone()
79-
f_ref.requires_grad = True
80-
g_ref.requires_grad = True
81-
82-
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
83-
dropout_prob=self.dropout_prob, probe_mask=True)
84-
if not pack_output:
85-
h_tst = my_joint( f=self.f_tst,
86-
g=self.g_tst,
87-
f_len=self.f_len,
88-
g_len=self.g_len)
89-
h_tst.backward(self.h_grad)
90-
if dropout:
91-
mask = my_joint.mask_probe[0]
92-
else:
93-
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
94-
h_tst = my_joint( f=self.f_tst,
95-
g=self.g_tst,
96-
f_len=self.f_len,
97-
g_len=self.g_len,
98-
batch_offset=batch_offset,
99-
packed_batch=batch_offset[-1])
100-
h_tst.backward(self.h_grad_packed)
101-
if dropout:
102-
mask_packed = my_joint.mask_probe[0]
103-
mask = self._unpack(mask_packed, self.f_len, self.g_len)
104-
105-
# reference
106-
h_ref, f_grad_ref, g_grad_ref \
107-
= transducer_ref.transducer_joint_reference(f=f_ref,
108-
g=g_ref,
109-
h_grad=self.h_grad,
110-
f_len=self.f_len,
111-
g_len=self.g_len,
112-
pack_output=pack_output,
113-
relu=relu,
114-
dropout=dropout,
115-
dropout_prob=self.dropout_prob,
116-
mask=mask if dropout else None)
117-
118-
f_grad_tst = self.f_tst.grad
119-
g_grad_tst = self.g_tst.grad
120-
121-
self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
122-
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
123-
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))
124-
125-
def test_transducer_joint(self):
126-
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
127-
128-
def test_transducer_joint_vec(self):
129-
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
130-
131-
def test_transducer_joint_pack(self):
132-
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
133-
134-
def test_transducer_joint_vec_pack(self):
135-
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
136-
137-
def test_transducer_joint_relu(self):
138-
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
139-
140-
def test_transducer_joint_vec_relu(self):
141-
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
142-
143-
def test_transducer_joint_pack_relu(self):
144-
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
145-
146-
def test_transducer_joint_vec_pack_relu(self):
147-
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
148-
149-
@unittest.expectedFailure
150-
def test_transducer_joint_relu_dropout(self):
151-
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
152-
153-
@unittest.expectedFailure
154-
def test_transducer_joint_vec_relu_dropout(self):
155-
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
156-
157-
@unittest.expectedFailure
158-
def test_transducer_joint_pack_relu_dropout(self):
159-
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
160-
161-
@unittest.expectedFailure
162-
def test_transducer_joint_vec_pack_relu_dropout(self):
163-
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
164-
165-
166-
if __name__ == '__main__':
167-
unittest.main()
1+
import unittest
2+
3+
import torch
4+
5+
SKIP_TEST = None
6+
try:
7+
from apex.contrib.transducer import TransducerJoint
8+
from apex.contrib.transducer import _transducer_ref as transducer_ref
9+
except ImportError as e:
10+
SKIP_TEST = e
11+
12+
13+
@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
14+
class TransducerJointTest(unittest.TestCase):
15+
def setUp(self, seed=1234):
16+
torch.manual_seed(seed)
17+
18+
def gen_input(self, for_vector_kernel):
19+
self.B = 4
20+
T_min = 51
21+
T_max = 101
22+
U_min = 12
23+
U_max = 25
24+
if for_vector_kernel:
25+
H = 512
26+
else:
27+
H = 509
28+
dtype = torch.float16
29+
device = "cuda"
30+
31+
self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
32+
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
33+
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
34+
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
35+
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
36+
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
37+
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
38+
self.dropout_prob = 0.5
39+
40+
# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
41+
# the loss function
42+
for b in range(self.B):
43+
self.h_grad[b, self.f_len[b]:, :, :] = 0
44+
self.h_grad[b, :, self.g_len[b]:, :] = 0
45+
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)
46+
47+
48+
def _pack(self, x, f_len, g_len):
49+
B = x.size(0)
50+
list_x = []
51+
for b in range(B):
52+
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
53+
x_row = torch.cat(list_x_row)
54+
list_x.append(x_row)
55+
x_packed = torch.cat(list_x).data.clone()
56+
x_packed.requires_grad = True
57+
batch_offset = torch.cumsum(f_len * g_len, dim=0)
58+
return x_packed
59+
60+
def _unpack(self, x, f_len, g_len):
61+
batch_offset = torch.cumsum(f_len * g_len, dim=0)
62+
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
63+
B = self.h_grad.size(0)
64+
H = self.h_grad.size(-1)
65+
for b in range(B):
66+
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
67+
my_f_len = f_len[b]
68+
my_g_len = g_len[b]
69+
for t in range(my_f_len):
70+
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
71+
my_batch_offset + t*my_g_len + my_g_len]
72+
return x_unpacked
73+
74+
def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
75+
self.gen_input(for_vector_kernel=for_vector_kernel)
76+
# Generate reference
77+
f_ref = self.f_tst.data.clone()
78+
g_ref = self.g_tst.data.clone()
79+
f_ref.requires_grad = True
80+
g_ref.requires_grad = True
81+
82+
my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
83+
dropout_prob=self.dropout_prob, probe_mask=True)
84+
if not pack_output:
85+
h_tst = my_joint( f=self.f_tst,
86+
g=self.g_tst,
87+
f_len=self.f_len,
88+
g_len=self.g_len)
89+
h_tst.backward(self.h_grad)
90+
if dropout:
91+
mask = my_joint.mask_probe[0]
92+
else:
93+
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
94+
h_tst = my_joint( f=self.f_tst,
95+
g=self.g_tst,
96+
f_len=self.f_len,
97+
g_len=self.g_len,
98+
batch_offset=batch_offset,
99+
packed_batch=batch_offset[-1])
100+
h_tst.backward(self.h_grad_packed)
101+
if dropout:
102+
mask_packed = my_joint.mask_probe[0]
103+
mask = self._unpack(mask_packed, self.f_len, self.g_len)
104+
105+
# reference
106+
h_ref, f_grad_ref, g_grad_ref \
107+
= transducer_ref.transducer_joint_reference(f=f_ref,
108+
g=g_ref,
109+
h_grad=self.h_grad,
110+
f_len=self.f_len,
111+
g_len=self.g_len,
112+
pack_output=pack_output,
113+
relu=relu,
114+
dropout=dropout,
115+
dropout_prob=self.dropout_prob,
116+
mask=mask if dropout else None)
117+
118+
f_grad_tst = self.f_tst.grad
119+
g_grad_tst = self.g_tst.grad
120+
121+
torch.testing.assert_close(h_ref, h_tst, atol=1e-5, rtol=1e-5)
122+
torch.testing.assert_close(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5)
123+
torch.testing.assert_close(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)
124+
125+
def test_transducer_joint(self):
126+
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
127+
128+
def test_transducer_joint_vec(self):
129+
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)
130+
131+
def test_transducer_joint_pack(self):
132+
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)
133+
134+
def test_transducer_joint_vec_pack(self):
135+
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)
136+
137+
def test_transducer_joint_relu(self):
138+
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
139+
140+
def test_transducer_joint_vec_relu(self):
141+
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)
142+
143+
def test_transducer_joint_pack_relu(self):
144+
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)
145+
146+
def test_transducer_joint_vec_pack_relu(self):
147+
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)
148+
149+
@unittest.expectedFailure
150+
def test_transducer_joint_relu_dropout(self):
151+
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
152+
153+
@unittest.expectedFailure
154+
def test_transducer_joint_vec_relu_dropout(self):
155+
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)
156+
157+
@unittest.expectedFailure
158+
def test_transducer_joint_pack_relu_dropout(self):
159+
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)
160+
161+
@unittest.expectedFailure
162+
def test_transducer_joint_vec_pack_relu_dropout(self):
163+
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)
164+
165+
166+
if __name__ == '__main__':
167+
unittest.main()

0 commit comments

Comments
 (0)