diff --git a/setup.py b/setup.py index 68baf9c939bdb..eabcf063013c7 100644 --- a/setup.py +++ b/setup.py @@ -168,5 +168,9 @@ def append_nvcc_threads(nvcc_extra_args): ], ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension} if ext_modules else {}, - python_requires=">=3.7" + python_requires=">=3.7", + install_requires=[ + "torch", + "einops", + ], )