diff --git a/example/std_example.rs b/example/std_example.rs index 33db75f094..408c9b3dad 100644 --- a/example/std_example.rs +++ b/example/std_example.rs @@ -9,6 +9,8 @@ )] #![allow(internal_features)] +#[cfg(target_arch = "x86_64")] +use std::arch::asm; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; use std::hint::black_box; @@ -279,6 +281,17 @@ unsafe fn test_simd() { #[cfg(not(jit))] test_crc32(); + + #[cfg(not(jit))] + test_xmm_roundtrip(); + #[cfg(not(jit))] + if is_x86_feature_detected!("avx") { + test_ymm_roundtrip(); + } + #[cfg(not(jit))] + if is_x86_feature_detected!("avx512f") { + test_zmm_roundtrip(); + } } } @@ -576,6 +589,65 @@ unsafe fn test_mm_cvtps_ph() { assert_eq_m128i(r, e); } +#[cfg(target_arch = "x86_64")] +#[cfg(not(jit))] +unsafe fn test_xmm_roundtrip() { + unsafe { + let input = [1u8; 16]; + let mut output = [0u8; 16]; + + asm!( + "movups {xmm}, [{input}]", + "movups [{output}], {xmm}", + input = in(reg) input.as_ptr(), + output = in(reg) output.as_mut_ptr(), + xmm = out(xmm_reg) _, + ); + + assert_eq!(input, output); + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx")] +#[cfg(not(jit))] +unsafe fn test_ymm_roundtrip() { + unsafe { + let input = [1u8; 32]; + let mut output = [0u8; 32]; + + asm!( + "vmovups {ymm}, [{input}]", + "vmovups [{output}], {ymm}", + input = in(reg) input.as_ptr(), + output = in(reg) output.as_mut_ptr(), + ymm = out(ymm_reg) _, + ); + + assert_eq!(input, output); + } +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx512f")] +#[cfg(not(jit))] +unsafe fn test_zmm_roundtrip() { + unsafe { + let input = [1u8; 64]; + let mut output = [0u8; 64]; + + asm!( + "vmovups {zmm}, [{input}]", + "vmovups [{output}], {zmm}", + input = in(reg) input.as_ptr(), + output = in(reg) output.as_mut_ptr(), + zmm = out(zmm_reg) _, + ); + + assert_eq!(input, output); + } +} + fn test_checked_mul() { let u: Option = u8::from_str_radix("1000", 10).ok(); assert_eq!(u, None); diff --git a/src/inline_asm.rs b/src/inline_asm.rs index ac0da06cbb..8f1bb0b471 100644 --- a/src/inline_asm.rs +++ b/src/inline_asm.rs @@ -548,22 +548,21 @@ impl<'tcx> InlineAssemblyGenerator<'_, 'tcx> { match self.arch { InlineAsmArch::X86_64 => match reg { InlineAsmReg::X86(reg) - if reg as u32 >= X86InlineAsmReg::xmm0 as u32 - && reg as u32 <= X86InlineAsmReg::xmm15 as u32 => + if matches!( + reg.reg_class(), + X86InlineAsmRegClass::xmm_reg + | X86InlineAsmRegClass::ymm_reg + | X86InlineAsmRegClass::zmm_reg + ) => { - // rustc emits x0 rather than xmm0 - let class = match *modifier { - None | Some('x') => "xmm", - Some('y') => "ymm", - Some('z') => "zmm", - _ => unreachable!(), - }; - write!( - generated_asm, - "{class}{}", - reg as u32 - X86InlineAsmReg::xmm0 as u32 - ) - .unwrap(); + // rustc emits x0/y0/z0 rather than xmm0/ymm0/zmm0 + let name = reg.name(); + if let Some(prefix) = modifier { + let index = &name[3..]; + write!(generated_asm, "{prefix}mm{index}").unwrap(); + } else { + write!(generated_asm, "{name}").unwrap(); + } } _ => reg .emit(&mut generated_asm, InlineAsmArch::X86_64, *modifier) @@ -716,12 +715,17 @@ impl<'tcx> InlineAssemblyGenerator<'_, 'tcx> { InlineAsmArch::X86_64 => { match reg { InlineAsmReg::X86(reg) - if reg as u32 >= X86InlineAsmReg::xmm0 as u32 - && reg as u32 <= X86InlineAsmReg::xmm15 as u32 => + if matches!( + reg.reg_class(), + X86InlineAsmRegClass::xmm_reg + | X86InlineAsmRegClass::ymm_reg + | X86InlineAsmRegClass::zmm_reg + ) => { - // rustc emits x0 rather than xmm0 - write!(generated_asm, " movups [rbx+0x{:x}], ", offset.bytes()).unwrap(); - write!(generated_asm, "xmm{}", reg as u32 - X86InlineAsmReg::xmm0 as u32) + // rustc emits x0/y0/z0 rather than xmm0/ymm0/zmm0 + let name = reg.name(); + let mov = if name.starts_with("xmm") { "movups" } else { "vmovups" }; + write!(generated_asm, " {mov} [rbx+0x{:x}], {name}", offset.bytes()) .unwrap(); } _ => { @@ -761,16 +765,17 @@ impl<'tcx> InlineAssemblyGenerator<'_, 'tcx> { InlineAsmArch::X86_64 => { match reg { InlineAsmReg::X86(reg) - if reg as u32 >= X86InlineAsmReg::xmm0 as u32 - && reg as u32 <= X86InlineAsmReg::xmm15 as u32 => + if matches!( + reg.reg_class(), + X86InlineAsmRegClass::xmm_reg + | X86InlineAsmRegClass::ymm_reg + | X86InlineAsmRegClass::zmm_reg + ) => { - // rustc emits x0 rather than xmm0 - write!( - generated_asm, - " movups xmm{}", - reg as u32 - X86InlineAsmReg::xmm0 as u32 - ) - .unwrap(); + // rustc emits x0/y0/z0 rather than xmm0/ymm0/zmm0 + let name = reg.name(); + let mov = if name.starts_with("xmm") { "movups" } else { "vmovups" }; + write!(generated_asm, " {mov} {name}").unwrap(); } _ => { generated_asm.push_str(" mov ");