Burn Torch backend
This crate provides a Torch backend for Burn utilizing the
tch-rs
crate, which offers a Rust interface to the
PyTorch C++ API.
The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).
tch-rs
requires the C++ PyTorch library (LibTorch) to
be available on your system.
By default, the CPU distribution is installed for LibTorch v2.2.0 as required by tch-rs
.
CUDA
To install the latest compatible CUDA distribution, set the TORCH_CUDA_VERSION
environment
variable before the tch-rs
dependency is retrieved with cargo
.
export TORCH_CUDA_VERSION=cu121
On Windows:
$Env:TORCH_CUDA_VERSION = "cu121"
For example, running the validation sample for the first time could be done with the following commands:
export TORCH_CUDA_VERSION=cu121
cargo run --bin cuda --release
Important: make sure your driver version is compatible with the selected CUDA version. A CUDA Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having the latest driver version is recommended, but you can always take a look at the toolkit driver version table or minimum required driver version (limited feature-set, might not work with all operations).
Once your installation is complete, you should be able to build/run your project. You can also
validate your installation by running the appropriate cpu
, cuda
or mps
sample as below.
cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release
Note: no MPS distribution is available for automatic download at this time, please check out the manual instructions.
To install tch-rs
with a different LibTorch distribution, you will have to manually download the
desired LibTorch distribution. The instructions are detailed in the sections below for each
platform.
Compute Platform | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
---|---|---|---|---|---|---|---|---|
CPU | Yes | No | Yes | Yes | Yes | Yes | Yes | No |
CUDA | Yes [1] | Yes | Yes | No | Yes | No | No | No |
Metal (MPS) | No | Yes | No | Yes | No | No | No | No |
Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | No | No |
[1] The LibTorch CUDA distribution also comes with CPU support.
🐧 Linux
First, download the LibTorch CPU distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcpu.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and LD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
🍎 Mac
First, download the LibTorch CPU distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and DYLD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH
🪟 Windows
First, download the LibTorch CPU distribution.
wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.2.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
Then, set the LIBTORCH
environment variable and append the library to your path as with the
PowerShell commands below before building burn-tch
or a crate which depends on it.
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
LibTorch 2.2.0 currently includes binary distributions with CUDA 11.8 or 12.1 runtimes. The manual installation instructions are detailed below.
CUDA 11.8
🐧 Linux
First, download the LibTorch CUDA 11.8 distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu118.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and LD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
Note: make sure your CUDA installation is in your PATH
and LD_LIBRARY_PATH
.
🪟 Windows
First, download the LibTorch CUDA 11.8 distribution.
wget https://download.pytorch.org/libtorch/cu118/libtorch-win-shared-with-deps-2.2.0%2Bcu118.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
Then, set the LIBTORCH
environment variable and append the library to your path as with the
PowerShell commands below before building burn-tch
or a crate which depends on it.
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
CUDA 12.1
🐧 Linux
First, download the LibTorch CUDA 12.1 distribution.
wget -O libtorch.zip https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip
unzip libtorch.zip
Then, point to that installation using the LIBTORCH
and LD_LIBRARY_PATH
environment variables
before building burn-tch
or a crate which depends on it.
export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH
Note: make sure your CUDA installation is in your PATH
and LD_LIBRARY_PATH
.
🪟 Windows
First, download the LibTorch CUDA 12.1 distribution.
wget https://download.pytorch.org/libtorch/cu121/libtorch-win-shared-with-deps-2.2.0%2Bcu121.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip
Then, set the LIBTORCH
environment variable and append the library to your path as with the
PowerShell commands below before building burn-tch
or a crate which depends on it.
$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"
There is no official LibTorch distribution with MPS support at this time, so the easiest alternative is to use a PyTorch installation. This requires a Python installation.
Note: MPS acceleration is available on MacOS 12.3+.
pip install torch==2.2.0
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH
For a simple example, check out any of the test programs in src/bin/
. Each program
sets the device to use and performs a simple element-wise addition.
For a more complete example using the tch
backend, take a loot at the
Burn mnist example.