Skip to content

Commit

Permalink
Rollup merge of #128731 - RalfJung:simd-shuffle-vector, r=workingjubilee
Browse files Browse the repository at this point in the history
simd_shuffle intrinsic: allow argument to be passed as vector

See rust-lang/rust#128738 for context.

I'd like to get rid of [this hack](https://github.com/rust-lang/rust/blob/6c0b89dfac65be9a5be12f938f23098ebc36c635/compiler/rustc_codegen_ssa/src/mir/block.rs#L922-L935). rust-lang/rust#128537 almost lets us do that since constant SIMD vectors will then be passed as immediate arguments. However, simd_shuffle for some reason actually takes an *array* as argument, not a vector, so the hack is still required to ensure that the array becomes an immediate (which then later stages of codegen convert into a vector, as that's what LLVM needs).

This PR prepares simd_shuffle to also support a vector as the `idx` argument. Once this lands, stdarch can hopefully be updated to pass `idx` as a vector, and then support for arrays can be removed, which finally lets us get rid of that hack.
  • Loading branch information
tgross35 authored Aug 27, 2024
2 parents 9d12735 + 7a151d5 commit f1d136b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
44 changes: 30 additions & 14 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1923,15 +1923,11 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
v2: RValue<'gcc>,
mask: RValue<'gcc>,
) -> RValue<'gcc> {
let struct_type = mask.get_type().is_struct().expect("mask should be of struct type");

// TODO(antoyo): use a recursive unqualified() here.
let vector_type = v1.get_type().unqualified().dyncast_vector().expect("vector type");
let element_type = vector_type.get_element_type();
let vec_num_units = vector_type.get_num_units();

let mask_num_units = struct_type.get_field_count();
let mut vector_elements = vec![];
let mask_element_type = if element_type.is_integral() {
element_type
} else {
Expand All @@ -1942,19 +1938,39 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
#[cfg(not(feature = "master"))]
self.int_type
};
for i in 0..mask_num_units {
let field = struct_type.get_field(i as i32);
vector_elements.push(self.context.new_cast(
self.location,
mask.access_field(self.location, field).to_rvalue(),
mask_element_type,
));
}

let mut mask_elements = if let Some(vector_type) = mask.get_type().dyncast_vector() {
let mask_num_units = vector_type.get_num_units();
let mut mask_elements = vec![];
for i in 0..mask_num_units {
let index = self.context.new_rvalue_from_long(self.cx.type_u32(), i as _);
mask_elements.push(self.context.new_cast(
self.location,
self.extract_element(mask, index).to_rvalue(),
mask_element_type,
));
}
mask_elements
} else {
let struct_type = mask.get_type().is_struct().expect("mask should be of struct type");
let mask_num_units = struct_type.get_field_count();
let mut mask_elements = vec![];
for i in 0..mask_num_units {
let field = struct_type.get_field(i as i32);
mask_elements.push(self.context.new_cast(
self.location,
mask.access_field(self.location, field).to_rvalue(),
mask_element_type,
));
}
mask_elements
};
let mask_num_units = mask_elements.len();

// NOTE: the mask needs to be the same length as the input vectors, so add the missing
// elements in the mask if needed.
for _ in mask_num_units..vec_num_units {
vector_elements.push(self.context.new_rvalue_zero(mask_element_type));
mask_elements.push(self.context.new_rvalue_zero(mask_element_type));
}

let result_type = self.context.new_vector_type(element_type, mask_num_units as u64);
Expand Down Expand Up @@ -1998,7 +2014,7 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {

let new_mask_num_units = std::cmp::max(mask_num_units, vec_num_units);
let mask_type = self.context.new_vector_type(mask_element_type, new_mask_num_units as u64);
let mask = self.context.new_rvalue_from_vector(self.location, mask_type, &vector_elements);
let mask = self.context.new_rvalue_from_vector(self.location, mask_type, &mask_elements);
let result = self.context.new_rvalue_vector_perm(self.location, v1, v2, mask);

if vec_num_units != mask_num_units {
Expand Down
19 changes: 12 additions & 7 deletions src/intrinsic/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,19 +353,24 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(
}

if name == sym::simd_shuffle {
// Make sure this is actually an array, since typeck only checks the length-suffixed
// Make sure this is actually an array or SIMD vector, since typeck only checks the length-suffixed
// version of this intrinsic.
let n: u64 = match *args[2].layout.ty.kind() {
let idx_ty = args[2].layout.ty;
let n: u64 = match idx_ty.kind() {
ty::Array(ty, len) if matches!(*ty.kind(), ty::Uint(ty::UintTy::U32)) => {
len.try_eval_target_usize(bx.cx.tcx, ty::ParamEnv::reveal_all()).unwrap_or_else(
|| span_bug!(span, "could not evaluate shuffle index array length"),
)
}
_ => return_error!(InvalidMonomorphization::SimdShuffle {
span,
name,
ty: args[2].layout.ty
}),
_ if idx_ty.is_simd()
&& matches!(
idx_ty.simd_size_and_type(bx.cx.tcx).1.kind(),
ty::Uint(ty::UintTy::U32)
) =>
{
idx_ty.simd_size_and_type(bx.cx.tcx).0
}
_ => return_error!(InvalidMonomorphization::SimdShuffle { span, name, ty: idx_ty }),
};
require_simd!(ret_ty, InvalidMonomorphization::SimdReturn { span, name, ty: ret_ty });

Expand Down

0 comments on commit f1d136b

Please sign in to comment.