From 8fd7160f13c8544c50c0d4e52ce4a0158a75667b Mon Sep 17 00:00:00 2001 From: Andrew Shao Date: Mon, 8 Jan 2024 09:27:30 -0800 Subject: [PATCH] Fix index when installing torch through smart build (#449) Torch changed something in their indexing when trying to install from their provided wheels. This updates the `pip install` command within `smart build` to ensure that the appropriate packages can be found. [ committed by @ashao ] [ reviewed by @ankona ] --- smartsim/_core/_cli/build.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 474d96c8a..87fbff5fb 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -226,9 +226,10 @@ def build_redis_ai( logger.info("ML Backends and RedisAI build complete!") -def check_py_torch_version(versions: Versioner, device: _TDeviceStr = "cpu") -> None: +def check_py_torch_version(versions: Versioner, device_in: _TDeviceStr = "cpu") -> None: """Check Python environment for TensorFlow installation""" + device = device_in.lower() if BuildEnv.is_macos(): if device == "gpu": raise BuildError("SmartSim does not support GPU on MacOS") @@ -260,10 +261,11 @@ def check_py_torch_version(versions: Versioner, device: _TDeviceStr = "cpu") -> "Torch version not found in python environment. " "Attempting to install via `pip`" ) + wheel_device = device if device == "cpu" else device_suffix.replace("+","") pip( "install", - "-f", - "https://download.pytorch.org/whl/torch_stable.html", + "--extra-index-url", + f"https://download.pytorch.org/whl/{wheel_device}", *(f"{package}=={version}" for package, version in torch_deps.items()), ) elif missing or conflicts: