diff --git a/smartsim/_core/_cli/build.py b/smartsim/_core/_cli/build.py index 474d96c8a..87fbff5fb 100644 --- a/smartsim/_core/_cli/build.py +++ b/smartsim/_core/_cli/build.py @@ -226,9 +226,10 @@ def build_redis_ai( logger.info("ML Backends and RedisAI build complete!") -def check_py_torch_version(versions: Versioner, device: _TDeviceStr = "cpu") -> None: +def check_py_torch_version(versions: Versioner, device_in: _TDeviceStr = "cpu") -> None: """Check Python environment for TensorFlow installation""" + device = device_in.lower() if BuildEnv.is_macos(): if device == "gpu": raise BuildError("SmartSim does not support GPU on MacOS") @@ -260,10 +261,11 @@ def check_py_torch_version(versions: Versioner, device: _TDeviceStr = "cpu") -> "Torch version not found in python environment. " "Attempting to install via `pip`" ) + wheel_device = device if device == "cpu" else device_suffix.replace("+","") pip( "install", - "-f", - "https://download.pytorch.org/whl/torch_stable.html", + "--extra-index-url", + f"https://download.pytorch.org/whl/{wheel_device}", *(f"{package}=={version}" for package, version in torch_deps.items()), ) elif missing or conflicts: