Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making zero_mem_region() and control regs helpers unsafe #546

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions kernel/src/cpu/control_regs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::platform::SvsmPlatform;
use bitflags::bitflags;
use core::arch::asm;

#[inline]
pub fn cr0_init() {
let mut cr0 = read_cr0();

Expand All @@ -21,6 +22,7 @@ pub fn cr0_init() {
write_cr0(cr0);
}

#[inline]
pub fn cr4_init(platform: &dyn SvsmPlatform) {
let mut cr4 = read_cr4();

Expand All @@ -46,6 +48,7 @@ pub fn cr4_init(platform: &dyn SvsmPlatform) {
write_cr4(cr4);
}

#[inline]
pub fn cr0_sse_enable() {
let mut cr0 = read_cr0();

Expand All @@ -58,6 +61,7 @@ pub fn cr0_sse_enable() {
write_cr0(cr0);
}

#[inline]
pub fn cr4_osfxsr_enable() {
let mut cr4 = read_cr4();

Expand All @@ -66,6 +70,7 @@ pub fn cr4_osfxsr_enable() {
write_cr4(cr4);
}

#[inline]
pub fn cr4_xsave_enable() {
let mut cr4 = read_cr4();

Expand All @@ -91,6 +96,7 @@ bitflags! {
}
}

#[inline]
pub fn read_cr0() -> CR0Flags {
let cr0: u64;

Expand All @@ -103,6 +109,7 @@ pub fn read_cr0() -> CR0Flags {
CR0Flags::from_bits_truncate(cr0)
}

#[inline]
pub fn write_cr0(cr0: CR0Flags) {
let reg = cr0.bits();

Expand All @@ -113,6 +120,7 @@ pub fn write_cr0(cr0: CR0Flags) {
}
}

#[inline]
pub fn read_cr2() -> usize {
let ret: usize;
unsafe {
Expand All @@ -123,6 +131,7 @@ pub fn read_cr2() -> usize {
ret
}

#[inline]
pub fn write_cr2(cr2: usize) {
unsafe {
asm!("mov %rax, %cr2",
Expand All @@ -131,6 +140,7 @@ pub fn write_cr2(cr2: usize) {
}
}

#[inline]
pub fn read_cr3() -> PhysAddr {
let ret: usize;
unsafe {
Expand All @@ -141,6 +151,7 @@ pub fn read_cr3() -> PhysAddr {
PhysAddr::from(ret)
}

#[inline]
pub fn write_cr3(cr3: PhysAddr) {
unsafe {
asm!("mov %rax, %cr3",
Expand Down Expand Up @@ -176,6 +187,7 @@ bitflags! {
}
}

#[inline]
pub fn read_cr4() -> CR4Flags {
let cr4: u64;

Expand All @@ -188,6 +200,7 @@ pub fn read_cr4() -> CR4Flags {
CR4Flags::from_bits_truncate(cr4)
}

#[inline]
pub fn write_cr4(cr4: CR4Flags) {
let reg = cr4.bits();

Expand Down
16 changes: 14 additions & 2 deletions kernel/src/mm/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,12 @@ impl MemoryRegion {
fn allocate_zeroed_page(&mut self) -> Result<VirtAddr, AllocError> {
let vaddr = self.allocate_page()?;

zero_mem_region(vaddr, vaddr + PAGE_SIZE);
// SAFETY: we trust allocate_page() to return a pointer to a valid
// page. vaddr + PAGE_SIZE also correctly points to the end of the
// page.
unsafe {
zero_mem_region(vaddr, vaddr + PAGE_SIZE);
}

Ok(vaddr)
}
Expand Down Expand Up @@ -1091,7 +1096,14 @@ pub fn allocate_zeroed_page() -> Result<VirtAddr, SvsmError> {
/// `SvsmError` if allocation fails.
pub fn allocate_file_page() -> Result<VirtAddr, SvsmError> {
let vaddr = ROOT_MEM.lock().allocate_file_page()?;
zero_mem_region(vaddr, vaddr + PAGE_SIZE);

// SAFETY: we trust allocate_file_page() to return a pointer to a valid
// page. vaddr + PAGE_SIZE also correctly points to the end of the
// page.
unsafe {
zero_mem_region(vaddr, vaddr + PAGE_SIZE);
}

Ok(vaddr)
}

Expand Down
18 changes: 15 additions & 3 deletions kernel/src/platform/snp_fw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ fn validate_fw_mem_region(
PageSize::Regular,
)?;

zero_mem_region(vaddr, vaddr + PAGE_SIZE);
// SAFETY: we trust PerCPUPageMappingGuard::create_4k() to return a
// valid pointer to a correctly mapped region of size PAGE_SIZE.
unsafe {
zero_mem_region(vaddr, vaddr + PAGE_SIZE);
}
}

Ok(())
Expand Down Expand Up @@ -322,7 +326,11 @@ fn copy_secrets_page_to_fw(
let start = guard.virt_addr();

// Zero target
zero_mem_region(start, start + PAGE_SIZE);
// SAFETY: we trust PerCPUPageMappingGuard::create_4k() to return a
// valid pointer to a correctly mapped region of size PAGE_SIZE.
unsafe {
zero_mem_region(start, start + PAGE_SIZE);
}

// Copy secrets page
let mut fw_secrets_page = secrets_page().copy_for_vmpl(GUEST_VMPL);
Expand All @@ -345,7 +353,11 @@ fn zero_caa_page(fw_addr: PhysAddr) -> Result<(), SvsmError> {
let guard = PerCPUPageMappingGuard::create_4k(fw_addr)?;
let vaddr = guard.virt_addr();

zero_mem_region(vaddr, vaddr + PAGE_SIZE);
// SAFETY: we trust PerCPUPageMappingGuard::create_4k() to return a
// valid pointer to a correctly mapped region of size PAGE_SIZE.
unsafe {
zero_mem_region(vaddr, vaddr + PAGE_SIZE);
}

Ok(())
}
Expand Down
9 changes: 8 additions & 1 deletion kernel/src/protocols/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,14 @@ fn core_pvalidate_one(entry: u64, flush: &mut bool) -> Result<(), SvsmReqError>
// FIXME: This check leaves a window open for the attack described
// above. Remove the check once OVMF and Linux have been fixed and
// no longer try to pvalidate MMIO memory.
zero_mem_region(vaddr, vaddr + page_size_bytes);

// SAFETY: paddr is validated at the beginning of the function, and
// we trust PerCPUPageMappingGuard::create() to return a valid
// vaddr pointing to a mapped region of at least page_size_bytes
// size.
unsafe {
zero_mem_region(vaddr, vaddr + page_size_bytes);
}
} else {
log::warn!("Not clearing possible read-only page at PA {:#x}", paddr);
}
Expand Down
3 changes: 1 addition & 2 deletions kernel/src/svsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,9 @@ pub extern "C" fn svsm_start(li: &KernelLaunchInfo, vb_addr: usize) {
// We trust stage 2 to give the value provided by IGVM.
unsafe {
secrets_page_mut().copy_from(secrets_page_virt);
zero_mem_region(secrets_page_virt, secrets_page_virt + PAGE_SIZE);
}

zero_mem_region(secrets_page_virt, secrets_page_virt + PAGE_SIZE);

cr0_init();
cr4_init(&*platform);
determine_cet_support();
Expand Down
20 changes: 16 additions & 4 deletions kernel/src/utils/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,22 @@ where
x1 <= y2 && y1 <= x2
}

pub fn zero_mem_region(start: VirtAddr, end: VirtAddr) {
let size = end - start;
/// # Safety
///
/// Caller should ensure [`core::ptr::write_bytes`] safety rules.
pub unsafe fn zero_mem_region(start: VirtAddr, end: VirtAddr) {
if start.is_null() {
panic!("Attempted to zero out a NULL pointer");
}

let count = end
.checked_sub(start.as_usize())
.expect("Invalid size calculation")
.as_usize();

// Zero region
unsafe { start.as_mut_ptr::<u8>().write_bytes(0, size) }
// SAFETY: the safety rules must be upheld by the caller.
unsafe { start.as_mut_ptr::<u8>().write_bytes(0, count) }
}

/// Obtain bit for a given position
Expand Down Expand Up @@ -117,7 +125,11 @@ mod tests {
let start = VirtAddr::from(data.as_mut_ptr());
let end = start + core::mem::size_of_val(&data);

zero_mem_region(start, end);
// SAFETY: start and end correctly point respectively to the start and
// end of data.
unsafe {
zero_mem_region(start, end);
}

for byte in &data {
assert_eq!(*byte, 0);
Expand Down
Loading