Skip to content

Commit

Permalink
Improve CUDA support (#612)
Browse files Browse the repository at this point in the history
* Perform CUDA --device-link.

This allows to perform the final link with system linker.

* Add 'cudart' method mimicking the '--cudart' nvcc command-line option.

Try to locate the library in standard location relative to nvcc command.
If it fails, user is held responsible for specifying one in RUSTFLAGS.

* Add dummy CUDA test to cc-test.

Execution is bound to fail without card, but the failure is ignored.
It's rather a compile-n-link test. The test is suppressed if 'nvcc'
is not found on the $PATH.

* Add dummy CUDA CI test.

* Harmonize CUDA support with NVCC default --cudart static.

This can interfere with current deployments in the wild, in which
case some adjustments might be required. Most notably one might
have to add .cuda("none") to the corresponding Builder instantiation
to restore the original behaviour.
  • Loading branch information
dot-asm committed Aug 2, 2021
1 parent a11e066 commit 4a6e8b7
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 0 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,25 @@ jobs:
- run: cargo test ${{ matrix.no_run }} --manifest-path cc-test/Cargo.toml --target ${{ matrix.target }} --features parallel
- run: cargo test ${{ matrix.no_run }} --manifest-path cc-test/Cargo.toml --target ${{ matrix.target }} --release

cuda:
name: Test CUDA support
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@master
- name: Install cuda-minimal-build-11-4
shell: bash
run: |
# https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin
sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /"
sudo apt-get update
sudo apt-get -y install cuda-minimal-build-11-4
- name: Test 'cudart' feature
shell: bash
run: env PATH=/usr/local/cuda/bin:$PATH cargo test --manifest-path cc-test/Cargo.toml --features test_cuda

msrv:
name: MSRV
runs-on: ${{ matrix.os }}
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ fn main() {
cc::Build::new()
// Switch to CUDA C++ library compilation using NVCC.
.cuda(true)
.cudart("static")
// Generate code for Maxwell (GTX 970, 980, 980 Ti, Titan X).
.flag("-gencode").flag("arch=compute_52,code=sm_52")
// Generate code for Maxwell (Jetson TX1).
Expand Down
2 changes: 2 additions & 0 deletions cc-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ test = false

[build-dependencies]
cc = { path = ".." }
which = "^4.0"

[features]
parallel = ["cc/parallel"]
test_cuda = []
18 changes: 18 additions & 0 deletions cc-test/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,24 @@ fn main() {
.cpp(true)
.compile("baz");

if env::var("CARGO_FEATURE_TEST_CUDA").is_ok() {
// Detect if there is CUDA compiler and engage "cuda" feature.
let nvcc = match env::var("NVCC") {
Ok(var) => which::which(var),
Err(_) => which::which("nvcc"),
};
if nvcc.is_ok() {
cc::Build::new()
.cuda(true)
.cudart("static")
.file("src/cuda.cu")
.compile("libcuda.a");

// Communicate [cfg(feature = "cuda")] to test/all.rs.
println!("cargo:rustc-cfg=feature=\"cuda\"");
}
}

if target.contains("windows") {
cc::Build::new().file("src/windows.c").compile("windows");
}
Expand Down
5 changes: 5 additions & 0 deletions cc-test/src/cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <cuda.h>

__global__ void kernel() {}

extern "C" void cuda_kernel() { kernel<<<1, 1>>>(); }
11 changes: 11 additions & 0 deletions cc-test/tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ fn opt_linkage() {
assert_eq!(answer(), 42);
}
}

#[cfg(feature = "cuda")]
#[test]
fn cuda_here() {
extern "C" {
fn cuda_kernel();
}
unsafe {
cuda_kernel();
}
}
104 changes: 104 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pub struct Build {
cpp_link_stdlib: Option<Option<String>>,
cpp_set_stdlib: Option<String>,
cuda: bool,
cudart: Option<String>,
target: Option<String>,
host: Option<String>,
out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -298,6 +299,7 @@ impl Build {
cpp_link_stdlib: None,
cpp_set_stdlib: None,
cuda: false,
cudart: None,
target: None,
host: None,
out_dir: None,
Expand Down Expand Up @@ -611,6 +613,20 @@ impl Build {
self.cuda = cuda;
if cuda {
self.cpp = true;
self.cudart = Some("static".to_string());
}
self
}

/// Link CUDA run-time.
///
/// This option mimics the `--cudart` NVCC command-line option. Just like
/// the original it accepts `{none|shared|static}`, with default being
/// `static`. The method has to be invoked after `.cuda(true)`, or not
/// at all, if the default is right for the project.
pub fn cudart(&mut self, cudart: &str) -> &mut Build {
if self.cuda {
self.cudart = Some(cudart.to_string());
}
self
}
Expand Down Expand Up @@ -996,6 +1012,56 @@ impl Build {
}
}

let cudart = match &self.cudart {
Some(opt) => opt.as_str(), // {none|shared|static}
None => "none",
};
if cudart != "none" {
if let Some(nvcc) = which(&self.get_compiler().path) {
// Try to figure out the -L search path. If it fails,
// it's on user to specify one by passing it through
// RUSTFLAGS environment variable.
let mut libtst = false;
let mut libdir = nvcc;
libdir.pop(); // remove 'nvcc'
libdir.push("..");
let target_arch = env::var("CARGO_CFG_TARGET_ARCH").unwrap();
if cfg!(target_os = "linux") {
libdir.push("targets");
libdir.push(target_arch.to_owned() + "-linux");
libdir.push("lib");
libtst = true;
} else if cfg!(target_env = "msvc") {
libdir.push("lib");
match target_arch.as_str() {
"x86_64" => {
libdir.push("x64");
libtst = true;
}
"x86" => {
libdir.push("Win32");
libtst = true;
}
_ => libtst = false,
}
}
if libtst && libdir.is_dir() {
println!(
"cargo:rustc-link-search=native={}",
libdir.to_str().unwrap()
);
}

// And now the -l flag.
let lib = match cudart {
"shared" => "cudart",
"static" => "cudart_static",
bad => panic!("unsupported cudart option: {}", bad),
};
println!("cargo:rustc-link-lib={}", lib);
}
}

Ok(())
}

Expand Down Expand Up @@ -1205,6 +1271,9 @@ impl Build {
if !msvc || !is_asm || !is_arm {
cmd.arg("-c");
}
if self.cuda && self.files.len() > 1 {
cmd.arg("--device-c");
}
cmd.arg(&obj.src);
if cfg!(target_os = "macos") {
self.fix_env_for_apple_os(&mut cmd)?;
Expand Down Expand Up @@ -1811,6 +1880,21 @@ impl Build {
self.assemble_progressive(dst, chunk)?;
}

if self.cuda {
// Link the device-side code and add it to the target library,
// so that non-CUDA linker can link the final binary.

let out_dir = self.get_out_dir()?;
let dlink = out_dir.join(lib_name.to_owned() + "_dlink.o");
let mut nvcc = self.get_compiler().to_command();
nvcc.arg("--device-link")
.arg("-o")
.arg(dlink.clone())
.arg(dst);
run(&mut nvcc, "nvcc")?;
self.assemble_progressive(dst, &[dlink])?;
}

let target = self.get_target()?;
if target.contains("msvc") {
// The Rust compiler will look for libfoo.a and foo.lib, but the
Expand Down Expand Up @@ -3100,3 +3184,23 @@ fn map_darwin_target_from_rust_to_compiler_architecture(target: &str) -> Option<
None
}
}

fn which(tool: &Path) -> Option<PathBuf> {
fn check_exe(exe: &mut PathBuf) -> bool {
let exe_ext = std::env::consts::EXE_EXTENSION;
exe.exists() || (!exe_ext.is_empty() && exe.set_extension(exe_ext) && exe.exists())
}

// If |tool| is not just one "word," assume it's an actual path...
if tool.components().count() > 1 {
let mut exe = PathBuf::from(tool);
return if check_exe(&mut exe) { Some(exe) } else { None };
}

// Loop through PATH entries searching for the |tool|.
let path_entries = env::var_os("PATH")?;
env::split_paths(&path_entries).find_map(|path_entry| {
let mut exe = path_entry.join(tool);
return if check_exe(&mut exe) { Some(exe) } else { None };
})
}

0 comments on commit 4a6e8b7

Please sign in to comment.