Skip to content

Commit

Permalink
Use torch.testing.assert_close for better signals (#1742)
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored Oct 17, 2023
1 parent 5afc6a0 commit 19cc873
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 342 deletions.
4 changes: 2 additions & 2 deletions apex/contrib/test/fmha/test_fmha.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool):
ctx = ctx.view(b,s,h,d)

ctx_ref = py_mha(qkv, amask, b,s,h,d)
self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3))
torch.testing.assert_close(ctx_ref.float(), ctx.float(), atol=1e-3)

labels = torch.randn_like(ctx_ref)
diff = ctx_ref - labels
Expand All @@ -114,7 +114,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool):

dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d)

self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3))
torch.testing.assert_close(qkv.grad.float(), dqkv2.float(), atol=1e-3)

def test_128(self):
self.run_test(128, 32, False)
Expand Down
334 changes: 167 additions & 167 deletions apex/contrib/test/transducer/test_transducer_joint.py
Original file line number Diff line number Diff line change
@@ -1,167 +1,167 @@
import unittest

import torch

SKIP_TEST = None
try:
from apex.contrib.transducer import TransducerJoint
from apex.contrib.transducer import _transducer_ref as transducer_ref
except ImportError as e:
SKIP_TEST = e


@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)

def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"

self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5

# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)


def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed

def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked

def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True

my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)

# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)

f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad

self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4))

def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)

def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)

def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)

def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)

def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)

def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)

def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)

def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)

@unittest.expectedFailure
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)

@unittest.expectedFailure
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)

@unittest.expectedFailure
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)

@unittest.expectedFailure
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)


if __name__ == '__main__':
unittest.main()
import unittest

import torch

SKIP_TEST = None
try:
from apex.contrib.transducer import TransducerJoint
from apex.contrib.transducer import _transducer_ref as transducer_ref
except ImportError as e:
SKIP_TEST = e


@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
class TransducerJointTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)

def gen_input(self, for_vector_kernel):
self.B = 4
T_min = 51
T_max = 101
U_min = 12
U_max = 25
if for_vector_kernel:
H = 512
else:
H = 509
dtype = torch.float16
device = "cuda"

self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device)
self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device)
self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device)
self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device)
self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device)
self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max
self.dropout_prob = 0.5

# Make sure gradients from out-of-bound locations are zero. This should be guaranteed by
# the loss function
for b in range(self.B):
self.h_grad[b, self.f_len[b]:, :, :] = 0
self.h_grad[b, :, self.g_len[b]:, :] = 0
self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len)


def _pack(self, x, f_len, g_len):
B = x.size(0)
list_x = []
for b in range(B):
list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])]
x_row = torch.cat(list_x_row)
list_x.append(x_row)
x_packed = torch.cat(list_x).data.clone()
x_packed.requires_grad = True
batch_offset = torch.cumsum(f_len * g_len, dim=0)
return x_packed

def _unpack(self, x, f_len, g_len):
batch_offset = torch.cumsum(f_len * g_len, dim=0)
x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8)
B = self.h_grad.size(0)
H = self.h_grad.size(-1)
for b in range(B):
my_batch_offset = 0 if b == 0 else batch_offset[b-1]
my_f_len = f_len[b]
my_g_len = g_len[b]
for t in range(my_f_len):
x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len :
my_batch_offset + t*my_g_len + my_g_len]
return x_unpacked

def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout):
self.gen_input(for_vector_kernel=for_vector_kernel)
# Generate reference
f_ref = self.f_tst.data.clone()
g_ref = self.g_tst.data.clone()
f_ref.requires_grad = True
g_ref.requires_grad = True

my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout,
dropout_prob=self.dropout_prob, probe_mask=True)
if not pack_output:
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len)
h_tst.backward(self.h_grad)
if dropout:
mask = my_joint.mask_probe[0]
else:
batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0)
h_tst = my_joint( f=self.f_tst,
g=self.g_tst,
f_len=self.f_len,
g_len=self.g_len,
batch_offset=batch_offset,
packed_batch=batch_offset[-1])
h_tst.backward(self.h_grad_packed)
if dropout:
mask_packed = my_joint.mask_probe[0]
mask = self._unpack(mask_packed, self.f_len, self.g_len)

# reference
h_ref, f_grad_ref, g_grad_ref \
= transducer_ref.transducer_joint_reference(f=f_ref,
g=g_ref,
h_grad=self.h_grad,
f_len=self.f_len,
g_len=self.g_len,
pack_output=pack_output,
relu=relu,
dropout=dropout,
dropout_prob=self.dropout_prob,
mask=mask if dropout else None)

f_grad_tst = self.f_tst.grad
g_grad_tst = self.g_tst.grad

torch.testing.assert_close(h_ref, h_tst, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)

def test_transducer_joint(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)

def test_transducer_joint_vec(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False)

def test_transducer_joint_pack(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False)

def test_transducer_joint_vec_pack(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False)

def test_transducer_joint_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)

def test_transducer_joint_vec_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False)

def test_transducer_joint_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False)

def test_transducer_joint_vec_pack_relu(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False)

@unittest.expectedFailure
def test_transducer_joint_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)

@unittest.expectedFailure
def test_transducer_joint_vec_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True)

@unittest.expectedFailure
def test_transducer_joint_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True)

@unittest.expectedFailure
def test_transducer_joint_vec_pack_relu_dropout(self):
self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 19cc873

Please sign in to comment.