Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
name = 'PyTorch'
version = '1.12.1'
versionsuffix = '-CUDA-%(cudaver)s'

homepage = 'https://pytorch.org/'
description = """Tensors and Dynamic neural networks in Python with strong GPU acceleration.
PyTorch is a deep learning framework that puts Python first."""

toolchain = {'name': 'foss', 'version': '2022a'}

source_urls = [GITHUB_RELEASE]
sources = ['%(namelower)s-v%(version)s.tar.gz']

patches = [
'PyTorch-1.7.0_avoid-nan-in-test-torch.patch',
'PyTorch-1.7.0_disable-dev-shm-test.patch',
'PyTorch-1.10.0_fix-kineto-crash.patch',
'PyTorch-1.10.0_fix-test-dataloader-fixed-affinity.patch',
'PyTorch-1.10.0_fix-test-model_dump.patch',
'PyTorch-1.10.0_fix-vsx-vector-functions.patch',
'PyTorch-1.10.0_skip-nnapi-test-without-qnnpack.patch',
'PyTorch-1.11.0_fix-fsdp-fp16-test.patch',
'PyTorch-1.11.0_fix-test_utils.patch',
'PyTorch-1.11.0_increase_c10d_gloo_timeout.patch',
'PyTorch-1.11.0_increase-distributed-test-timeout.patch',
'PyTorch-1.11.0_install-vsx-vec-headers.patch',
'PyTorch-1.12.1_fix-cuda-gcc-version-check.patch',
'PyTorch-1.12.1_fix-skip-decorators.patch',
'PyTorch-1.12.1_fix-test_cpp_extensions_jit.patch',
'PyTorch-1.12.1_fix-test_wishart_log_prob.patch',
'PyTorch-1.12.1_fix-TestCudaFuser.test_unary_ops.patch',
'PyTorch-1.12.1_fix-TestTorch.test_to.patch',
'PyTorch-1.12.1_fix-use-after-free-in-tensorpipe-agent.patch',
'PyTorch-1.12.1_fix-vsx-vector-funcs.patch',
'PyTorch-1.12.1_fix-vsx-loadu.patch',
'PyTorch-1.12.1_increase-test-adadelta-tolerance.patch',
'PyTorch-1.12.1_increase-tolerance-test_ops.patch',
'PyTorch-1.12.1_no-cuda-stubs-rpath.patch',
'PyTorch-1.12.1_python-3.10-annotation-fix.patch',
'PyTorch-1.12.1_python-3.10-compat.patch',
'PyTorch-1.12.1_remove-flaky-test-in-testnn.patch',
'PyTorch-1.12.1_skip-ao-sparsity-test-without-fbgemm.patch',
'PyTorch-1.12.1_skip-failing-grad-test.patch',
'PyTorch-1.12.1_skip-test_round_robin_create_destroy.patch',
]
checksums = [
'031c71073db73da732b5d01710220564ce6dd88d812ba053f0cc94296401eccb', # pytorch-v1.12.1.tar.gz
'b899aa94d9e60f11ee75a706563312ccefa9cf432756c470caa8e623991c8f18', # PyTorch-1.7.0_avoid-nan-in-test-torch.patch
'622cb1eaeadc06e13128a862d9946bcc1f1edd3d02b259c56a9aecc4d5406b8a', # PyTorch-1.7.0_disable-dev-shm-test.patch
# PyTorch-1.10.0_fix-kineto-crash.patch
'dc467333b28162149af8f675929d8c6bf219f23230bfc0d39af02ba4f6f882eb',
# PyTorch-1.10.0_fix-test-dataloader-fixed-affinity.patch
'313dca681f45ce3bc7c4557fdcdcbe0b77216d2c708fa30a2ec0e22c44876707',
# PyTorch-1.10.0_fix-test-model_dump.patch
'339148ae1a028cda6e750ac93fa38a599f66c7abe26586c9219f1a206ea14557',
# PyTorch-1.10.0_fix-vsx-vector-functions.patch
'7bef5f96cb83b2d655d2f76dd7468a171d446f0b3e06da2232ec7f886484d312',
# PyTorch-1.10.0_skip-nnapi-test-without-qnnpack.patch
'34ba476a7bcddec323bf9eca083cb4623d0f569d081aa3add3769c24f22849d2',
'bb1c4e6d6fd4b0cf57ff8b824c797331b533bb1ffc63f5db0bae3aee10c3dc13', # PyTorch-1.11.0_fix-fsdp-fp16-test.patch
'4f7e25c4e2eb7094f92607df74488c6a4a35849fabf05fcf6c3655fa3f44a861', # PyTorch-1.11.0_fix-test_utils.patch
# PyTorch-1.11.0_increase_c10d_gloo_timeout.patch
'20cd4a8663f74ab326fdb032b926bf5c7e94d9750c515ab9050927ba00cf1953',
# PyTorch-1.11.0_increase-distributed-test-timeout.patch
'087ad20163a1291773ae3457569b80523080eb3731e210946459b2333a919f3f',
'f2e6b9625733d9a471bb75e1ea20e28814cf1380b4f9089aa838ee35ddecf07d', # PyTorch-1.11.0_install-vsx-vec-headers.patch
# PyTorch-1.12.1_fix-cuda-gcc-version-check.patch
'a650f4576f06c749f244cada52ff9c02499fa8f182019129488db3845e0756ab',
'e3ca6e42b2fa592ea095939fb59ab875668a058479407db3f3684cc5c6f4146c', # PyTorch-1.12.1_fix-skip-decorators.patch
# PyTorch-1.12.1_fix-test_cpp_extensions_jit.patch
'1efc9850c431d702e9117d4766277d3f88c5c8b3870997c9974971bce7f2ab83',
# PyTorch-1.12.1_fix-test_wishart_log_prob.patch
'cf475ae6e6234b96c8d1bf917597c5176c94b3ccd940b72f2e1cd0c979580f45',
# PyTorch-1.12.1_fix-TestCudaFuser.test_unary_ops.patch
'8e6e844c6b0541e0c8115911ee1a9d548613254b36dfbdada202fd723fc26aa2',
'75f27987c3f25c501e719bd2b1c70a029ae0ee28514a97fe447516aee02b1535', # PyTorch-1.12.1_fix-TestTorch.test_to.patch
# PyTorch-1.12.1_fix-use-after-free-in-tensorpipe-agent.patch
'0bd7e88b92c4c6f0fecf01746009858ba19f2df68b10b88c41485328a531875d',
'caccbf60f62eac313896c1eaec78b08f5d0fdfcb907079087490bb13d1561aa2', # PyTorch-1.12.1_fix-vsx-vector-funcs.patch
'8bfe3c94ada1dd1f7974a1261a8b576fb7ae944050fa1c7830fca033831123b2', # PyTorch-1.12.1_fix-vsx-loadu.patch
# PyTorch-1.12.1_increase-test-adadelta-tolerance.patch
'944ed1af5ad4bbe20cbb042764a88dad1eef6cd33218617cf3d4cd90c6764695',
# PyTorch-1.12.1_increase-tolerance-test_ops.patch
'1c1fa520801e2ee5faf56a3d6dc96321e7c11664fd16bffd7c6ee437e68357fb',
'2905826ca713752b47c84e4ec8b177c90cbd91fca498ba2ba546f495c4cf70a6', # PyTorch-1.12.1_no-cuda-stubs-rpath.patch
# PyTorch-1.12.1_python-3.10-annotation-fix.patch
'11e168fd429d9e156fc79dd806b08125f3640651ad9998abd810446b2ed0c2d7',
'81402420a878b40f824778f0333fbec6504325a6a1b06a22749c4cac3eaccf67', # PyTorch-1.12.1_python-3.10-compat.patch
# PyTorch-1.12.1_remove-flaky-test-in-testnn.patch
'e81b678e354dd137c0d6d974605cdedbf672096fdbdf567c347bc2fbfc73471d',
# PyTorch-1.12.1_skip-ao-sparsity-test-without-fbgemm.patch
'edd464ec8c37b44c07a72008d732604f6837f2dd61c7810c391a86ba4945ca39',
'1c89e7e67287fe6b9a95480a4178d3653b94d0ab2fe68edf227606c8ae548fdc', # PyTorch-1.12.1_skip-failing-grad-test.patch
# PyTorch-1.12.1_skip-test_round_robin_create_destroy.patch
'1435fcac3234edc865479199673b902eb67f6a2bd046af7d731141f03594666d',
]

osdependencies = [OS_PKG_IBVERBS_DEV]

builddependencies = [
('CMake', '3.23.1'),
('hypothesis', '6.46.7'),
]

dependencies = [
('CUDA', '11.7.0', '', SYSTEM),
('Ninja', '1.10.2'), # Required for JIT compilation of C++ extensions
('Python', '3.10.4'),
('protobuf', '3.19.4'),
('protobuf-python', '3.19.4'),
('pybind11', '2.9.2'),
('SciPy-bundle', '2022.05'),
('PyYAML', '6.0'),
('MPFR', '4.1.0'),
('GMP', '6.2.1'),
('numactl', '2.0.14'),
('FFmpeg', '4.4.2'),
('Pillow', '9.1.1'),
('cuDNN', '8.4.1.50', '-CUDA-%(cudaver)s', SYSTEM),
('magma', '2.6.2', '-CUDA-%(cudaver)s'),
('NCCL', '2.12.12', '-CUDA-%(cudaver)s'),
('expecttest', '0.1.3'),
]

# default CUDA compute capabilities to use (override via --cuda-compute-capabilities)
cuda_compute_capabilities = ['3.5', '3.7', '5.2', '6.0', '6.1', '7.0', '7.2', '7.5', '8.0', '8.6']

excluded_tests = {
'': [
# This test seems to take too long on NVIDIA Ampere at least.
'distributed/test_distributed_spawn',
# Broken on CUDA 11.6/11.7: https://github.com/pytorch/pytorch/issues/75375
'distributions/test_constraints',
]
}

runtest = 'cd test && PYTHONUNBUFFERED=1 %(python)s run_test.py --continue-through-error --verbose %(excluded_tests)s'

# The readelf sanity check command can be taken out once the TestRPATH test from
# https://github.com/pytorch/pytorch/pull/87593 is accepted, since it is then checked as part of the PyTorch test suite
local_libcaffe2 = "$EBROOTPYTORCH/lib/python%%(pyshortver)s/site-packages/torch/lib/libcaffe2_nvrtc.%s" % SHLIB_EXT
sanity_check_commands = [
"readelf -d %s | egrep 'RPATH|RUNPATH' | grep -v stubs" % local_libcaffe2,
]

tests = ['PyTorch-check-cpp-extension.py']

moduleclass = 'ai'
133 changes: 133 additions & 0 deletions easybuild/easyconfigs/p/PyTorch/PyTorch-1.12.1-foss-2022a.eb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
name = 'PyTorch'
version = '1.12.1'

homepage = 'https://pytorch.org/'
description = """Tensors and Dynamic neural networks in Python with strong GPU acceleration.
PyTorch is a deep learning framework that puts Python first."""

toolchain = {'name': 'foss', 'version': '2022a'}

source_urls = [GITHUB_RELEASE]
sources = ['%(namelower)s-v%(version)s.tar.gz']

patches = [
'PyTorch-1.7.0_avoid-nan-in-test-torch.patch',
'PyTorch-1.7.0_disable-dev-shm-test.patch',
'PyTorch-1.10.0_fix-kineto-crash.patch',
'PyTorch-1.10.0_fix-test-dataloader-fixed-affinity.patch',
'PyTorch-1.10.0_fix-test-model_dump.patch',
'PyTorch-1.10.0_fix-vsx-vector-functions.patch',
'PyTorch-1.10.0_skip-nnapi-test-without-qnnpack.patch',
'PyTorch-1.11.0_fix-fsdp-fp16-test.patch',
'PyTorch-1.11.0_fix-test_utils.patch',
'PyTorch-1.11.0_increase_c10d_gloo_timeout.patch',
'PyTorch-1.11.0_increase-distributed-test-timeout.patch',
'PyTorch-1.11.0_install-vsx-vec-headers.patch',
'PyTorch-1.12.1_fix-cuda-gcc-version-check.patch',
'PyTorch-1.12.1_fix-skip-decorators.patch',
'PyTorch-1.12.1_fix-test_cpp_extensions_jit.patch',
'PyTorch-1.12.1_fix-test_wishart_log_prob.patch',
'PyTorch-1.12.1_fix-TestCudaFuser.test_unary_ops.patch',
'PyTorch-1.12.1_fix-TestTorch.test_to.patch',
'PyTorch-1.12.1_fix-use-after-free-in-tensorpipe-agent.patch',
'PyTorch-1.12.1_fix-vsx-vector-funcs.patch',
'PyTorch-1.12.1_fix-vsx-loadu.patch',
'PyTorch-1.12.1_increase-test-adadelta-tolerance.patch',
'PyTorch-1.12.1_increase-tolerance-test_ops.patch',
'PyTorch-1.12.1_no-cuda-stubs-rpath.patch',
'PyTorch-1.12.1_python-3.10-annotation-fix.patch',
'PyTorch-1.12.1_python-3.10-compat.patch',
'PyTorch-1.12.1_remove-flaky-test-in-testnn.patch',
'PyTorch-1.12.1_skip-ao-sparsity-test-without-fbgemm.patch',
'PyTorch-1.12.1_skip-failing-grad-test.patch',
'PyTorch-1.12.1_skip-test_round_robin_create_destroy.patch',
]
checksums = [
'031c71073db73da732b5d01710220564ce6dd88d812ba053f0cc94296401eccb', # pytorch-v1.12.1.tar.gz
'b899aa94d9e60f11ee75a706563312ccefa9cf432756c470caa8e623991c8f18', # PyTorch-1.7.0_avoid-nan-in-test-torch.patch
'622cb1eaeadc06e13128a862d9946bcc1f1edd3d02b259c56a9aecc4d5406b8a', # PyTorch-1.7.0_disable-dev-shm-test.patch
# PyTorch-1.10.0_fix-kineto-crash.patch
'dc467333b28162149af8f675929d8c6bf219f23230bfc0d39af02ba4f6f882eb',
# PyTorch-1.10.0_fix-test-dataloader-fixed-affinity.patch
'313dca681f45ce3bc7c4557fdcdcbe0b77216d2c708fa30a2ec0e22c44876707',
# PyTorch-1.10.0_fix-test-model_dump.patch
'339148ae1a028cda6e750ac93fa38a599f66c7abe26586c9219f1a206ea14557',
# PyTorch-1.10.0_fix-vsx-vector-functions.patch
'7bef5f96cb83b2d655d2f76dd7468a171d446f0b3e06da2232ec7f886484d312',
# PyTorch-1.10.0_skip-nnapi-test-without-qnnpack.patch
'34ba476a7bcddec323bf9eca083cb4623d0f569d081aa3add3769c24f22849d2',
'bb1c4e6d6fd4b0cf57ff8b824c797331b533bb1ffc63f5db0bae3aee10c3dc13', # PyTorch-1.11.0_fix-fsdp-fp16-test.patch
'4f7e25c4e2eb7094f92607df74488c6a4a35849fabf05fcf6c3655fa3f44a861', # PyTorch-1.11.0_fix-test_utils.patch
# PyTorch-1.11.0_increase_c10d_gloo_timeout.patch
'20cd4a8663f74ab326fdb032b926bf5c7e94d9750c515ab9050927ba00cf1953',
# PyTorch-1.11.0_increase-distributed-test-timeout.patch
'087ad20163a1291773ae3457569b80523080eb3731e210946459b2333a919f3f',
'f2e6b9625733d9a471bb75e1ea20e28814cf1380b4f9089aa838ee35ddecf07d', # PyTorch-1.11.0_install-vsx-vec-headers.patch
# PyTorch-1.12.1_fix-cuda-gcc-version-check.patch
'a650f4576f06c749f244cada52ff9c02499fa8f182019129488db3845e0756ab',
'e3ca6e42b2fa592ea095939fb59ab875668a058479407db3f3684cc5c6f4146c', # PyTorch-1.12.1_fix-skip-decorators.patch
# PyTorch-1.12.1_fix-test_cpp_extensions_jit.patch
'1efc9850c431d702e9117d4766277d3f88c5c8b3870997c9974971bce7f2ab83',
# PyTorch-1.12.1_fix-test_wishart_log_prob.patch
'cf475ae6e6234b96c8d1bf917597c5176c94b3ccd940b72f2e1cd0c979580f45',
# PyTorch-1.12.1_fix-TestCudaFuser.test_unary_ops.patch
'8e6e844c6b0541e0c8115911ee1a9d548613254b36dfbdada202fd723fc26aa2',
'75f27987c3f25c501e719bd2b1c70a029ae0ee28514a97fe447516aee02b1535', # PyTorch-1.12.1_fix-TestTorch.test_to.patch
# PyTorch-1.12.1_fix-use-after-free-in-tensorpipe-agent.patch
'0bd7e88b92c4c6f0fecf01746009858ba19f2df68b10b88c41485328a531875d',
'caccbf60f62eac313896c1eaec78b08f5d0fdfcb907079087490bb13d1561aa2', # PyTorch-1.12.1_fix-vsx-vector-funcs.patch
'8bfe3c94ada1dd1f7974a1261a8b576fb7ae944050fa1c7830fca033831123b2', # PyTorch-1.12.1_fix-vsx-loadu.patch
# PyTorch-1.12.1_increase-test-adadelta-tolerance.patch
'944ed1af5ad4bbe20cbb042764a88dad1eef6cd33218617cf3d4cd90c6764695',
# PyTorch-1.12.1_increase-tolerance-test_ops.patch
'1c1fa520801e2ee5faf56a3d6dc96321e7c11664fd16bffd7c6ee437e68357fb',
'2905826ca713752b47c84e4ec8b177c90cbd91fca498ba2ba546f495c4cf70a6', # PyTorch-1.12.1_no-cuda-stubs-rpath.patch
# PyTorch-1.12.1_python-3.10-annotation-fix.patch
'11e168fd429d9e156fc79dd806b08125f3640651ad9998abd810446b2ed0c2d7',
'81402420a878b40f824778f0333fbec6504325a6a1b06a22749c4cac3eaccf67', # PyTorch-1.12.1_python-3.10-compat.patch
# PyTorch-1.12.1_remove-flaky-test-in-testnn.patch
'e81b678e354dd137c0d6d974605cdedbf672096fdbdf567c347bc2fbfc73471d',
# PyTorch-1.12.1_skip-ao-sparsity-test-without-fbgemm.patch
'edd464ec8c37b44c07a72008d732604f6837f2dd61c7810c391a86ba4945ca39',
'1c89e7e67287fe6b9a95480a4178d3653b94d0ab2fe68edf227606c8ae548fdc', # PyTorch-1.12.1_skip-failing-grad-test.patch
# PyTorch-1.12.1_skip-test_round_robin_create_destroy.patch
'1435fcac3234edc865479199673b902eb67f6a2bd046af7d731141f03594666d',
]

osdependencies = [OS_PKG_IBVERBS_DEV]

builddependencies = [
('CMake', '3.23.1'),
('hypothesis', '6.46.7'),
]

dependencies = [
('Ninja', '1.10.2'), # Required for JIT compilation of C++ extensions
('Python', '3.10.4'),
('protobuf', '3.19.4'),
('protobuf-python', '3.19.4'),
('pybind11', '2.9.2'),
('SciPy-bundle', '2022.05'),
('PyYAML', '6.0'),
('MPFR', '4.1.0'),
('GMP', '6.2.1'),
('numactl', '2.0.14'),
('FFmpeg', '4.4.2'),
('Pillow', '9.1.1'),
('expecttest', '0.1.3'),
]

excluded_tests = {
'': [
# This test seems to take too long on NVIDIA Ampere at least.
'distributed/test_distributed_spawn',
# Broken on CUDA 11.6/11.7: https://github.com/pytorch/pytorch/issues/75375
'distributions/test_constraints',
]
}

runtest = 'cd test && PYTHONUNBUFFERED=1 %(python)s run_test.py --continue-through-error --verbose %(excluded_tests)s'

tests = ['PyTorch-check-cpp-extension.py']

moduleclass = 'ai'
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
Fix a TestDistributions.test_wishart_log_prob failure in test_distributions
See https://github.com/pytorch/pytorch/pull/87977

Author: Alexander Grund (TU Dresden)

diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py
index 127018516e123..219eacf4790b0 100644
--- a/test/distributions/test_distributions.py
+++ b/test/distributions/test_distributions.py
@@ -2036,7 +2036,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()

self.assertEqual(batched_prob.shape, unbatched_prob.shape)
- self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
+ self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_lowrank_multivariate_normal_sample(self):
@@ -2176,7 +2176,7 @@ def test_multivariate_normal_log_prob(self):
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()

self.assertEqual(batched_prob.shape, unbatched_prob.shape)
- self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
+ self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_multivariate_normal_sample(self):
@@ -2331,7 +2331,7 @@ def test_wishart_log_prob(self):
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()

self.assertEqual(batched_prob.shape, unbatched_prob.shape)
- self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
+ self.assertEqual(batched_prob, unbatched_prob, atol=1e-3, rtol=0)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_wishart_sample(self):
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
Handle change of type annotation handling in Python 3.10 in the JIT module.
Fixes failures in e.g. test_jit.

From https://github.com/pytorch/pytorch/pull/81334 & https://github.com/pytorch/pytorch/pull/81506
Backport: Alexander Grund (TU Dresden)

diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 3c067d5c1c5..2cc3dba89b4 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -1059,11 +1059,19 @@ def _get_named_tuple_properties(obj):
if field in obj._field_defaults]
else:
defaults = []
+ # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
+ # Also, annotations from base class are not inherited so they need to be queried explicitly
+ if sys.version_info[:2] < (3, 10):
+ obj_annotations = getattr(obj, '__annotations__', {})
+ else:
+ obj_annotations = inspect.get_annotations(obj)
+ if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
+ obj_annotations = inspect.get_annotations(obj.__base__)
+
annotations = []
- has_annotations = hasattr(obj, '__annotations__')
for field in obj._fields:
- if has_annotations and field in obj.__annotations__:
- the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
+ if field in obj_annotations:
+ the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.getInferred())
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index 8175d14fe5d..aa36a4561d4 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -5,6 +5,7 @@ import collections
import textwrap
import functools
import warnings
+import sys
from typing import Dict, List, Set, Type

import torch._jit_internal as _jit_internal
@@ -134,7 +135,22 @@ def infer_concrete_type_builder(nn_module, share_types=True):
if isinstance(nn_module, (torch.nn.ParameterDict)):
concrete_type_builder.set_parameter_dict()

- class_annotations = getattr(nn_module, '__annotations__', {})
+ def get_annotations(obj):
+ if sys.version_info < (3, 10):
+ return getattr(obj, '__annotations__', {})
+ # In Python-3.10+ it is recommended to use inspect.get_annotations
+ # See https://docs.python.org/3.10/howto/annotations.html
+ # But also, in 3.10 annotations from base class are not inherited
+ # by unannotated derived one, so they must be manually extracted
+ annotations = inspect.get_annotations(obj)
+ if len(annotations) > 0:
+ return annotations
+ cls = obj if isinstance(obj, type) else type(obj)
+ if len(cls.__bases__) == 0:
+ return {}
+ return inspect.get_annotations(cls.__bases__[0])
+
+ class_annotations = get_annotations(nn_module)
if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)):
class_annotations = {}

Loading