From a8f4e08c345fc6bc0d869a0b0dad686705218ca5 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 27 Jun 2024 09:43:59 -0700 Subject: [PATCH] remove torch dependency (#449) * remove torch dependency * push * push * push * push * push * push * Update pre_build_script.sh * Update version.txt * push --- packaging/pre_build_script.sh | 8 ++++++++ setup.py | 10 ---------- torchao/prototype/mx_formats/custom_cast.py | 2 +- version.txt | 2 +- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 4bbd09694a..1690c62750 100644 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -9,4 +9,12 @@ set -eux echo "This script is run before building torchao binaries" +python -m pip install --upgrade pip +if [ -z "$PYTORCH_VERSION" ]; then + PYTORCH_DEP="torch" +else + PYTORCH_DEP="torch==$PYTORCH_VERSION" +fi +pip install $PYTORCH_DEP + pip install setuptools wheel twine auditwheel diff --git a/setup.py b/setup.py index a6a2da8a75..fc85dcea9c 100644 --- a/setup.py +++ b/setup.py @@ -110,15 +110,6 @@ def get_extensions(): return ext_modules -# Mimic code from torchvision https://github.com/pytorch/vision/blob/143d078b28f00471156a4e562dd3836370acc9ee/setup.py#L58 -pytorch_dep = "torch" -if os.getenv("PYTORCH_VERSION"): - pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") - -requirements = [ - pytorch_dep, -] - setup( name=package_name, version=version+version_suffix, @@ -128,7 +119,6 @@ def get_extensions(): "torchao.kernel.configs": ["*.pkl"], }, ext_modules=get_extensions() if use_cpp != "0" else None, - install_requires=requirements, extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", long_description=open("README.md").read(), diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 00082acad9..381f91c4a6 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -15,7 +15,7 @@ # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert # at the callsite prevents usage of this on unsupported versions. -if TORCH_VERSION_AFTER_2_4: +if TORCH_VERSION_AFTER_2_4 and has_triton(): from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( diff --git a/version.txt b/version.txt index 0d91a54c7d..9e11b32fca 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.0 +0.3.1