Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 64 additions & 8 deletions compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::ffi::CString;

use llvm::Linkage::*;
use rustc_abi::Align;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::bug;
use rustc_middle::ty::offload_meta::OffloadMetadata;

use crate::builder::Builder;
Expand Down Expand Up @@ -69,6 +71,57 @@ impl<'ll> OffloadGlobals<'ll> {
}
}

pub(crate) struct OffloadKernelDims<'ll> {
num_workgroups: &'ll Value,
threads_per_block: &'ll Value,
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
}

impl<'ll> OffloadKernelDims<'ll> {
pub(crate) fn from_operands<'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
) -> Self {
let cx = builder.cx;
let arr_ty = cx.type_array(cx.type_i32(), 3);
let four = Align::from_bytes(4).unwrap();

let OperandValue::Ref(place) = workgroup_op.val else {
bug!("expected array operand by reference");
};
let workgroup_val = builder.load(arr_ty, place.llval, four);

let OperandValue::Ref(place) = thread_op.val else {
bug!("expected array operand by reference");
};
let thread_val = builder.load(arr_ty, place.llval, four);

fn mul_dim3<'ll, 'tcx>(
builder: &mut Builder<'_, 'll, 'tcx>,
arr: &'ll Value,
) -> &'ll Value {
let x = builder.extract_value(arr, 0);
let y = builder.extract_value(arr, 1);
let z = builder.extract_value(arr, 2);

let xy = builder.mul(x, y);
builder.mul(xy, z)
}

let num_workgroups = mul_dim3(builder, workgroup_val);
let threads_per_block = mul_dim3(builder, thread_val);

OffloadKernelDims {
workgroup_dims: workgroup_val,
thread_dims: thread_val,
num_workgroups,
threads_per_block,
}
}
}

// ; Function Attrs: nounwind
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
Expand Down Expand Up @@ -204,12 +257,12 @@ impl KernelArgsTy {
num_args: u64,
memtransfer_types: &'ll Value,
geps: [&'ll Value; 3],
workgroup_dims: &'ll Value,
thread_dims: &'ll Value,
) -> [(Align, &'ll Value); 13] {
let four = Align::from_bytes(4).expect("4 Byte alignment should work");
let eight = Align::EIGHT;

let ti32 = cx.type_i32();
let ci32_0 = cx.get_const_i32(0);
[
(four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
(four, cx.get_const_i32(num_args)),
Expand All @@ -222,8 +275,8 @@ impl KernelArgsTy {
(eight, cx.const_null(cx.type_ptr())), // dbg
(eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
(eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
(four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
(four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
(four, workgroup_dims),
(four, thread_dims),
(four, cx.get_const_i32(0)),
]
}
Expand Down Expand Up @@ -413,10 +466,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
types: &[&Type],
metadata: &[OffloadMetadata],
offload_globals: &OffloadGlobals<'ll>,
offload_dims: &OffloadKernelDims<'ll>,
) {
let cx = builder.cx;
let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
offload_data;
let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
offload_dims;

let tgt_decl = offload_globals.launcher_fn;
let tgt_target_kernel_ty = offload_globals.launcher_ty;
Expand Down Expand Up @@ -554,7 +610,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
num_args,
s_ident_t,
);
let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
let values =
KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);

// Step 3)
// Here we fill the KernelArgsTy, see the documentation above
Expand All @@ -567,9 +624,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
s_ident_t,
// FIXME(offload) give users a way to select which GPU to use.
cx.get_const_i64(u64::MAX), // MAX == -1.
// FIXME(offload): Don't hardcode the numbers of threads in the future.
cx.get_const_i32(2097152),
cx.get_const_i32(256),
num_workgroups,
threads_per_block,
region_id,
a5,
];
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use tracing::debug;
use crate::abi::FnAbiLlvmExt;
use crate::builder::Builder;
use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call};
use crate::builder::gpu_offload::{gen_call_handling, gen_define_handling};
use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling};
use crate::context::CodegenCx;
use crate::declare::declare_raw_fn;
use crate::errors::{
Expand Down Expand Up @@ -1384,7 +1384,8 @@ fn codegen_offload<'ll, 'tcx>(
}
};

let args = get_args_from_tuple(bx, args[1], fn_target);
let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]);
let args = get_args_from_tuple(bx, args[3], fn_target);
let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE);

let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder();
Expand All @@ -1403,7 +1404,7 @@ fn codegen_offload<'ll, 'tcx>(
}
};
let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals);
gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims);
}

fn get_args_from_tuple<'ll, 'tcx>(
Expand Down
14 changes: 12 additions & 2 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_abi::ExternAbi;
use rustc_errors::DiagMessage;
use rustc_hir::{self as hir, LangItem};
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_span::def_id::LocalDefId;
use rustc_span::{Span, Symbol, sym};

Expand Down Expand Up @@ -315,7 +315,17 @@ pub(crate) fn check_intrinsic_type(
let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity();
(0, 0, vec![type_id, type_id], tcx.types.bool)
}
sym::offload => (3, 0, vec![param(0), param(1)], param(2)),
sym::offload => (
3,
0,
vec![
param(0),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)),
param(1),
],
param(2),
),
sym::offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::arith_offset => (
1,
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,12 @@ impl<'tcx> Ty<'tcx> {
*self.kind() == Str
}

/// Returns true if this type is `&str`. The reference's lifetime is ignored.
#[inline]
pub fn is_imm_ref_str(self) -> bool {
matches!(self.kind(), ty::Ref(_, inner, hir::Mutability::Not) if inner.is_str())
}

#[inline]
pub fn is_param(self, index: u32) -> bool {
match self.kind() {
Expand Down
9 changes: 7 additions & 2 deletions compiler/rustc_mir_build/src/builder/matches/buckets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

(
TestKind::Eq { value: test_val, .. },
TestKind::StringEq { value: test_val, .. },
TestableCase::Constant { value: case_val, kind: PatConstKind::String },
)
| (
TestKind::ScalarEq { value: test_val, .. },
TestableCase::Constant {
value: case_val,
kind: PatConstKind::Float | PatConstKind::Other,
Expand Down Expand Up @@ -347,7 +351,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
| TestKind::If
| TestKind::SliceLen { .. }
| TestKind::Range { .. }
| TestKind::Eq { .. }
| TestKind::StringEq { .. }
| TestKind::ScalarEq { .. }
| TestKind::Deref { .. },
_,
) => {
Expand Down
15 changes: 14 additions & 1 deletion compiler/rustc_mir_build/src/builder/matches/match_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use rustc_abi::FieldIdx;
use rustc_middle::mir::*;
use rustc_middle::span_bug;
use rustc_middle::thir::*;
use rustc_middle::ty::{self, Ty, TypeVisitableExt};

Expand Down Expand Up @@ -173,9 +174,21 @@ impl<'tcx> MatchPairTree<'tcx> {
PatConstKind::IntOrChar
} else if pat_ty.is_floating_point() {
PatConstKind::Float
} else if pat_ty.is_str() {
// Deref-patterns can cause string-literal patterns to have
// type `str` instead of the usual `&str`.
if !cx.tcx.features().deref_patterns() {
span_bug!(
pattern.span,
"const pattern has type `str` but deref_patterns is not enabled"
);
}
PatConstKind::String
} else if pat_ty.is_imm_ref_str() {
PatConstKind::String
} else {
// FIXME(Zalathar): This still covers several different
// categories (e.g. raw pointer, string, pattern-type)
// categories (e.g. raw pointer, pattern-type)
// which could be split out into their own kinds.
PatConstKind::Other
};
Expand Down
23 changes: 15 additions & 8 deletions compiler/rustc_mir_build/src/builder/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1290,9 +1290,10 @@ enum PatConstKind {
/// These types don't support `SwitchInt` and require an equality test,
/// but can also interact with range pattern tests.
Float,
/// Constant string values, tested via string equality.
String,
/// Any other constant-pattern is usually tested via some kind of equality
/// check. Types that might be encountered here include:
/// - `&str`
/// - raw pointers derived from integer values
/// - pattern types, e.g. `pattern_type!(u32 is 1..)`
Other,
Expand Down Expand Up @@ -1368,14 +1369,20 @@ enum TestKind<'tcx> {
/// Test whether a `bool` is `true` or `false`.
If,

/// Test for equality with value, possibly after an unsizing coercion to
/// `cast_ty`,
Eq {
/// Tests the place against a string constant using string equality.
StringEq {
/// Constant `&str` value to test against.
value: ty::Value<'tcx>,
// Integer types are handled by `SwitchInt`, and constants with ADT
// types and `&[T]` types are converted back into patterns, so this can
// only be `&str` or floats.
cast_ty: Ty<'tcx>,
/// Type of the corresponding pattern node. Usually `&str`, but could
/// be `str` for patterns like `deref!("..."): String`.
pat_ty: Ty<'tcx>,
},

/// Tests the place against a constant using scalar equality.
ScalarEq {
value: ty::Value<'tcx>,
/// Type of the corresponding pattern node.
pat_ty: Ty<'tcx>,
},

/// Test whether the value falls within an inclusive or exclusive range.
Expand Down
Loading
Loading