Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions one_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def torch_version():
torver = [line for line in torch_version_file if '__version__' in line][0].split('__version__ = ')[1].strip("'")
else:
from torch import __version__ as torver

return torver


Expand Down Expand Up @@ -203,7 +204,7 @@ def install_webui():

# Find the proper Pytorch installation command
install_git = "conda install -y -k ninja git"
install_pytorch = "python -m pip install torch torchvision torchaudio"
install_pytorch = "python -m pip install torch==2.1.* torchvision==0.16.* torchaudio==2.1.* "

use_cuda118 = "N"
if any((is_windows(), is_linux())) and selected_gpu == "NVIDIA":
Expand All @@ -219,20 +220,20 @@ def install_webui():

if use_cuda118 == 'Y':
print("CUDA: 11.8")
install_pytorch += "--index-url https://download.pytorch.org/whl/cu118"
else:
print("CUDA: 12.1")

install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}"
install_pytorch += "--index-url https://download.pytorch.org/whl/cu121"
elif not is_macos() and selected_gpu == "AMD":
if is_linux():
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6"
install_pytorch += "--index-url https://download.pytorch.org/whl/rocm5.6"
else:
print("AMD GPUs are only supported on Linux. Exiting...")
sys.exit(1)
elif is_linux() and selected_gpu in ["APPLE", "NONE"]:
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu"
install_pytorch += "--index-url https://download.pytorch.org/whl/cpu"
elif selected_gpu == "INTEL":
install_pytorch = "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 intel_extension_for_pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
install_pytorch += "intel_extension_for_pytorch==2.1.* --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"

# Install Git and then Pytorch
print_big_message("Installing PyTorch.")
Expand Down