Skip to content

Commit

Permalink
Rollup merge of rust-lang#94703 - kjetilkjeka:nvptx-kernel-args-abi2,…
Browse files Browse the repository at this point in the history
… r=nagisa

Fix codegen bug in "ptx-kernel" abi related to arg passing

I found a codegen bug in the nvptx abi related to that args are passed as ptrs ([see comment](rust-lang#38788 (comment))), this is not as specified in the [ptx-interoperability doc](https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/) or how C/C++ does it. It will also almost always fail in practice since device/host uses different memory spaces for most hardware.

This PR fixes the bug and add tests for passing structs to ptx kernels.

I observed that all nvptx assembly tests had been marked as [ignore a long time ago](rust-lang#59752 (comment)). I'm not sure if the new one should be marked as ignore, it passed on my computer but it might fail if ptx-linker is missing on the server? I guess this is outside scope for this PR and should be looked at in a different issue/PR.

I only fixed the nvptx64-nvidia-cuda target and not the potential code paths for the non-existing 32bit target. Even though 32bit nvptx is not a supported target there are still some code under the hood supporting codegen for 32 bit ptx. I was advised to create an MCP to find out if this code should be removed or updated.

Perhaps ``@RDambrosio016`` would have interest in taking a quick look at this.
  • Loading branch information
GuillaumeGomez authored Apr 26, 2022
2 parents eaf8beb + 5bf5acc commit fe49981
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 9 deletions.
16 changes: 16 additions & 0 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,22 @@ where

pointee_info
}

fn is_adt(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Adt(..))
}

fn is_never(this: TyAndLayout<'tcx>) -> bool {
this.ty.kind() == &ty::Never
}

fn is_tuple(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Tuple(..))
}

fn is_unit(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0)
}
}

impl<'tcx> ty::Instance<'tcx> {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl<T> List<T> {
static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign);
unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) }
}

pub fn len(&self) -> usize {
self.len
}
}

impl<T: Copy> List<T> {
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_target/src/abi/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> {
"sparc" => sparc::compute_abi_info(cx, self),
"sparc64" => sparc64::compute_abi_info(cx, self),
"nvptx" => nvptx::compute_abi_info(self),
"nvptx64" => nvptx64::compute_abi_info(self),
"nvptx64" => {
if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel {
nvptx64::compute_ptx_kernel_abi_info(cx, self)
} else {
nvptx64::compute_abi_info(self)
}
}
"hexagon" => hexagon::compute_abi_info(self),
"riscv32" | "riscv64" => riscv::compute_abi_info(cx, self),
"wasm32" | "wasm64" => {
Expand Down
47 changes: 39 additions & 8 deletions compiler/rustc_target/src/abi/call/nvptx64.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
// Reference: PTX Writer's Guide to Interoperability
// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability

use crate::abi::call::{ArgAbi, FnAbi};
use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
use crate::abi::{HasDataLayout, TyAbiInterface};

fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 {
ret.make_indirect();
} else {
ret.extend_integer_width_to(64);
}
}

fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 {
arg.make_indirect();
} else {
arg.extend_integer_width_to(64);
}
}

fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
let align_bytes = arg.layout.align.abi.bytes();

let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) });
}
}

Expand All @@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
classify_arg(arg);
}
}

pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
panic!("Kernels should not return anything other than () or !");
}

for arg in &mut fn_abi.args {
if arg.is_ignore() {
continue;
}
classify_arg_kernel(cx, arg);
}
}
32 changes: 32 additions & 0 deletions compiler/rustc_target/src/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,10 @@ pub trait TyAbiInterface<'a, C>: Sized {
cx: &C,
offset: Size,
) -> Option<PointeeInfo>;
fn is_adt(this: TyAndLayout<'a, Self>) -> bool;
fn is_never(this: TyAndLayout<'a, Self>) -> bool;
fn is_tuple(this: TyAndLayout<'a, Self>) -> bool;
fn is_unit(this: TyAndLayout<'a, Self>) -> bool;
}

impl<'a, Ty> TyAndLayout<'a, Ty> {
Expand Down Expand Up @@ -1396,6 +1400,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
_ => false,
}
}

pub fn is_adt<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_adt(self)
}

pub fn is_never<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_never(self)
}

pub fn is_tuple<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_tuple(self)
}

pub fn is_unit<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_unit(self)
}
}

impl<'a, Ty> TyAndLayout<'a, Ty> {
Expand Down
Loading

0 comments on commit fe49981

Please sign in to comment.