diff --git a/.github/workflows/run_tests/action.yml b/.github/workflows/run_tests/action.yml index 935a5a7e099..64d009f7a2d 100644 --- a/.github/workflows/run_tests/action.yml +++ b/.github/workflows/run_tests/action.yml @@ -16,6 +16,9 @@ inputs: runs: using: "composite" steps: + - name: Setup MSVC for torch.compile + if: runner.os == 'Windows' + uses: ilammy/msvc-dev-cmd@v1 - name: Install dependencies working-directory: python shell: bash diff --git a/python/pyproject.toml b/python/pyproject.toml index 815e28c6666..4b6814aa070 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -64,7 +64,7 @@ tests = [ ] dev = ["ruff==0.4.1", "pyright"] benchmarks = ["pytest-benchmark"] -torch = ["torch"] +torch = ["torch>=2.0"] geo = [ "geoarrow-rust-core", "geoarrow-rust-io", @@ -115,9 +115,13 @@ filterwarnings = [ 'ignore:.*datetime\.datetime\.utcnow\(\) is deprecated.*:DeprecationWarning', # Pandas 2.2 on Python 2.12 'ignore:.*datetime\.datetime\.utcfromtimestamp\(\) is deprecated.*:DeprecationWarning', - # Pytorch 2.2 on Python 2.12 + # Pytorch 2.2 on Python 3.12 'ignore:.*is deprecated and will be removed in Python 3\.14.*:DeprecationWarning', 'ignore:.*The distutils package is deprecated.*:DeprecationWarning', + # Pytorch inductor uses deprecated load_module() in its code cache + 'ignore:.*the load_module\(\) method is deprecated.*:DeprecationWarning', + # Pytorch uses deprecated jit.script_method internally (torch/utils/mkldnn.py) + 'ignore:.*torch\.jit\.script_method.*is deprecated.*:DeprecationWarning', # TensorFlow/Keras import can emit NumPy deprecation FutureWarnings in some environments. # Keep FutureWarnings as errors generally, but ignore this known-noisy import-time warning. 'ignore:.*np\.object.*:FutureWarning', diff --git a/python/python/lance/torch/distance.py b/python/python/lance/torch/distance.py index 3c5280f4b11..81201027c87 100644 --- a/python/python/lance/torch/distance.py +++ b/python/python/lance/torch/distance.py @@ -1,19 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors -import warnings from typing import Optional, Tuple -# Suppress torch.jit.script deprecation warning in PyTorch 2.10+ -# TODO: migrate to torch.compile when feasible -warnings.filterwarnings( - "ignore", - message=r".*torch\.jit\.script.*deprecated.*", - category=DeprecationWarning, -) - -from lance.dependencies import torch # noqa: E402 -from lance.log import LOGGER # noqa: E402 +from lance.dependencies import torch +from lance.log import LOGGER __all__ = [ "pairwise_cosine", @@ -24,7 +15,7 @@ ] -@torch.jit.script +@torch.compile def _pairwise_cosine( x: torch.Tensor, y: torch.Tensor, y2: torch.Tensor ) -> torch.Tensor: @@ -57,7 +48,7 @@ def pairwise_cosine( return _pairwise_cosine(x, y, y2) -@torch.jit.script +@torch.compile def _cosine_distance( vectors: torch.Tensor, centroids: torch.Tensor, split_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -122,7 +113,7 @@ def cosine_distance( raise RuntimeError("Cosine distance out of memory") -@torch.jit.script +@torch.compile def argmin_l2(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x = x.reshape(1, x.shape[0], -1) y = y.reshape(1, y.shape[0], -1) @@ -133,7 +124,7 @@ def argmin_l2(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten return min_dists.pow(2), idx -@torch.jit.script +@torch.compile def pairwise_l2( x: torch.Tensor, y: torch.Tensor, y2: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -178,7 +169,7 @@ def pairwise_l2( return dists.type(origin_dtype) -@torch.jit.script +@torch.compile def _l2_distance( x: torch.Tensor, y: torch.Tensor, @@ -245,7 +236,7 @@ def l2_distance( raise RuntimeError("L2 distance out of memory") -@torch.jit.script +@torch.compile def dot_distance(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Pair-wise dot distance between two 2-D Tensors.