diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 5c4aa7da51431..fba92f996aa6b 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -2,7 +2,9 @@ use std::ffi::CString; use llvm::Linkage::*; use rustc_abi::Align; +use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; +use rustc_middle::bug; use rustc_middle::ty::offload_meta::OffloadMetadata; use crate::builder::Builder; @@ -69,6 +71,57 @@ impl<'ll> OffloadGlobals<'ll> { } } +pub(crate) struct OffloadKernelDims<'ll> { + num_workgroups: &'ll Value, + threads_per_block: &'ll Value, + workgroup_dims: &'ll Value, + thread_dims: &'ll Value, +} + +impl<'ll> OffloadKernelDims<'ll> { + pub(crate) fn from_operands<'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, + workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>, + thread_op: &OperandRef<'tcx, &'ll llvm::Value>, + ) -> Self { + let cx = builder.cx; + let arr_ty = cx.type_array(cx.type_i32(), 3); + let four = Align::from_bytes(4).unwrap(); + + let OperandValue::Ref(place) = workgroup_op.val else { + bug!("expected array operand by reference"); + }; + let workgroup_val = builder.load(arr_ty, place.llval, four); + + let OperandValue::Ref(place) = thread_op.val else { + bug!("expected array operand by reference"); + }; + let thread_val = builder.load(arr_ty, place.llval, four); + + fn mul_dim3<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, + arr: &'ll Value, + ) -> &'ll Value { + let x = builder.extract_value(arr, 0); + let y = builder.extract_value(arr, 1); + let z = builder.extract_value(arr, 2); + + let xy = builder.mul(x, y); + builder.mul(xy, z) + } + + let num_workgroups = mul_dim3(builder, workgroup_val); + let threads_per_block = mul_dim3(builder, thread_val); + + OffloadKernelDims { + workgroup_dims: workgroup_val, + thread_dims: thread_val, + num_workgroups, + threads_per_block, + } + } +} + // ; Function Attrs: nounwind // declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2 fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) { @@ -204,12 +257,12 @@ impl KernelArgsTy { num_args: u64, memtransfer_types: &'ll Value, geps: [&'ll Value; 3], + workgroup_dims: &'ll Value, + thread_dims: &'ll Value, ) -> [(Align, &'ll Value); 13] { let four = Align::from_bytes(4).expect("4 Byte alignment should work"); let eight = Align::EIGHT; - let ti32 = cx.type_i32(); - let ci32_0 = cx.get_const_i32(0); [ (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)), (four, cx.get_const_i32(num_args)), @@ -222,8 +275,8 @@ impl KernelArgsTy { (eight, cx.const_null(cx.type_ptr())), // dbg (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)), (eight, cx.get_const_i64(KernelArgsTy::FLAGS)), - (four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])), - (four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])), + (four, workgroup_dims), + (four, thread_dims), (four, cx.get_const_i32(0)), ] } @@ -413,10 +466,13 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( types: &[&Type], metadata: &[OffloadMetadata], offload_globals: &OffloadGlobals<'ll>, + offload_dims: &OffloadKernelDims<'ll>, ) { let cx = builder.cx; let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } = offload_data; + let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } = + offload_dims; let tgt_decl = offload_globals.launcher_fn; let tgt_target_kernel_ty = offload_globals.launcher_ty; @@ -554,7 +610,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( num_args, s_ident_t, ); - let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps); + let values = + KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims); // Step 3) // Here we fill the KernelArgsTy, see the documentation above @@ -567,9 +624,8 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( s_ident_t, // FIXME(offload) give users a way to select which GPU to use. cx.get_const_i64(u64::MAX), // MAX == -1. - // FIXME(offload): Don't hardcode the numbers of threads in the future. - cx.get_const_i32(2097152), - cx.get_const_i32(256), + num_workgroups, + threads_per_block, region_id, a5, ]; diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index b4057eea735ea..481f75f337d63 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -30,7 +30,7 @@ use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; -use crate::builder::gpu_offload::{gen_call_handling, gen_define_handling}; +use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling}; use crate::context::CodegenCx; use crate::declare::declare_raw_fn; use crate::errors::{ @@ -1384,7 +1384,8 @@ fn codegen_offload<'ll, 'tcx>( } }; - let args = get_args_from_tuple(bx, args[1], fn_target); + let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]); + let args = get_args_from_tuple(bx, args[3], fn_target); let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE); let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder(); @@ -1403,7 +1404,7 @@ fn codegen_offload<'ll, 'tcx>( } }; let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals); - gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals); + gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims); } fn get_args_from_tuple<'ll, 'tcx>( diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 4e8333f678b66..9eaf5319cb040 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -4,7 +4,7 @@ use rustc_abi::ExternAbi; use rustc_errors::DiagMessage; use rustc_hir::{self as hir, LangItem}; use rustc_middle::traits::{ObligationCause, ObligationCauseCode}; -use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_middle::ty::{self, Const, Ty, TyCtxt}; use rustc_span::def_id::LocalDefId; use rustc_span::{Span, Symbol, sym}; @@ -315,7 +315,17 @@ pub(crate) fn check_intrinsic_type( let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity(); (0, 0, vec![type_id, type_id], tcx.types.bool) } - sym::offload => (3, 0, vec![param(0), param(1)], param(2)), + sym::offload => ( + 3, + 0, + vec![ + param(0), + Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)), + Ty::new_array_with_const_len(tcx, tcx.types.u32, Const::from_target_usize(tcx, 3)), + param(1), + ], + param(2), + ), sym::offset => (2, 0, vec![param(0), param(1)], param(0)), sym::arith_offset => ( 1, diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index c3834607e9236..c282f2211f650 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -1218,6 +1218,12 @@ impl<'tcx> Ty<'tcx> { *self.kind() == Str } + /// Returns true if this type is `&str`. The reference's lifetime is ignored. + #[inline] + pub fn is_imm_ref_str(self) -> bool { + matches!(self.kind(), ty::Ref(_, inner, hir::Mutability::Not) if inner.is_str()) + } + #[inline] pub fn is_param(self, index: u32) -> bool { match self.kind() { diff --git a/compiler/rustc_mir_build/src/builder/matches/buckets.rs b/compiler/rustc_mir_build/src/builder/matches/buckets.rs index fce35aa9ef306..0d2e9bf87585d 100644 --- a/compiler/rustc_mir_build/src/builder/matches/buckets.rs +++ b/compiler/rustc_mir_build/src/builder/matches/buckets.rs @@ -314,7 +314,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } ( - TestKind::Eq { value: test_val, .. }, + TestKind::StringEq { value: test_val, .. }, + TestableCase::Constant { value: case_val, kind: PatConstKind::String }, + ) + | ( + TestKind::ScalarEq { value: test_val, .. }, TestableCase::Constant { value: case_val, kind: PatConstKind::Float | PatConstKind::Other, @@ -347,7 +351,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | TestKind::If | TestKind::SliceLen { .. } | TestKind::Range { .. } - | TestKind::Eq { .. } + | TestKind::StringEq { .. } + | TestKind::ScalarEq { .. } | TestKind::Deref { .. }, _, ) => { diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 386ca61a61241..f0114c2193c3e 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use rustc_abi::FieldIdx; use rustc_middle::mir::*; +use rustc_middle::span_bug; use rustc_middle::thir::*; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; @@ -173,9 +174,21 @@ impl<'tcx> MatchPairTree<'tcx> { PatConstKind::IntOrChar } else if pat_ty.is_floating_point() { PatConstKind::Float + } else if pat_ty.is_str() { + // Deref-patterns can cause string-literal patterns to have + // type `str` instead of the usual `&str`. + if !cx.tcx.features().deref_patterns() { + span_bug!( + pattern.span, + "const pattern has type `str` but deref_patterns is not enabled" + ); + } + PatConstKind::String + } else if pat_ty.is_imm_ref_str() { + PatConstKind::String } else { // FIXME(Zalathar): This still covers several different - // categories (e.g. raw pointer, string, pattern-type) + // categories (e.g. raw pointer, pattern-type) // which could be split out into their own kinds. PatConstKind::Other }; diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index 421065a894119..9080e2ba801bf 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -1290,9 +1290,10 @@ enum PatConstKind { /// These types don't support `SwitchInt` and require an equality test, /// but can also interact with range pattern tests. Float, + /// Constant string values, tested via string equality. + String, /// Any other constant-pattern is usually tested via some kind of equality /// check. Types that might be encountered here include: - /// - `&str` /// - raw pointers derived from integer values /// - pattern types, e.g. `pattern_type!(u32 is 1..)` Other, @@ -1368,14 +1369,20 @@ enum TestKind<'tcx> { /// Test whether a `bool` is `true` or `false`. If, - /// Test for equality with value, possibly after an unsizing coercion to - /// `cast_ty`, - Eq { + /// Tests the place against a string constant using string equality. + StringEq { + /// Constant `&str` value to test against. value: ty::Value<'tcx>, - // Integer types are handled by `SwitchInt`, and constants with ADT - // types and `&[T]` types are converted back into patterns, so this can - // only be `&str` or floats. - cast_ty: Ty<'tcx>, + /// Type of the corresponding pattern node. Usually `&str`, but could + /// be `str` for patterns like `deref!("..."): String`. + pat_ty: Ty<'tcx>, + }, + + /// Tests the place against a constant using scalar equality. + ScalarEq { + value: ty::Value<'tcx>, + /// Type of the corresponding pattern node. + pat_ty: Ty<'tcx>, }, /// Test whether the value falls within an inclusive or exclusive range. diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 972d9f66faddc..c2e39d47a92ca 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -38,11 +38,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { TestableCase::Constant { value: _, kind: PatConstKind::IntOrChar } => { TestKind::SwitchInt } - TestableCase::Constant { value, kind: PatConstKind::Float } => { - TestKind::Eq { value, cast_ty: match_pair.pattern_ty } + TestableCase::Constant { value, kind: PatConstKind::String } => { + TestKind::StringEq { value, pat_ty: match_pair.pattern_ty } } - TestableCase::Constant { value, kind: PatConstKind::Other } => { - TestKind::Eq { value, cast_ty: match_pair.pattern_ty } + TestableCase::Constant { value, kind: PatConstKind::Float | PatConstKind::Other } => { + TestKind::ScalarEq { value, pat_ty: match_pair.pattern_ty } } TestableCase::Range(ref range) => { @@ -141,17 +141,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.cfg.terminate(block, self.source_info(match_start_span), terminator); } - TestKind::Eq { value, mut cast_ty } => { + TestKind::StringEq { value, pat_ty } => { let tcx = self.tcx; let success_block = target_block(TestBranch::Success); let fail_block = target_block(TestBranch::Failure); - let mut expect_ty = value.ty; - let mut expect = self.literal_operand(test.span, Const::from_ty_value(tcx, value)); + let expected_value_ty = value.ty; + let expected_value_operand = + self.literal_operand(test.span, Const::from_ty_value(tcx, value)); - let mut place = place; + let mut actual_value_ty = pat_ty; + let mut actual_value_place = place; - match cast_ty.kind() { + match pat_ty.kind() { ty::Str => { // String literal patterns may have type `str` if `deref_patterns` is // enabled, in order to allow `deref!("..."): String`. In this case, `value` @@ -172,11 +174,43 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ref_place, Rvalue::Ref(re_erased, BorrowKind::Shared, place), ); - place = ref_place; - cast_ty = ref_str_ty; + actual_value_place = ref_place; + actual_value_ty = ref_str_ty; } + _ => {} + } + + assert_eq!(expected_value_ty, actual_value_ty); + assert!(actual_value_ty.is_imm_ref_str()); + + // Compare two strings using `::eq`. + // (Interestingly this means that exhaustiveness analysis relies, for soundness, + // on the `PartialEq` impl for `str` to be correct!) + self.string_compare( + block, + success_block, + fail_block, + source_info, + expected_value_operand, + Operand::Copy(actual_value_place), + ); + } + + TestKind::ScalarEq { value, pat_ty } => { + let tcx = self.tcx; + let success_block = target_block(TestBranch::Success); + let fail_block = target_block(TestBranch::Failure); + + let mut expected_value_ty = value.ty; + let mut expected_value_operand = + self.literal_operand(test.span, Const::from_ty_value(tcx, value)); + + let mut actual_value_ty = pat_ty; + let mut actual_value_place = place; + + match pat_ty.kind() { &ty::Pat(base, _) => { - assert_eq!(cast_ty, value.ty); + assert_eq!(pat_ty, value.ty); assert!(base.is_trivially_pure_clone_copy()); let transmuted_place = self.temp(base, test.span); @@ -184,7 +218,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { block, self.source_info(scrutinee_span), transmuted_place, - Rvalue::Cast(CastKind::Transmute, Operand::Copy(place), base), + Rvalue::Cast( + CastKind::Transmute, + Operand::Copy(actual_value_place), + base, + ), ); let transmuted_expect = self.temp(base, test.span); @@ -192,54 +230,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { block, self.source_info(test.span), transmuted_expect, - Rvalue::Cast(CastKind::Transmute, expect, base), + Rvalue::Cast(CastKind::Transmute, expected_value_operand, base), ); - place = transmuted_place; - expect = Operand::Copy(transmuted_expect); - cast_ty = base; - expect_ty = base; + actual_value_place = transmuted_place; + actual_value_ty = base; + expected_value_operand = Operand::Copy(transmuted_expect); + expected_value_ty = base; } _ => {} } - assert_eq!(expect_ty, cast_ty); - if !cast_ty.is_scalar() { - // Use `PartialEq::eq` instead of `BinOp::Eq` - // (the binop can only handle primitives) - // Make sure that we do *not* call any user-defined code here. - // The only type that can end up here is string literals, which have their - // comparison defined in `core`. - // (Interestingly this means that exhaustiveness analysis relies, for soundness, - // on the `PartialEq` impl for `str` to b correct!) - match *cast_ty.kind() { - ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {} - _ => { - span_bug!( - source_info.span, - "invalid type for non-scalar compare: {cast_ty}" - ) - } - }; - self.string_compare( - block, - success_block, - fail_block, - source_info, - expect, - Operand::Copy(place), - ); - } else { - self.compare( - block, - success_block, - fail_block, - source_info, - BinOp::Eq, - expect, - Operand::Copy(place), - ); - } + assert_eq!(expected_value_ty, actual_value_ty); + assert!(actual_value_ty.is_scalar()); + + self.compare( + block, + success_block, + fail_block, + source_info, + BinOp::Eq, + expected_value_operand, + Operand::Copy(actual_value_place), + ); } TestKind::Range(ref range) => { diff --git a/library/alloc/src/ffi/c_str.rs b/library/alloc/src/ffi/c_str.rs index 59f5857b97aa1..d6dcba7107a9c 100644 --- a/library/alloc/src/ffi/c_str.rs +++ b/library/alloc/src/ffi/c_str.rs @@ -636,7 +636,7 @@ impl CString { Self { inner: v.into_boxed_slice() } } - /// Attempts to converts a [Vec]<[u8]> to a [`CString`]. + /// Attempts to convert a [Vec]<[u8]> to a [`CString`]. /// /// Runtime checks are present to ensure there is only one nul byte in the /// [`Vec`], its last element. diff --git a/library/core/src/convert/mod.rs b/library/core/src/convert/mod.rs index 89cda30c03036..ef4ab15f93c0b 100644 --- a/library/core/src/convert/mod.rs +++ b/library/core/src/convert/mod.rs @@ -308,8 +308,8 @@ pub const trait AsRef: PointeeSized { /// both `AsMut>` and `AsMut<[T]>`. /// /// In the following, the example functions `caesar` and `null_terminate` provide a generic -/// interface which work with any type that can be converted by cheap mutable-to-mutable conversion -/// into a byte slice (`[u8]`) or byte vector (`Vec`), respectively. +/// interface which works with any type that can be converted by cheap mutable-to-mutable conversion +/// into a byte slice (`[u8]`) or a byte vector (`Vec`), respectively. /// /// [dereference]: core::ops::DerefMut /// [target type]: core::ops::Deref::Target diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index d46d3ed9d5137..0ae8d3d4a4ce1 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3385,11 +3385,17 @@ pub const fn autodiff(f: F, df: G, args: T) -> /// - `T`: A tuple of arguments passed to `f`. /// - `R`: The return type of the kernel. /// +/// Arguments: +/// - `f`: The kernel function to offload. +/// - `workgroup_dim`: A 3D size specifying the number of workgroups to launch. +/// - `thread_dim`: A 3D size specifying the number of threads per workgroup. +/// - `args`: A tuple of arguments forwarded to `f`. +/// /// Example usage (pseudocode): /// /// ```rust,ignore (pseudocode) /// fn kernel(x: *mut [f64; 128]) { -/// core::intrinsics::offload(kernel_1, (x,)) +/// core::intrinsics::offload(kernel_1, [256, 1, 1], [32, 1, 1], (x,)) /// } /// /// #[cfg(target_os = "linux")] @@ -3408,7 +3414,12 @@ pub const fn autodiff(f: F, df: G, args: T) -> /// . #[rustc_nounwind] #[rustc_intrinsic] -pub const fn offload(f: F, args: T) -> R; +pub const fn offload( + f: F, + workgroup_dim: [u32; 3], + thread_dim: [u32; 3], + args: T, +) -> R; /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] diff --git a/library/unwind/src/libunwind.rs b/library/unwind/src/libunwind.rs index 9ac9b54ed4a29..091efa9c51292 100644 --- a/library/unwind/src/libunwind.rs +++ b/library/unwind/src/libunwind.rs @@ -78,8 +78,8 @@ pub const unwinder_private_data_size: usize = 20; #[cfg(all(target_arch = "wasm32", target_os = "linux"))] pub const unwinder_private_data_size: usize = 2; -#[cfg(all(target_arch = "hexagon", target_os = "linux"))] -pub const unwinder_private_data_size: usize = 35; +#[cfg(target_arch = "hexagon")] +pub const unwinder_private_data_size: usize = 5; #[cfg(any(target_arch = "loongarch32", target_arch = "loongarch64"))] pub const unwinder_private_data_size: usize = 2; diff --git a/src/build_helper/src/npm.rs b/src/build_helper/src/npm.rs index f250ced4dc8a7..2a558b5618b3d 100644 --- a/src/build_helper/src/npm.rs +++ b/src/build_helper/src/npm.rs @@ -22,7 +22,16 @@ pub fn install(src_root_path: &Path, out_dir: &Path, yarn: &Path) -> Result::from(format!( + "unable to run yarn: {}", + err.kind() + ))) + })? + .wait()?; if !exit_status.success() { eprintln!("yarn install did not exit successfully"); return Err(io::Error::other(Box::::from(format!( diff --git a/src/doc/rustc-dev-guide/src/offload/usage.md b/src/doc/rustc-dev-guide/src/offload/usage.md index 062534a4b6556..4d3222123aaff 100644 --- a/src/doc/rustc-dev-guide/src/offload/usage.md +++ b/src/doc/rustc-dev-guide/src/offload/usage.md @@ -57,7 +57,7 @@ fn main() { #[inline(never)] unsafe fn kernel(x: *mut [f64; 256]) { - core::intrinsics::offload(kernel_1, (x,)) + core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,)) } #[cfg(target_os = "linux")] diff --git a/tests/codegen-llvm/gpu_offload/control_flow.rs b/tests/codegen-llvm/gpu_offload/control_flow.rs index 4a213f5a33a8d..28ee9c85b0edc 100644 --- a/tests/codegen-llvm/gpu_offload/control_flow.rs +++ b/tests/codegen-llvm/gpu_offload/control_flow.rs @@ -21,14 +21,19 @@ // CHECK-NOT define // CHECK: bb3 // CHECK: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null) -// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2097152, i32 256, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args) +// CHECK: %10 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.foo.region_id, ptr nonnull %kernel_args) // CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.foo, ptr null, ptr null) #[unsafe(no_mangle)] unsafe fn main() { let A = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; for i in 0..100 { - core::intrinsics::offload::<_, _, ()>(foo, (A.as_ptr() as *const [f32; 6],)); + core::intrinsics::offload::<_, _, ()>( + foo, + [256, 1, 1], + [32, 1, 1], + (A.as_ptr() as *const [f32; 6],), + ); } } diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index ac179a65828d7..b4d17143720a7 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -82,14 +82,14 @@ fn main() { // CHECK-NEXT: %5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40 // CHECK-NEXT: %6 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72 // CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %5, i8 0, i64 32, i1 false) -// CHECK-NEXT: store <4 x i32> , ptr %6, align 8 -// CHECK-NEXT: %.fca.1.gep3 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88 -// CHECK-NEXT: store i32 0, ptr %.fca.1.gep3, align 8 -// CHECK-NEXT: %.fca.2.gep4 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92 -// CHECK-NEXT: store i32 0, ptr %.fca.2.gep4, align 4 +// CHECK-NEXT: store <4 x i32> , ptr %6, align 8 +// CHECK-NEXT: %.fca.1.gep5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88 +// CHECK-NEXT: store i32 1, ptr %.fca.1.gep5, align 8 +// CHECK-NEXT: %.fca.2.gep7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92 +// CHECK-NEXT: store i32 1, ptr %.fca.2.gep7, align 4 // CHECK-NEXT: %7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96 // CHECK-NEXT: store i32 0, ptr %7, align 8 -// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 2097152, i32 256, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args) +// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args) // CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null) // CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull %EmptyDesc) // CHECK-NEXT: ret void @@ -98,7 +98,7 @@ fn main() { #[unsafe(no_mangle)] #[inline(never)] pub fn kernel_1(x: &mut [f32; 256]) { - core::intrinsics::offload(_kernel_1, (x,)) + core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,)) } #[unsafe(no_mangle)]