forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_decomp.py
718 lines (618 loc) · 28 KB
/
test_decomp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
# Owner(s): ["module: primTorch", "module: decompositions"]
from collections import defaultdict
from torch import Tensor
import torch.autograd
from torch._decomp import decomposition_table
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
TestCase,
skipIfCrossRef,
suppress_warnings,
TEST_WITH_ASAN,
run_tests,
skipIfTorchDynamo,
)
from torch.testing._internal.common_device_type import (
onlyNativeDeviceTypes,
ops,
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import has_key, DispatchKey
import itertools
import functools
from functools import partial
import unittest
aten = torch.ops.aten
# TODO: this isn't going to work with non-aten namespaces
def overload_to_aten_name(overload):
return overload._schema.name.split("::")[1]
# All operators that can have decomp tests
decomposition_names = {overload_to_aten_name(k) for k in decomposition_table}
_decomp_test_ops = [
op
for op in op_db
if op.aten_name in decomposition_names
or op.aten_backward_name in decomposition_names
]
def diff_arg(arg, requires_grad=True):
def is_differentiable_arg(arg):
if requires_grad:
return arg.requires_grad
else:
return arg.is_floating_point() or arg.is_complex()
if is_iterable_of_tensors(arg):
if all([is_differentiable_arg(a) for a in arg]):
return True
if all([not is_differentiable_arg(a) for a in arg]):
return False
raise RuntimeError("NYI: The test runner can't handle this")
return isinstance(arg, Tensor) and is_differentiable_arg(arg)
# Version of autograd.grad with some differences:
# - pytree inputs is allowed (but leaves of the pytree have to all
# be tensors)
# - if an input is not used as part of derivatives, we will return a
# zero-filled tensor for the result
def _autograd_grad(
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
inputs, inputs_spec = tree_flatten(inputs)
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
diff_grad_outputs = [
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
]
if len(diff_grad_outputs) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
grad_inputs = torch.autograd.grad(
diff_outputs,
diff_inputs,
grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True,
)
result = []
grad_inputs_iter = iter(grad_inputs)
for inp in inputs:
if inp.requires_grad:
grad_input = next(grad_inputs_iter)
if grad_input is None:
result.append(torch.zeros_like(inp))
else:
result.append(grad_input)
else:
result.append(torch.zeros_like(inp))
return tree_unflatten(result, inputs_spec)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
def ref_vjp_no_create(f, *primals):
result = f(*primals)
def wrapped(cotangents):
return _autograd_grad(
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False
)
return result, wrapped
dtype_precisions = {
torch.float16: (0.001, 1e-5),
torch.bfloat16: (0.016, 1e-4),
torch.float32: (1.3e-6, 1e-5),
torch.float64: (1e-7, 1e-7),
torch.complex32: (0.001, 1e-5),
torch.complex64: (1.3e-6, 1e-5),
torch.complex128: (1e-7, 1e-7),
}
# Returns the "default" rtol and atol for comparing scalars or
# tensors of the given dtypes.
def _getDefaultRtolAndAtol(dtype0, dtype1):
rtol = max(
dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
)
atol = max(
dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
)
return rtol, atol
def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
assert orig.dtype == decomp.dtype, f"{i} Operation: {op}"
if orig.numel() == 0 or decomp.numel() == 0:
assert orig.numel() == decomp.numel()
return
assert orig.shape == decomp.shape, f"{i} Operation: {op}"
tol_table = {
(torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3,
(torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
(torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
(torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
(torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
(torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
(torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
(torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
(torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
(torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
(torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7,
(torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
}
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()
decomp_diff = (decomp - ref).abs().max()
atol = tol_table.get((test_dtype, op), 1e-7)
if decomp_diff > orig_diff + atol:
raise RuntimeError(
f"Difference from float64 is larger with decomposition {op.__name__}"
f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
f"atol = {atol}\n"
f"args = {args}\n"
f"kwargs = {kwargs}"
)
else:
test_case.assertEqual(
orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
)
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
test_case.assertEqual(
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
# Before adding an entry to this table, make sure your decomposition is right :)
tol_table = {
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
(torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
(torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
1e-3,
1e-3,
),
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
# This exceeds default tolerances only on CPU, on CUDA it's fine
(torch.float32, torch.ops.aten.grid_sampler_2d.default) : (7e-6, 3e-5),
# Exceeds tolerances on CUDA, likely due to fma
(torch.float32, torch.ops.aten.mv.default) : (1e-5, 3e-5),
(torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 5e-4),
(torch.float64, torch.ops.aten.upsample_bicubic2d.default) : (1e-5, 5e-4),
# The decomposition is TOO correct. It computes everything in int64, so sometimes
# there's an off-by-one error. See
# https://github.com/pytorch/pytorch/issues/81996
# https://github.com/pytorch/pytorch/issues/82230
(torch.int8, torch.ops.aten.linspace.default) : (0, 1),
(torch.uint8, torch.ops.aten.linspace.default) : (0, 1),
(torch.int16, torch.ops.aten.linspace.default) : (0, 1),
(torch.int32, torch.ops.aten.linspace.default) : (0, 1),
(torch.int64, torch.ops.aten.linspace.default) : (0, 1),
}
if (decomp.dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
# Given f, returns an f' such that:
# - f' takes only positional arguments
# - All arguments to f' are floating-point Tensors
# - All outputs of f' are floating-point Tensors
def normalize_op_input_output2(
f, args, kwargs, output_process_fn_grad=None, requires_grad=True
):
flat_args, args_spec = tree_flatten(args)
diff_argnums = tuple(
i
for i, arg in enumerate(flat_args)
if diff_arg(arg, requires_grad=requires_grad)
)
assert len(diff_argnums) > 0
primals = tuple(flat_args[i] for i in diff_argnums)
@functools.wraps(f)
def wrapped(*primals):
_args = list(flat_args)
for num, arg in zip(diff_argnums, primals):
_args[num] = arg
_args = tree_unflatten(_args, args_spec)
result = f(*_args, **kwargs)
if output_process_fn_grad is not None:
result = output_process_fn_grad(result)
if isinstance(result, tuple):
# TODO We should check that the integer outputs also agree
result = tuple(
r
for r in result
if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
)
assert len(result) > 0
return result
return wrapped, primals
# NB: This also upcasts dtype arguments
# TODO: handle complex correctly
def upcast_tensor(x, dtype=torch.float32):
if isinstance(x, Tensor) and x.dtype.is_floating_point:
return x.to(dtype=dtype)
elif (isinstance(x, torch.dtype)
and x in [torch.float16, torch.bfloat16, torch.float]):
return dtype
else:
return x
def normalize_op_input_output(f, sample, requires_grad=True):
args = tuple([sample.input] + list(sample.args))
return normalize_op_input_output2(
f,
args,
sample.kwargs,
sample.output_process_fn_grad,
requires_grad=requires_grad,
)
CROSS_REF_EXCLUDE_SET = {
# CUBLAS_STATUS_NOT_SUPPORTED when calling
# `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
# (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
# (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
# (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
("cuda", torch.bfloat16, "nn.functional.bilinear"),
# randomness
(None, None, "special.ndtr"), # aten.special_ndtr was not decomposed
(None, None, "new_empty"),
(None, None, "empty_like"),
(None, None, "empty"),
# It's the only in-place op without an out-of-place equivalent in the Python API
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
(None, None, "zero_"),
# No idea what's going on here
# In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
# in the test, but it seems to pass when tested locally and in the logsumexp test
(None, torch.float32, "masked.logsumexp"),
(None, torch.float64, "masked.logsumexp"),
# exp_vml_cpu not implemented for Half
(torch.cpu, torch.float16, "signal.windows.exponential"),
(torch.cpu, torch.float16, "signal.windows.gaussian"),
# sin_vml_cpu not implemented for Half
(torch.cpu, torch.float16, "signal.windows.cosine"),
# CompositeAutogradImplicit
# See https://github.com/pytorch/pytorch/issues/81669
(None, None, "nn.functional.relu6"),
(None, None, "meshgrid"),
# diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit)
(None, None, "diag"),
# _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32
("cpu", torch.bfloat16, "_softmax_backward_data"),
(None, None, "norm"),
# native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
(None, None, "native_batch_norm"),
}
CROSS_REF_BACKWARD_EXCLUDE_SET = {
# Decomposed backward formula is not as precise
("cpu", torch.bfloat16, "nn.functional.hardswish"),
("cuda", torch.float16, "nn.functional.cross_entropy"),
}
all_decomposed = set()
all_called = defaultdict(int)
# Helpful snippet for testing coverage
"""
import atexit
def check_coverage():
print("missing coverage:")
print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
atexit.register(check_coverage)
"""
# Helpful snippet for Horace to create his google sheet :)
"""
import atexit
def dump_ops():
with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
f.write(f'{op.__name__}\n')
g.write(f'{count}\n')
with open('run_decompositions.txt', 'w') as f:
for op in sorted([i.__name__ for i in all_decomposed]):
f.write(f'{op}\n')
atexit.register(dump_ops)
"""
def any_unsupported(args, kwargs):
def test_unsupported(t):
if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
# These are all things that we haven't coded decompositions
# to handle correctly. Maybe they should.
return any([
t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized,
t.is_nested, torch._is_functional_tensor(t),
])
elif torch.overrides.is_tensor_like(t):
# Decompositions will generally change the behavior of Tensor-like
# subclasses, so bypass tests in this case too
return True
else:
return False
flat_args, _ = tree_flatten(args)
flat_kwargs, _ = tree_flatten(kwargs)
return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs))
class TestDecomp(TestCase):
longMessage = True
# NB: This actually overlaps with test_comprehensive, but it only
# runs on things that are definitely decomposed so it's a lot faster
# to run
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ops(_decomp_test_ops)
def test_quick(self, device, dtype, op):
self.do_cross_ref(device, dtype, op, run_all=False)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ops(op_db)
def test_comprehensive(self, device, dtype, op):
self.do_cross_ref(device, dtype, op, run_all=True)
def test_uniform(self, device):
size = (2, 3, 4, 5)
dtype = torch.float32
x = make_tensor(size, dtype=dtype, device=device)
low = 0.3
high = 0.9
torch.manual_seed(123)
ref = torch.ops.aten.uniform(x, low, high)
torch.manual_seed(123)
res = torch._decomp.decompositions.uniform(x, low=low, high=high)
self.assertEqual(ref, res)
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def do_cross_ref(self, device, dtype, op, *, run_all):
test_keys = [
(torch.device(device).type, dtype, op.name),
(None, dtype, op.name),
(None, None, op.name),
]
if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
self.skipTest(f"{op.name} in {dtype} not supported")
skip_decomp_vjp = any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys)
test_dtype = dtype
# We check the correctness of each decomposition right after running it.
# So, when we encounter a decomposition, we run the function normally, and
# then run the decomposition, and ensure they're identical.
called = set()
decomposed = set()
saved_precision = self.precision
saved_rel_tol = self.rel_tol
test_case = self
class DecompCrossRefMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
test_case.precision = saved_precision
test_case.rel_tol = saved_rel_tol
called.add(func)
all_called[func] += 1
# Stuff we shouldn't bother testing
# (TODO: remove detach from the decomp table?)
# N.b. Testing in-place ops would need dedicated logic
in_place = func.name()[-1] == '_'
if func not in decomposition_table or func in [
torch.ops.aten.detach.default,
# non-deterministic ops
torch.ops.aten.empty.memory_format,
torch.ops.aten.empty_like.default,
torch.ops.aten.new_empty.default,
torch.ops.aten.empty_strided.default,
torch.ops.aten.new_empty_strided.default,
torch.ops.aten.randn.default,
torch.ops.aten.native_dropout.default,
] or any_unsupported(args, kwargs) or in_place:
return func(*args, **kwargs)
decomposed.add(func)
all_decomposed.add(func)
# We take 2 main strategies for verifying correctness/numerical stability of decompositions
# The first one is simply tolerance checking between decomp_out and pytorch_out
# However, for fp16/bf16 and reductions, this becomes very
# finicky, as there are not many guarantees we can make.
# So, for fp16/bf16, we instead compare the difference of
# {decomp_out, pytorch_out_64} and {pytorch_out,
# pytorch_out_64}. In other words, we compare how far the
# decomposition and pytorch are from the "ground truth" (i.e.
# fp64). If the decomposition results in more error, we error
# We also decompose the decomposition recursively for
# further coverage, as some paths not be exercised directly by
# OpInfos (sadly) but just by other ops
decomposition = decomposition_table[func]
do_relative_check = test_dtype in [torch.float16, torch.bfloat16]
if run_all:
# Execute recursively via DFS, to find the root of a possible error first
with self:
decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
else:
decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
# At this stage we should not be decomposing an in-place op
# We'd like to have decompositions that decompose out-of-place ops into out-of-place ops
# because decompositions are run after functionalisation and we would not like them to
# de-functionalise the graph, as that would break AoTAutograd
# We run the real function *after* the decomposition to make sure that the
# decomposition does not modify any of the inputs in-place. If it does
# real_out should be differen than decom_out so we should catch this
real_out_unflat = func(*args, **kwargs)
real_out, _ = tree_flatten(real_out_unflat)
assert len(real_out) == len(decomp_out)
if do_relative_check:
upcast = partial(upcast_tensor, dtype=torch.float64)
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
for i, (orig, decomp, ref) in enumerate(zip(real_out, decomp_out, real_out_double)):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_ref(test_case, func, test_dtype, i, orig, decomp, ref, args, kwargs)
else:
for orig, decomp in zip(real_out, decomp_out):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_equal(test_case, func, test_dtype, orig, decomp, args, kwargs)
return real_out_unflat
requires_grad = (
op.supports_autograd
and dtype in op.supported_backward_dtypes(torch.device(device).type)
# TODO: OpInfo really ought to error out for this case, but it's
# not exercised in test_ops_gradients atm. The problem is not
# complex32 per-se (which is supported by data movement only ops)
# but that when we do backwards we expect other ops like add to work
and not dtype == torch.complex32
)
samples = op.sample_inputs(device, test_dtype, requires_grad=requires_grad)
def check_decomposed(aten_name):
self.assertTrue(
any(overload_to_aten_name(c) == aten_name for c in decomposed),
msg=(f"aten.{aten_name} was not decomposed, saw calls for: "
f"{', '.join(map(str, list(called)))}. If your op is "
f"CompositeImplicitAutograd you should skip this test "
"by updating CROSS_REF_EXCLUDE_SET.")
)
aten_name = op.decomp_aten_name or op.aten_name
func = op.get_op()
for sample_input in samples:
if requires_grad:
fn, primals = normalize_op_input_output(func, sample_input)
primals = tree_map(
lambda x: x if isinstance(x, torch.Tensor) else x, primals
)
# Once https://github.com/pytorch/pytorch/pull/75965/ I can
# store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode
# for each region
decomposed.clear()
with DecompCrossRefMode(), enable_python_dispatcher():
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
check_decomposed(aten_name)
if not skip_decomp_vjp and (op.aten_backward_name in decomposition_names or run_all):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
decomposed.clear()
with DecompCrossRefMode(), enable_python_dispatcher():
decomp_vjp_fn(cotangents)
if not run_all:
check_decomposed(op.aten_backward_name)
elif aten_name in decomposition_names or run_all:
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
decomposed.clear()
with DecompCrossRefMode(), enable_python_dispatcher():
func(*args, **kwargs)
if not run_all:
check_decomposed(aten_name)
else:
assert op.supports_autograd
self.skipTest(
"only backwards is decomposed, but dtype doesn't support AD"
)
instantiate_device_type_tests(TestDecomp, globals())
class DecompContiguousTests(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
def test_contiguous_softmax(self, device):
size = (2, 4, 3, 3)
stride = (9, 18, 3, 1)
dtype = torch.float32
x = torch.randn(size, dtype=dtype, device=device)
x = torch.as_strided(x, size, stride)
ref = torch.ops.aten._softmax(x, -1, False)
res = torch._decomp.decompositions._softmax(x, -1, False)
self.assertEqual(ref.stride(), res.stride())
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
def test_contiguous_log_softmax(self, device):
size = (2, 4, 3, 3)
stride = (9, 18, 3, 1)
dtype = torch.float32
x = torch.randn(size, dtype=dtype, device=device)
x = torch.as_strided(x, size, stride)
ref = torch.ops.aten._log_softmax(x, -1, False)
res = torch._decomp.decompositions._log_softmax(x, -1, False)
self.assertEqual(ref.stride(), res.stride())
instantiate_device_type_tests(DecompContiguousTests, globals())
class DecompAmpTests(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@skipIfCrossRef
@onlyCUDA
def test_amp_batch_norm_backward(self):
device = "cuda"
grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
weight = torch.randn((2,), dtype=torch.float32, device=device)
rmean = torch.randn((2,), dtype=torch.float32, device=device)
rvar = torch.randn((2,), dtype=torch.float32, device=device)
mean = torch.randn((0,), dtype=torch.float32, device=device)
ref = torch.ops.aten.native_batch_norm_backward(
grad_out,
x,
weight,
rmean,
rvar,
mean,
mean,
False,
1e-05,
[True, True, True])
res = torch._decomp.decompositions.native_batch_norm_backward(
grad_out,
x,
weight,
rmean,
rvar,
mean,
mean,
False,
1e-05,
[True, True, True])
for (a, b) in zip(ref, res):
self.assertEqual(a.stride(), b.stride())
self.assertEqual(a.dtype, b.dtype)
instantiate_device_type_tests(DecompAmpTests, globals())
class HasDecompTest(TestCase):
def setUp(self):
super().setUp()
self.maxDiff = None
def test_has_decomposition(self):
def can_appear_in_trace(op) -> bool:
has_tensor_arg = any(
"Tensor" in str(a.type)
for a in itertools.chain(op._schema.arguments, op._schema.returns))
if not has_tensor_arg:
return False
try:
# CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
return not has_key(op, DispatchKey.CompositeImplicitAutograd)
except RuntimeError as e:
# has_key fails for some jit-registered ops, which shouldn't be
# relevant here anyway
if 'does not exist' in str(e):
return False
raise
def all_aten_overloads():
for name in torch._C._dispatch_get_all_op_names():
if not name.startswith("aten::"):
continue
name = name[6:]
if "." in name:
packet_name, overload_name = name.split(".")
else:
packet_name, overload_name = name, "default"
packet = getattr(aten, packet_name)
assert isinstance(packet, torch._ops.OpOverloadPacket)
op = getattr(packet, overload_name)
yield op
# This is for operators that are only registered in some CI
# configurations, so would cause the test to fail
allow_list = set([aten.get_gradients.default])
overloads_wanting_decomp = set(op for op in all_aten_overloads() if can_appear_in_trace(op))
ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
ops_missing_decomp -= allow_list
self.assertExpected("".join(sorted(op.name() + "\n" for op in ops_missing_decomp)))
if __name__ == "__main__":
run_tests()