diff --git a/stage2/Cargo.toml b/stage2/Cargo.toml new file mode 100644 index 000000000..ca650067e --- /dev/null +++ b/stage2/Cargo.toml @@ -0,0 +1,53 @@ +[package] +name = "svsm" +version = "0.1.0" +edition = "2021" +rust-version = "1.82.0" + +[[bin]] +name = "stage2" +path = "src/stage2.rs" +test = false + +[[bin]] +name = "svsm" +path = "src/svsm.rs" +test = false + +[lib] +test = true +doctest = true + +[dependencies] +bootlib.workspace = true +cpuarch.workspace = true +elf.workspace = true +syscall.workspace = true + +aes-gcm = { workspace = true, features = ["aes", "alloc"] } +bitfield-struct.workspace = true +bitflags.workspace = true +gdbstub = { workspace = true, optional = true } +gdbstub_arch = { workspace = true, optional = true } +igvm_defs = { workspace = true, features = ["unstable"] } +intrusive-collections.workspace = true +log = { workspace = true, features = ["max_level_info", "release_max_level_info"] } +packit.workspace = true +tdx-tdcall.workspace = true +libmstpm = { workspace = true, optional = true } +zerocopy.workspace = true + +[target."x86_64-unknown-none".dev-dependencies] +test.workspace = true + +[features] +default = [] +enable-gdb = ["dep:gdbstub", "dep:gdbstub_arch"] +mstpm = ["dep:libmstpm"] +nosmep = [] +nosmap = [] + +[dev-dependencies] + +[lints] +workspace = true diff --git a/stage2/build.rs b/stage2/build.rs new file mode 100644 index 000000000..526b10203 --- /dev/null +++ b/stage2/build.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +fn main() { + // Extra cfgs + println!("cargo::rustc-check-cfg=cfg(fuzzing)"); + println!("cargo::rustc-check-cfg=cfg(test_in_svsm)"); + + // Stage 2 + println!("cargo:rustc-link-arg-bin=stage2=-nostdlib"); + println!("cargo:rustc-link-arg-bin=stage2=--build-id=none"); + println!("cargo:rustc-link-arg-bin=stage2=-Tkernel/src/stage2.lds"); + println!("cargo:rustc-link-arg-bin=stage2=-no-pie"); + + // SVSM 2 + println!("cargo:rustc-link-arg-bin=svsm=-nostdlib"); + println!("cargo:rustc-link-arg-bin=svsm=--build-id=none"); + println!("cargo:rustc-link-arg-bin=svsm=--no-relax"); + println!("cargo:rustc-link-arg-bin=svsm=-Tkernel/src/svsm.lds"); + println!("cargo:rustc-link-arg-bin=svsm=-no-pie"); + + // Extra linker args for tests. + println!("cargo:rerun-if-env-changed=LINK_TEST"); + if std::env::var("LINK_TEST").is_ok() { + println!("cargo:rustc-cfg=test_in_svsm"); + println!("cargo:rustc-link-arg=-nostdlib"); + println!("cargo:rustc-link-arg=--build-id=none"); + println!("cargo:rustc-link-arg=--no-relax"); + println!("cargo:rustc-link-arg=-Tkernel/src/svsm.lds"); + println!("cargo:rustc-link-arg=-no-pie"); + } + + println!("cargo:rerun-if-changed=kernel/src/stage2.lds"); + println!("cargo:rerun-if-changed=kernel/src/svsm.lds"); +} diff --git a/stage2/src/acpi/mod.rs b/stage2/src/acpi/mod.rs new file mode 100644 index 000000000..42794add5 --- /dev/null +++ b/stage2/src/acpi/mod.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod tables; diff --git a/stage2/src/acpi/tables.rs b/stage2/src/acpi/tables.rs new file mode 100644 index 000000000..fc125df5a --- /dev/null +++ b/stage2/src/acpi/tables.rs @@ -0,0 +1,514 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use crate::error::SvsmError; +use crate::fw_cfg::FwCfg; +use crate::string::FixedString; +use alloc::vec::Vec; +use core::mem; +use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout}; + +/// ACPI Root System Description Pointer (RSDP) +/// used by ACPI programming interface +#[derive(Debug, Default, FromBytes, IntoBytes)] +#[repr(C, packed)] +struct RSDPDesc { + /// Signature must contain "RSD PTR" + sig: [u8; 8], + /// Checksum to add to all other bytes + chksum: u8, + /// OEM-supplied string + oem_id: [u8; 6], + /// Revision of the ACPI + rev: u8, + /// Physical address of the RSDT + rsdt_addr: u32, +} + +impl RSDPDesc { + /// Create an RSPDesc instance from FwCfg + /// + /// # Arguments + /// + /// - `fw_cfg`: A reference to the FwCfg instance. + /// + /// # Returns + /// + /// A [`Result`] containing the [`RSDPDesc`] if successful, or an [`SvsmError`] on failure. + fn from_fwcfg(fw_cfg: &FwCfg<'_>) -> Result { + let path = option_env!("ACPI_RSDP_PATH").unwrap_or("etc/acpi/rsdp"); + let file = fw_cfg.file_selector(path)?; + + if (file.size() as usize) < mem::size_of::() { + return Err(SvsmError::Acpi); + } + + fw_cfg.select(file.selector()); + let mut this = Self::new_zeroed(); + fw_cfg.read_bytes(this.as_mut_bytes()); + Ok(this) + } +} + +#[derive(Copy, Clone, Debug, Default, FromBytes, KnownLayout, Immutable)] +#[repr(C, packed)] +/// Raw header of an ACPI table. It corresponds to the beginning +/// portion of ACPI tables, before any specific table data +struct RawACPITableHeader { + /// Signature specificies the type of ACPI table + sig: [u8; 4], + /// Length of the table + len: u32, + /// Revision (signature field) + rev: u8, + /// Checksum for data integrity + chksum: u8, + /// OEM-supplied string to identify OEM + oem_id: [u8; 6], + /// OEM-supplied string to identify tables + oem_table_id: [u8; 8], + /// OEM-supplied version number + oem_rev: u32, + /// ID for compiler + compiler_id: [u8; 4], + /// Revision of compiler used to create the table + compiler_rev: u32, +} + +#[derive(Debug, Default)] +/// Higher level representation of the raw ACPI table header +struct ACPITableHeader { + sig: [u8; 4], + len: u32, + rev: u8, + chksum: u8, + oem_id: [u8; 6], + oem_table_id: [u8; 8], + oem_rev: u32, + compiler_id: [u8; 4], + compiler_rev: u32, +} + +impl ACPITableHeader { + /// Create a new [`ACPITableHeader`] from a raw [`RawACPITableHeader`]. + /// + /// This constructor converts a raw ACPI table header into a higher-level [`ACPITableHeader`]. + /// + /// # Arguments + /// + /// * `raw` - A [`RawACPITableHeader`] containing the raw header data. + /// + /// # Returns + /// + /// A new [`ACPITableHeader`] instance. + const fn new(raw: RawACPITableHeader) -> Self { + Self { + sig: raw.sig, + len: raw.len, + rev: raw.rev, + chksum: raw.chksum, + oem_id: raw.oem_id, + oem_table_id: raw.oem_table_id, + oem_rev: raw.oem_rev, + compiler_id: raw.compiler_id, + compiler_rev: raw.compiler_rev, + } + } + + /// Print a human-readable summary of the ACPI table header's fields + #[expect(dead_code)] + fn print_summary(&self) { + let sig = FixedString::from(self.sig); + let oem_id = FixedString::from(self.oem_id); + let oem_table_id = FixedString::from(self.oem_table_id); + let compiler_id = FixedString::from(self.compiler_id); + log::trace!( + "ACPI: [{} {} {} {} {} {} {} {} {}]", + sig, + self.len, + self.rev, + self.chksum, + oem_id, + oem_table_id, + self.oem_rev, + compiler_id, + self.compiler_rev + ); + } +} + +#[derive(Debug)] +/// ACPI table, both header and contents +struct ACPITable { + header: ACPITableHeader, + /// Raw binary content of ACPI table + buf: Vec, +} + +impl ACPITable { + /// Create a new [`ACPITable`] from raw binary data. + /// + /// This constructor creates an [`ACPITable`] instance by parsing raw binary data. + /// + /// # Arguments + /// + /// * `ptr` - A slice containing the raw binary data of the ACPI table. + /// + /// # Returns + /// + /// A new [`ACPITable`] instance on success, or an [`SvsmError`] if parsing fails. + fn new(ptr: &[u8]) -> Result { + let (raw_header, _) = + RawACPITableHeader::read_from_prefix(ptr).map_err(|_| SvsmError::Acpi)?; + let size = raw_header.len as usize; + let content = ptr.get(..size).ok_or(SvsmError::Acpi)?; + + let mut buf = Vec::::new(); + // Allow for a failable allocation before copying + buf.try_reserve(size).map_err(|_| SvsmError::Mem)?; + buf.extend_from_slice(content); + + let header = ACPITableHeader::new(raw_header); + + Ok(Self { header, buf }) + } + + /// Get the signature of the ACPI table. + /// + /// This method returns the 4-character signature of the ACPI table, such as "APIC." + #[expect(dead_code)] + fn signature(&self) -> FixedString<4> { + FixedString::from(self.header.sig) + } + + /// Get the content of the ACPI table. + /// + /// This method returns a reference to the binary content of the ACPI table, + /// excluding the header. + /// + /// # Returns + /// + /// A reference to the ACPI table content, or [`None`] if the content is empty. + fn content(&self) -> Option<&[u8]> { + let offset = mem::size_of::(); + // Zero-length slices are valid, but we do not want them + self.buf.get(offset..).filter(|b| !b.is_empty()) + } + + /// Get a pointer to the content of the ACPI table at a specific offset. + /// + /// This method returns a pointer to the content of the ACPI table at the specified offset, + /// converted to the desired type `T`. + /// + /// # Arguments + /// + /// * `offset` - The offset at which to obtain the pointer. + /// + /// # Returns + /// + /// A reference to the content of the ACPI table at specified offset as type `T`, + /// or [`None`] if the offset is out of bounds. + fn content_ptr(&self, offset: usize) -> Option<&T> + where + T: FromBytes + KnownLayout + Immutable, + { + let bytes = self.content()?.get(offset..)?; + T::ref_from_prefix(bytes).ok().map(|(value, _rest)| value) + } +} + +/// ACPI Table Metadata +/// Metadata associated with an ACPI, information about signature and offset +#[derive(Debug)] +struct ACPITableMeta { + /// 4-character signature of the table + sig: FixedString<4>, + /// The offset of the table within the table buffer + offset: usize, +} + +impl ACPITableMeta { + /// Create a new [`ACPITableMeta`] instance. + /// + /// This constructor creates an [`ACPITableMeta`] instance with the specified signature and offset. + /// + /// # Arguments + /// + /// * `header` - The raw ACPI table header containing the signature. + /// * `offset` - The offset of the ACPI table within the ACPI table buffer. + /// + /// # Returns + /// + /// A new [`ACPITableMeta`] instance. + fn new(header: &RawACPITableHeader, offset: usize) -> Self { + let sig = FixedString::from(header.sig); + Self { sig, offset } + } +} + +const MAX_ACPI_TABLES_SIZE: usize = 128 * 1024; + +/// ACPI Table Buffer +/// A buffer containing ACPI tables. Responsible for loading the tables +/// from a firmware configuration +#[derive(Debug)] +struct ACPITableBuffer { + buf: Vec, + /// Collection of metadata for ACPI tables, including signatures + tables: Vec, +} + +impl ACPITableBuffer { + /// Create a new [`ACPITableBuffer`] instance from a firmware configuration source. + /// + /// This constructor creates an [`ACPITableBuffer`] instance by reading ACPI tables from the specified FwCfg source. + /// + /// # Arguments + /// + /// * `fw_cfg` - The firmware configuration source (FwCfg) from which ACPI tables will be loaded. + /// + /// # Returns + /// + /// A new [`ACPITableBuffer`] instance containing ACPI tables and their metadata. + fn from_fwcfg(fw_cfg: &FwCfg<'_>) -> Result { + let path = option_env!("ACPI_TABLES_PATH").unwrap_or("etc/acpi/tables"); + let file = fw_cfg.file_selector(path)?; + let size = file.size() as usize; + + let mut buf = Vec::::new(); + if size > MAX_ACPI_TABLES_SIZE { + return Err(SvsmError::Mem); + } + buf.try_reserve(size).map_err(|_| SvsmError::Mem)?; + buf.resize(size, 0); + fw_cfg.select(file.selector()); + fw_cfg.read_bytes(&mut buf); + + let mut acpibuf = Self { + buf, + tables: Vec::new(), + }; + acpibuf.load_tables(fw_cfg)?; + Ok(acpibuf) + } + + /// Load ACPI tables and their metadata from the ACPI Root System Description Pointer (RSDP). + /// + /// This method populates the `tables` field of the [`ACPITableBuffer`] with metadata for ACPI tables + /// found within the ACPI Root System Description Pointer (RSDP) structure. + /// + /// # Arguments + /// + /// * `fw_cfg` - The firmware configuration source (FwCfg) containing ACPI tables. + /// + /// # Returns + /// + /// A [`Result`] indicating success or an error if ACPI tables cannot be loaded. + fn load_tables(&mut self, fw_cfg: &FwCfg<'_>) -> Result<(), SvsmError> { + let desc = RSDPDesc::from_fwcfg(fw_cfg)?; + + let rsdt = self.acpi_table_from_offset(desc.rsdt_addr as usize)?; + let content = rsdt.content().ok_or(SvsmError::Acpi)?; + let offsets = content + .chunks_exact(mem::size_of::()) + .map(|c| u32::from_le_bytes(c.try_into().unwrap()) as usize); + + for offset in offsets { + let raw_header = self.buf.get(offset..).ok_or(SvsmError::Acpi)?; + let (raw_header, _) = + RawACPITableHeader::ref_from_prefix(raw_header).map_err(|_| SvsmError::Acpi)?; + let meta = ACPITableMeta::new(raw_header, offset); + self.tables.push(meta); + } + + Ok(()) + } + + /// Retrieve an ACPI table from a specified offset within the ACPI table buffer. + /// + /// This function attempts to retrieve an ACPI table from the ACPI table buffer starting from the + /// specified offset. It parses the table header and creates an [`ACPITable`] instance representing + /// the ACPI table's content. + /// + /// # Arguments + /// + /// * `offset` - The offset within the ACPI table buffer from which to retrieve the ACPI table. + /// + /// # Returns + /// + /// A [`Result`] containing the [`ACPITable`] instance if successfully retrieved, or an [`SvsmError`] + /// if the table cannot be retrieved or parsed. + fn acpi_table_from_offset(&self, offset: usize) -> Result { + let buf = self.buf.get(offset..).ok_or(SvsmError::Acpi)?; + ACPITable::new(buf) + } + + /// Retrieve an ACPI table by its signature. + /// + /// This method attempts to retrieve an ACPI table by its 4-character signature. + /// + /// # Arguments + /// + /// * `sig` - The signature of the ACPI table to retrieve. + /// + /// # Returns + /// + /// An [`Option`] containing the ACPI table if found, or [`None`] if not found. + fn acp_table_by_sig(&self, sig: &str) -> Option { + let offset = self + .tables + .iter() + .find(|entry| entry.sig == sig) + .map(|entry| entry.offset)?; + + self.acpi_table_from_offset(offset).ok() + } +} + +const MADT_HEADER_SIZE: usize = 8; + +/// Header of an entry within MADT +#[derive(Clone, Copy, Debug, FromBytes, Immutable, KnownLayout)] +#[repr(C, packed)] +struct RawMADTEntryHeader { + entry_type: u8, + entry_len: u8, +} + +/// Entry for a local APIC within MADT +#[derive(Clone, Copy, Debug, FromBytes, Immutable, KnownLayout)] +#[repr(C, packed)] +struct RawMADTEntryLocalApic { + header: RawMADTEntryHeader, + acpi_id: u8, + apic_id: u8, + flags: u32, +} + +/// Entry for a local X2APIC within MADT +#[derive(Clone, Copy, Debug, FromBytes, Immutable, KnownLayout)] +#[repr(C, packed)] +struct RawMADTEntryLocalX2Apic { + header: RawMADTEntryHeader, + reserved: [u8; 2], + apic_id: u32, + flags: u32, + acpi_id: u32, +} + +/// Information about an ACPI CPU +#[derive(Clone, Copy, Debug)] +pub struct ACPICPUInfo { + /// The APIC ID for the CPU + pub apic_id: u32, + /// Indicates whether the CPU is enabled + pub enabled: bool, +} + +/// Loads ACPI CPU information by parsing the ACPI tables. +/// +/// This function retrieves CPU information from the ACPI tables provided by the firmware. +/// It processes the Multiple APIC Description Table (MADT) to extract information about each CPU's +/// APIC ID and enabled status. +/// +/// # Arguments +/// +/// * `fw_cfg`: A reference to the Firmware Configuration (FwCfg) interface for accessing ACPI tables. +/// +/// # Returns +/// +/// A [`Result`] containing a vector of [`ACPICPUInfo`] structs representing CPU information. +/// If successful, the vector contains information about each detected CPU; otherwise, an error is returned. +/// +/// # Errors +/// +/// This function returns an error if there are issues with reading or parsing ACPI tables, +/// or if the required ACPI tables are not found. +/// +/// # Example +/// +/// ``` +/// use svsm::acpi::tables::load_acpi_cpu_info; +/// use svsm::fw_cfg::FwCfg; +/// use svsm::io::IOPort; +/// +/// #[derive(Debug)] +/// struct MyIo; +/// +/// impl IOPort for MyIo { +/// // your implementation +/// # fn outb(&self, _port: u16, _value: u8) {} +/// # fn outw(&self, _port: u16, _value: u16) {} +/// # fn inb(&self, _port: u16) -> u8 { 0 } +/// # fn inw(&self, _port: u16) -> u16 { 0 } +/// } +/// +/// let io = MyIo; +/// let fw_cfg = FwCfg::new(&io); +/// match load_acpi_cpu_info(&fw_cfg) { +/// Ok(cpu_info) => { +/// for info in cpu_info { +/// // You can print id (info.apic_id) and whether it is enabled (info.enabled) +/// } +/// } +/// Err(err) => { +/// // Print error +/// } +/// } +/// ``` +pub fn load_acpi_cpu_info(fw_cfg: &FwCfg<'_>) -> Result, SvsmError> { + let buffer = ACPITableBuffer::from_fwcfg(fw_cfg)?; + + let apic_table = buffer.acp_table_by_sig("APIC").ok_or(SvsmError::Acpi)?; + let content = apic_table.content().ok_or(SvsmError::Acpi)?; + + let mut cpus: Vec = Vec::new(); + + let mut offset = MADT_HEADER_SIZE; + while offset < content.len() { + let entry_ptr = apic_table + .content_ptr::(offset) + .ok_or(SvsmError::Acpi)?; + let entry_len = usize::from(entry_ptr.entry_len); + + match entry_ptr.entry_type { + 0 if entry_len == mem::size_of::() => { + let lapic_ptr = apic_table + .content_ptr::(offset) + .ok_or(SvsmError::Acpi)?; + cpus.push(ACPICPUInfo { + apic_id: lapic_ptr.apic_id as u32, + enabled: (lapic_ptr.flags & 1) == 1, + }); + } + 9 if entry_len == mem::size_of::() => { + let x2apic_ptr = apic_table + .content_ptr::(offset) + .ok_or(SvsmError::Acpi)?; + cpus.push(ACPICPUInfo { + apic_id: x2apic_ptr.apic_id, + enabled: (x2apic_ptr.flags & 1) == 1, + }); + } + madt_type if entry_len == 0 => { + log::warn!( + "Found zero-length MADT entry with type {}, stopping", + madt_type + ); + break; + } + madt_type => { + log::info!("Ignoring MADT entry with type {}", madt_type); + } + } + + offset = offset.checked_add(entry_len).ok_or(SvsmError::Acpi)?; + } + + Ok(cpus) +} diff --git a/stage2/src/address.rs b/stage2/src/address.rs new file mode 100644 index 000000000..9936d2950 --- /dev/null +++ b/stage2/src/address.rs @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Carlos López + +use crate::types::{PAGE_SHIFT, PAGE_SIZE}; +use core::fmt; +use core::ops; + +use core::slice; + +// The backing type to represent an address; +type InnerAddr = usize; + +const SIGN_BIT: usize = 47; + +#[inline] +const fn sign_extend(addr: InnerAddr) -> InnerAddr { + let mask = 1usize << SIGN_BIT; + if (addr & mask) == mask { + addr | !((1usize << SIGN_BIT) - 1) + } else { + addr & ((1usize << SIGN_BIT) - 1) + } +} + +pub trait Address: + Copy + From + Into + PartialEq + Eq + PartialOrd + Ord +{ + // Transform the address into its inner representation for easier + /// arithmetic manipulation + #[inline] + fn bits(&self) -> InnerAddr { + (*self).into() + } + + #[inline] + fn is_null(&self) -> bool { + self.bits() == 0 + } + + #[inline] + fn align_up(&self, align: InnerAddr) -> Self { + Self::from((self.bits() + (align - 1)) & !(align - 1)) + } + + #[inline] + fn page_align_up(&self) -> Self { + self.align_up(PAGE_SIZE) + } + + #[inline] + fn page_align(&self) -> Self { + Self::from(self.bits() & !(PAGE_SIZE - 1)) + } + + #[inline] + fn is_aligned(&self, align: InnerAddr) -> bool { + (self.bits() & (align - 1)) == 0 + } + + #[inline] + fn is_aligned_to(&self) -> bool { + self.is_aligned(core::mem::align_of::()) + } + + #[inline] + fn is_page_aligned(&self) -> bool { + self.is_aligned(PAGE_SIZE) + } + + #[inline] + fn checked_add(&self, off: InnerAddr) -> Option { + self.bits().checked_add(off).map(|addr| addr.into()) + } + + #[inline] + fn checked_sub(&self, off: InnerAddr) -> Option { + self.bits().checked_sub(off).map(|addr| addr.into()) + } + + #[inline] + fn saturating_add(&self, off: InnerAddr) -> Self { + Self::from(self.bits().saturating_add(off)) + } + + #[inline] + fn page_offset(&self) -> usize { + self.bits() & (PAGE_SIZE - 1) + } + + #[inline] + fn crosses_page(&self, size: usize) -> bool { + let start = self.bits(); + let x1 = start / PAGE_SIZE; + let x2 = (start + size - 1) / PAGE_SIZE; + x1 != x2 + } + + #[inline] + fn pfn(&self) -> InnerAddr { + self.bits() >> PAGE_SHIFT + } +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct PhysAddr(InnerAddr); + +impl PhysAddr { + #[inline] + pub const fn new(p: InnerAddr) -> Self { + Self(p) + } + + #[inline] + pub const fn null() -> Self { + Self(0) + } +} + +impl fmt::Display for PhysAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl fmt::LowerHex for PhysAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::LowerHex::fmt(&self.0, f) + } +} + +impl From for PhysAddr { + #[inline] + fn from(addr: InnerAddr) -> PhysAddr { + Self(addr) + } +} + +impl From for InnerAddr { + #[inline] + fn from(addr: PhysAddr) -> InnerAddr { + addr.0 + } +} + +impl From for PhysAddr { + #[inline] + fn from(addr: u64) -> PhysAddr { + // The unwrap will get optimized away on 64bit platforms, + // which should be our only target anyway + let addr: usize = addr.try_into().unwrap(); + PhysAddr::from(addr) + } +} + +impl From for u64 { + #[inline] + fn from(addr: PhysAddr) -> u64 { + addr.0 as u64 + } +} + +// Substracting two addresses produces an usize instead of an address, +// since we normally do this to compute the size of a memory region. +impl ops::Sub for PhysAddr { + type Output = InnerAddr; + + #[inline] + fn sub(self, other: PhysAddr) -> Self::Output { + self.0 - other.0 + } +} + +// Adding and subtracting usize to PhysAddr gives a new PhysAddr +impl ops::Sub for PhysAddr { + type Output = Self; + + #[inline] + fn sub(self, other: InnerAddr) -> Self { + PhysAddr::from(self.0 - other) + } +} + +impl ops::Add for PhysAddr { + type Output = Self; + + #[inline] + fn add(self, other: InnerAddr) -> Self { + PhysAddr::from(self.0 + other) + } +} + +impl Address for PhysAddr {} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct VirtAddr(InnerAddr); + +impl VirtAddr { + #[inline] + pub const fn null() -> Self { + Self(0) + } + + // const traits experimental, so for now we need this to make up + // for the lack of VirtAddr::from() in const contexts. + #[inline] + pub const fn new(addr: InnerAddr) -> Self { + Self(sign_extend(addr)) + } + + /// Returns the index into page-table pages of given levels. + pub const fn to_pgtbl_idx(&self) -> usize { + (self.0 >> (12 + L * 9)) & 0x1ffusize + } + + #[inline] + pub fn as_ptr(&self) -> *const T { + self.0 as *const T + } + + #[inline] + pub fn as_mut_ptr(&self) -> *mut T { + self.0 as *mut T + } + + /// Converts the `VirtAddr` to a reference to the given type, checking + /// that the address is not NULL and properly aligned. + /// + /// # Safety + /// + /// All safety requirements for pointers apply, minus alignment and NULL + /// checks, which this function already does. + #[inline] + pub unsafe fn aligned_ref<'a, T>(&self) -> Option<&'a T> { + self.is_aligned_to::() + .then(|| self.as_ptr::().as_ref()) + .flatten() + } + + /// Converts the `VirtAddr` to a reference to the given type, checking + /// that the address is not NULL and properly aligned. + /// + /// # Safety + /// + /// All safety requirements for pointers apply, minus alignment and NULL + /// checks, which this function already does. + #[inline] + pub unsafe fn aligned_mut<'a, T>(&self) -> Option<&'a mut T> { + self.is_aligned_to::() + .then(|| self.as_mut_ptr::().as_mut()) + .flatten() + } + + pub const fn const_add(&self, offset: usize) -> Self { + VirtAddr::new(self.0 + offset) + } + + /// Converts the `VirtAddr` to a slice of a given type + /// + /// # Arguments: + /// + /// * `len` - Number of elements of type `T` in the slice + /// + /// # Returns + /// + /// Slice with `len` elements of type `T` + /// + /// # Safety + /// + /// All Safety requirements from [`core::slice::from_raw_parts`] for the + /// data pointed to by the `VirtAddr` apply here as well. + pub unsafe fn to_slice(&self, len: usize) -> &[T] { + slice::from_raw_parts::(self.as_ptr::(), len) + } +} + +impl fmt::Display for VirtAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl fmt::LowerHex for VirtAddr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::LowerHex::fmt(&self.0, f) + } +} + +impl From for VirtAddr { + #[inline] + fn from(addr: InnerAddr) -> Self { + Self(sign_extend(addr)) + } +} + +impl From for InnerAddr { + #[inline] + fn from(addr: VirtAddr) -> Self { + addr.0 + } +} + +impl From for VirtAddr { + #[inline] + fn from(addr: u64) -> Self { + let addr: usize = addr.try_into().unwrap(); + VirtAddr::from(addr) + } +} + +impl From for u64 { + #[inline] + fn from(addr: VirtAddr) -> Self { + addr.0 as u64 + } +} + +impl From<*const T> for VirtAddr { + #[inline] + fn from(ptr: *const T) -> Self { + Self(ptr as InnerAddr) + } +} + +impl From<*mut T> for VirtAddr { + fn from(ptr: *mut T) -> Self { + Self(ptr as InnerAddr) + } +} + +impl ops::Sub for VirtAddr { + type Output = InnerAddr; + + #[inline] + fn sub(self, other: VirtAddr) -> Self::Output { + sign_extend(self.0 - other.0) + } +} + +impl ops::Sub for VirtAddr { + type Output = Self; + + #[inline] + fn sub(self, other: usize) -> Self { + VirtAddr::from(self.0 - other) + } +} + +impl ops::Add for VirtAddr { + type Output = VirtAddr; + + fn add(self, other: InnerAddr) -> Self { + VirtAddr::from(self.0 + other) + } +} + +impl Address for VirtAddr { + #[inline] + fn checked_add(&self, off: InnerAddr) -> Option { + self.bits() + .checked_add(off) + .map(|addr| sign_extend(addr).into()) + } + + #[inline] + fn checked_sub(&self, off: InnerAddr) -> Option { + self.bits() + .checked_sub(off) + .map(|addr| sign_extend(addr).into()) + } +} diff --git a/stage2/src/boot_stage2.rs b/stage2/src/boot_stage2.rs new file mode 100644 index 000000000..83534e7fe --- /dev/null +++ b/stage2/src/boot_stage2.rs @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use core::arch::global_asm; + +use svsm::{ + cpu::{ + efer::EFERFlags, + msr::{EFER, SEV_STATUS}, + }, + mm::PGTABLE_LVL3_IDX_PTE_SELFMAP, +}; + +global_asm!( + r#" + .text + .section ".startup.text","ax" + .code32 + + .org 0 + .globl startup_32 + startup_32: + + /* Save pointer to startup structure in EBP */ + movl %esp, %ebp + /* + * Load a GDT. Despite the naming, it contains valid + * entries for both, "legacy" 32bit and long mode each. + */ + movl $gdt64_desc, %eax + lgdt (%eax) + + movw $0x10, %ax + movw %ax, %ds + movw %ax, %es + movw %ax, %fs + movw %ax, %gs + movw %ax, %ss + + ljmpl $0x8, $.Lon_svsm32_cs + + .Lon_svsm32_cs: + /* + * SEV: %esi is always 0, only BSP running + * TDX: %esi is the TD CPU index + */ + test %esi, %esi + jnz .Lskip_paging_setup + + /* Clear out the static page table pages. */ + movl $pgtable_end, %ecx + subl $pgtable, %ecx + shrl $2, %ecx + xorl %eax, %eax + movl $pgtable, %edi + rep stosl + + /* Determine the C-bit position within PTEs. */ + call get_pte_c_bit + movl %eax, %edx + + /* Populate the static page table pages with an identity mapping. */ + movl $pgtable, %edi + leal 0x1007(%edi), %eax + movl %eax, 0(%edi) + addl %edx, 4(%edi) + + addl $0x1000, %edi + leal 0x1007(%edi), %eax + movl $4, %ecx + 1: movl %eax, 0(%edi) + addl %edx, 4(%edi) + addl $0x1000, %eax + addl $8, %edi + decl %ecx + jnz 1b + andl $0xfffff000, %edi + + addl $0x1000, %edi + movl $0x00000183, %eax + movl $2048, %ecx + 1: movl %eax, 0(%edi) + addl %edx, 4(%edi) + addl $0x00200000, %eax + addl $8, %edi + decl %ecx + jnz 1b + + /* Insert a self-map entry */ + movl $pgtable, %edi + movl %edi, %eax + orl $0x63, %eax + movl %eax, 8*{PGTABLE_LVL3_IDX_PTE_SELFMAP}(%edi) + movl $0x80000000, %eax + orl %edx, %eax + movl %eax, 0xF6C(%edi) + + /* Signal APs */ + movl $setup_flag, %edi + movl $1, (%edi) + jmp 2f + +.Lskip_paging_setup: + movl $setup_flag, %edi +.Lap_wait: + movl (%edi), %eax + test %eax, %eax + jz .Lap_wait + +2: + /* Enable 64bit PTEs, CR4.PAE. */ + movl %cr4, %eax + bts $5, %eax + movl %eax, %cr4 + + /* Enable long mode, EFER.LME. Also ensure NXE is set. */ + movl ${EFER}, %ecx + rdmsr + movl %eax, %ebx + orl $({LME} | {NXE}), %eax + cmp %eax, %ebx + jz 2f + wrmsr + 2: + + /* Load the static page table root. */ + movl $pgtable, %eax + movl %eax, %cr3 + + /* Enable paging, CR0.PG. */ + movl %cr0, %eax + bts $31, %eax + movl %eax, %cr0 + + ljmpl $0x18, $startup_64 + + get_pte_c_bit: + /* + * Check if this is an SNP platform. If not, there is no C bit. + */ + cmpl $1, 8(%ebp) + jnz .Lvtom + + /* + * Check that the SNP_Active bit in the SEV_STATUS MSR is set. + */ + movl ${SEV_STATUS}, %ecx + rdmsr + + testl $0x04, %eax + jz .Lno_sev_snp + + /* + * Check whether VTOM is selected + */ + testl $0x08, %eax + jnz .Lvtom + + /* Determine the PTE C-bit position from the CPUID page. */ + + /* Locate the table. The pointer to the CPUID page is 12 bytes into + * the stage2 startup structure. */ + movl 12(%ebp), %ecx + /* Read the number of entries. */ + movl (%ecx), %eax + /* Create a pointer to the first entry. */ + leal 16(%ecx), %ecx + + .Lcheck_entry: + /* Check that there is another entry. */ + test %eax, %eax + je .Lno_sev_snp + + /* Check the input parameters of the current entry. */ + cmpl $0x8000001f, (%ecx) /* EAX_IN */ + jne .Lwrong_entry + cmpl $0, 4(%ecx) /* ECX_IN */ + jne .Lwrong_entry + cmpl $0, 8(%ecx) /* XCR0_IN (lower half) */ + jne .Lwrong_entry + cmpl $0, 12(%ecx) /* XCR0_IN (upper half) */ + jne .Lwrong_entry + cmpl $0, 16(%ecx) /* XSS_IN (lower half) */ + jne .Lwrong_entry + cmpl $0, 20(%ecx) /* XSS_IN (upper half) */ + jne .Lwrong_entry + + /* All parameters were correct. */ + jmp .Lfound_entry + + .Lwrong_entry: + /* + * The current entry doesn't contain the correct input + * parameters. Try the next one. + */ + decl %eax + addl $0x30, %ecx + jmp .Lcheck_entry + + .Lfound_entry: + /* Extract the c-bit location from the cpuid entry. */ + movl 28(%ecx), %ebx + andl $0x3f, %ebx + + /* + * Verify that the C-bit position is within reasonable bounds: + * >= 32 and < 64. + */ + cmpl $32, %ebx + jl .Lno_sev_snp + cmpl $64, %ebx + jae .Lno_sev_snp + + subl $32, %ebx + xorl %eax, %eax + btsl %ebx, %eax + ret + + .Lvtom: + xorl %eax, %eax + ret + + .Lno_sev_snp: + hlt + jmp .Lno_sev_snp + + .code64 + + startup_64: + /* Reload the data segments with 64bit descriptors. */ + movw $0x20, %ax + movw %ax, %ds + movw %ax, %es + movw %ax, %fs + movw %ax, %gs + movw %ax, %ss + + test %esi, %esi + jz .Lbsp_main + + .Lcheck_command: + /* TODO */ + jmp .Lcheck_command + + .Lbsp_main: + /* Clear out .bss and transfer control to the main stage2 code. */ + xorq %rax, %rax + leaq _bss(%rip), %rdi + leaq _ebss(%rip), %rcx + subq %rdi, %rcx + shrq $3, %rcx + rep stosq + + movl %ebp, %edi + jmp stage2_main + + .data + + .align 4 + setup_flag: + .long 0 + + idt32: + .rept 32 + .quad 0 + .endr + idt32_end: + + idt32_desc: + .word idt32_end - idt32 - 1 + .long idt32 + + idt64: + .rept 32 + .octa 0 + .endr + idt64_end: + + idt64_desc: + .word idt64_end - idt64 - 1 + .quad idt64 + + .align 256 + gdt64: + .quad 0 + .quad 0x00cf9a000000ffff /* 32 bit code segment */ + .quad 0x00cf93000000ffff /* 32 bit data segment */ + .quad 0x00af9a000000ffff /* 64 bit code segment */ + .quad 0x00cf92000000ffff /* 64 bit data segment */ + gdt64_end: + + gdt64_desc: + .word gdt64_end - gdt64 - 1 + .quad gdt64 + + .align 4096 + .globl pgtable + pgtable: + .fill 7 * 4096, 1, 0 + pgtable_end:"#, + PGTABLE_LVL3_IDX_PTE_SELFMAP = const PGTABLE_LVL3_IDX_PTE_SELFMAP, + EFER = const EFER, + LME = const EFERFlags::LME.bits(), + NXE = const EFERFlags::NXE.bits(), + SEV_STATUS = const SEV_STATUS, + options(att_syntax) +); diff --git a/stage2/src/config.rs b/stage2/src/config.rs new file mode 100644 index 000000000..f7eda06e2 --- /dev/null +++ b/stage2/src/config.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +extern crate alloc; + +use core::slice; + +use crate::acpi::tables::{load_acpi_cpu_info, ACPICPUInfo}; +use crate::address::PhysAddr; +use crate::error::SvsmError; +use crate::fw_cfg::FwCfg; +use crate::fw_meta::{parse_fw_meta_data, SevFWMetaData}; +use crate::igvm_params::IgvmParams; +use crate::mm::{PerCPUPageMappingGuard, PAGE_SIZE, SIZE_1G}; +use crate::serial::SERIAL_PORT; +use crate::utils::MemoryRegion; +use alloc::vec::Vec; +use cpuarch::vmsa::VMSA; + +fn check_ovmf_regions( + flash_regions: &[MemoryRegion], + kernel_region: &MemoryRegion, +) { + let flash_range = { + let one_gib = 1024 * 1024 * 1024usize; + let start = PhysAddr::from(3 * one_gib); + MemoryRegion::new(start, one_gib) + }; + + // Sanity-check flash regions. + for region in flash_regions.iter() { + // Make sure that the regions are between 3GiB and 4GiB. + if !region.overlap(&flash_range) { + panic!("flash region in unexpected region"); + } + + // Make sure that no regions overlap with the kernel. + if region.overlap(kernel_region) { + panic!("flash region overlaps with kernel"); + } + } + + // Make sure that regions don't overlap. + for (i, outer) in flash_regions.iter().enumerate() { + for inner in flash_regions[..i].iter() { + if outer.overlap(inner) { + panic!("flash regions overlap"); + } + } + // Make sure that one regions ends at 4GiB. + let one_region_ends_at_4gib = flash_regions + .iter() + .any(|region| region.end() == flash_range.end()); + assert!(one_region_ends_at_4gib); + } +} + +#[derive(Debug)] +pub enum SvsmConfig<'a> { + FirmwareConfig(FwCfg<'a>), + IgvmConfig(IgvmParams<'a>), +} + +impl SvsmConfig<'_> { + pub fn find_kernel_region(&self) -> Result, SvsmError> { + match self { + SvsmConfig::FirmwareConfig(fw_cfg) => fw_cfg.find_kernel_region(), + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.find_kernel_region(), + } + } + pub fn page_state_change_required(&self) -> bool { + match self { + SvsmConfig::FirmwareConfig(_) => true, + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.page_state_change_required(), + } + } + pub fn get_memory_regions(&self) -> Result>, SvsmError> { + match self { + SvsmConfig::FirmwareConfig(fw_cfg) => fw_cfg.get_memory_regions(), + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.get_memory_regions(), + } + } + pub fn write_guest_memory_map(&self, map: &[MemoryRegion]) -> Result<(), SvsmError> { + match self { + SvsmConfig::FirmwareConfig(_) => Ok(()), + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.write_guest_memory_map(map), + } + } + pub fn reserved_kernel_area_size(&self) -> usize { + match self { + SvsmConfig::FirmwareConfig(_) => 0, + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.reserved_kernel_area_size(), + } + } + pub fn load_cpu_info(&self) -> Result, SvsmError> { + match self { + SvsmConfig::FirmwareConfig(fw_cfg) => load_acpi_cpu_info(fw_cfg), + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.load_cpu_info(), + } + } + pub fn should_launch_fw(&self) -> bool { + match self { + SvsmConfig::FirmwareConfig(_) => true, + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.should_launch_fw(), + } + } + + pub fn debug_serial_port(&self) -> u16 { + match self { + SvsmConfig::FirmwareConfig(_) => SERIAL_PORT, + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.debug_serial_port(), + } + } + + pub fn get_fw_metadata(&self) -> Option { + match self { + SvsmConfig::FirmwareConfig(_) => { + // Map the metadata location which is defined by the firmware config + let guard = + PerCPUPageMappingGuard::create_4k(PhysAddr::from(4 * SIZE_1G - PAGE_SIZE)) + .expect("Failed to map FW metadata page"); + let vstart = guard.virt_addr().as_ptr::(); + // Safety: we just mapped a page, so the size must hold. The type + // of the slice elements is `u8` so there are no alignment requirements. + let metadata = unsafe { slice::from_raw_parts(vstart, PAGE_SIZE) }; + Some(parse_fw_meta_data(metadata).expect("Failed to parse FW SEV meta-data")) + } + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.get_fw_metadata(), + } + } + + pub fn get_fw_regions( + &self, + kernel_region: &MemoryRegion, + ) -> Vec> { + match self { + SvsmConfig::FirmwareConfig(fw_cfg) => { + let flash_regions = fw_cfg.iter_flash_regions().collect::>(); + check_ovmf_regions(&flash_regions, kernel_region); + flash_regions + } + SvsmConfig::IgvmConfig(igvm_params) => { + let flash_regions = igvm_params.get_fw_regions(); + if !igvm_params.fw_in_low_memory() { + check_ovmf_regions(&flash_regions, kernel_region); + } + flash_regions + } + } + } + + pub fn fw_in_low_memory(&self) -> bool { + match self { + SvsmConfig::FirmwareConfig(_) => false, + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.fw_in_low_memory(), + } + } + + pub fn invalidate_boot_data(&self) -> bool { + match self { + SvsmConfig::FirmwareConfig(_) => false, + SvsmConfig::IgvmConfig(_) => true, + } + } + + pub fn initialize_guest_vmsa(&self, vmsa: &mut VMSA) -> Result<(), SvsmError> { + match self { + SvsmConfig::FirmwareConfig(_) => Ok(()), + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.initialize_guest_vmsa(vmsa), + } + } + + pub fn use_alternate_injection(&self) -> bool { + match self { + SvsmConfig::FirmwareConfig(_) => false, + SvsmConfig::IgvmConfig(igvm_params) => igvm_params.use_alternate_injection(), + } + } +} diff --git a/stage2/src/console.rs b/stage2/src/console.rs new file mode 100644 index 000000000..92d17f859 --- /dev/null +++ b/stage2/src/console.rs @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::error::SvsmError; +use crate::io::IOPort; +use crate::locking::SpinLock; +use crate::serial::{SerialPort, Terminal, DEFAULT_SERIAL_PORT}; +use crate::utils::immut_after_init::{ImmutAfterInitCell, ImmutAfterInitResult}; +use core::fmt; + +#[derive(Clone, Copy, Debug)] +struct Console { + writer: &'static dyn Terminal, +} + +impl fmt::Write for Console { + fn write_str(&mut self, s: &str) -> fmt::Result { + for ch in s.bytes() { + self.writer.put_byte(ch); + } + Ok(()) + } +} + +static WRITER: SpinLock = SpinLock::new(Console { + writer: &DEFAULT_SERIAL_PORT, +}); +static CONSOLE_INITIALIZED: ImmutAfterInitCell = ImmutAfterInitCell::new(false); +static CONSOLE_SERIAL: ImmutAfterInitCell> = ImmutAfterInitCell::uninit(); + +fn init_console(writer: &'static dyn Terminal) -> ImmutAfterInitResult<()> { + WRITER.lock().writer = writer; + CONSOLE_INITIALIZED.reinit(&true)?; + log::info!("COCONUT Secure Virtual Machine Service Module"); + Ok(()) +} + +pub fn init_svsm_console(writer: &'static dyn IOPort, port: u16) -> Result<(), SvsmError> { + CONSOLE_SERIAL + .init(&SerialPort::new(writer, port)) + .map_err(|_| SvsmError::Console)?; + (*CONSOLE_SERIAL).init(); + init_console(&*CONSOLE_SERIAL).map_err(|_| SvsmError::Console) +} + +#[doc(hidden)] +pub fn _print(args: fmt::Arguments<'_>) { + use core::fmt::Write; + if !*CONSOLE_INITIALIZED { + return; + } + WRITER.lock().write_fmt(args).unwrap(); +} + +#[derive(Clone, Copy, Debug)] +struct ConsoleLogger { + name: &'static str, +} + +impl ConsoleLogger { + const fn new(name: &'static str) -> Self { + Self { name } + } +} + +impl log::Log for ConsoleLogger { + fn enabled(&self, _metadata: &log::Metadata<'_>) -> bool { + true + } + + fn log(&self, record: &log::Record<'_>) { + if !self.enabled(record.metadata()) { + return; + } + + // The logger being uninitialized is impossible, as that would mean it + // wouldn't have been registered with the log library. + // Log format/detail depends on the level. + match record.metadata().level() { + log::Level::Error | log::Level::Warn => { + _print(format_args!( + "[{}] {}: {}\n", + self.name, + record.metadata().level().as_str(), + record.args() + )); + } + + log::Level::Info => { + _print(format_args!("[{}] {}\n", self.name, record.args())); + } + + log::Level::Debug | log::Level::Trace => { + _print(format_args!( + "[{}/{}] {} {}\n", + self.name, + record.metadata().target(), + record.metadata().level().as_str(), + record.args() + )); + } + }; + } + + fn flush(&self) {} +} + +static CONSOLE_LOGGER: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +pub fn install_console_logger(component: &'static str) -> ImmutAfterInitResult<()> { + CONSOLE_LOGGER.init(&ConsoleLogger::new(component))?; + + if let Err(e) = log::set_logger(&*CONSOLE_LOGGER) { + // Failed to install the ConsoleLogger, presumably because something had + // installed another logger before. No logs will appear at the console. + // Print an error string. + _print(format_args!( + "[{}]: ERROR: failed to install console logger: {:?}", + component, e, + )); + } + + // Log levels are to be configured via the log's library feature configuration. + log::set_max_level(log::LevelFilter::Trace); + Ok(()) +} + +#[macro_export] +macro_rules! println { + () => (log::info!("")); + ($($arg:tt)*) => (log::info!($($arg)*)); +} diff --git a/stage2/src/cpu/apic.rs b/stage2/src/cpu/apic.rs new file mode 100644 index 000000000..7d65e7e4a --- /dev/null +++ b/stage2/src/cpu/apic.rs @@ -0,0 +1,869 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +use crate::address::VirtAddr; +use crate::cpu::idt::common::INT_INJ_VECTOR; +use crate::cpu::percpu::{current_ghcb, this_cpu, PerCpuShared, PERCPU_AREAS}; +use crate::error::ApicError::Emulation; +use crate::error::SvsmError; +use crate::mm::GuestPtr; +use crate::platform::guest_cpu::GuestCpuState; +use crate::platform::SVSM_PLATFORM; +use crate::requests::SvsmCaa; +use crate::sev::hv_doorbell::HVExtIntStatus; +use crate::types::GUEST_VMPL; + +use bitfield_struct::bitfield; +use core::sync::atomic::Ordering; + +const APIC_REGISTER_APIC_ID: u64 = 0x802; +const APIC_REGISTER_TPR: u64 = 0x808; +const APIC_REGISTER_PPR: u64 = 0x80A; +const APIC_REGISTER_EOI: u64 = 0x80B; +const APIC_REGISTER_ISR_0: u64 = 0x810; +const APIC_REGISTER_ISR_7: u64 = 0x817; +const APIC_REGISTER_TMR_0: u64 = 0x818; +const APIC_REGISTER_TMR_7: u64 = 0x81F; +const APIC_REGISTER_IRR_0: u64 = 0x820; +const APIC_REGISTER_IRR_7: u64 = 0x827; +const APIC_REGISTER_ICR: u64 = 0x830; +const APIC_REGISTER_SELF_IPI: u64 = 0x83F; + +#[derive(Debug, PartialEq)] +enum IcrDestFmt { + Dest = 0, + OnlySelf = 1, + AllWithSelf = 2, + AllButSelf = 3, +} + +impl IcrDestFmt { + const fn into_bits(self) -> u64 { + self as _ + } + const fn from_bits(value: u64) -> Self { + match value { + 3 => Self::AllButSelf, + 2 => Self::AllWithSelf, + 1 => Self::OnlySelf, + _ => Self::Dest, + } + } +} + +#[derive(Debug, PartialEq)] +enum IcrMessageType { + Fixed = 0, + Unknown = 3, + Nmi = 4, + Init = 5, + Sipi = 6, + ExtInt = 7, +} + +impl IcrMessageType { + const fn into_bits(self) -> u64 { + self as _ + } + const fn from_bits(value: u64) -> Self { + match value { + 7 => Self::ExtInt, + 6 => Self::Sipi, + 5 => Self::Init, + 4 => Self::Nmi, + 0 => Self::Fixed, + _ => Self::Unknown, + } + } +} + +#[bitfield(u64)] +struct ApicIcr { + pub vector: u8, + #[bits(3)] + pub message_type: IcrMessageType, + pub destination_mode: bool, + pub delivery_status: bool, + rsvd_13: bool, + pub assert: bool, + pub trigger_mode: bool, + #[bits(2)] + pub remote_read_status: usize, + #[bits(2)] + pub destination_shorthand: IcrDestFmt, + #[bits(12)] + rsvd_31_20: u64, + pub destination: u32, +} + +// This structure must never be copied because a silent copy will cause APIC +// state to be lost. +#[expect(missing_copy_implementations)] +#[derive(Default, Debug)] +pub struct LocalApic { + irr: [u32; 8], + allowed_irr: [u32; 8], + isr_stack_index: usize, + isr_stack: [u8; 16], + tmr: [u32; 8], + host_tmr: [u32; 8], + update_required: bool, + interrupt_delivered: bool, + interrupt_queued: bool, + lazy_eoi_pending: bool, + nmi_pending: bool, +} + +impl LocalApic { + pub const fn new() -> Self { + Self { + irr: [0; 8], + allowed_irr: [0; 8], + isr_stack_index: 0, + isr_stack: [0; 16], + tmr: [0; 8], + host_tmr: [0; 8], + update_required: false, + interrupt_delivered: false, + interrupt_queued: false, + lazy_eoi_pending: false, + nmi_pending: false, + } + } + + fn scan_irr(&self) -> u8 { + // Scan to find the highest pending IRR vector. + for (i, irr) in self.irr.into_iter().enumerate().rev() { + if irr != 0 { + let bit_index = 31 - irr.leading_zeros(); + let vector = (i as u32) * 32 + bit_index; + return vector.try_into().unwrap(); + } + } + 0 + } + + fn remove_vector_register(register: &mut [u32; 8], irq: u8) { + register[irq as usize >> 5] &= !(1 << (irq & 31)); + } + + fn insert_vector_register(register: &mut [u32; 8], irq: u8) { + register[irq as usize >> 5] |= 1 << (irq & 31); + } + + fn test_vector_register(register: &[u32; 8], irq: u8) -> bool { + (register[irq as usize >> 5] & 1 << (irq & 31)) != 0 + } + + fn rewind_pending_interrupt(&mut self, irq: u8) { + let new_index = self.isr_stack_index.checked_sub(1).unwrap(); + assert!(self.isr_stack.get(new_index) == Some(&irq)); + Self::insert_vector_register(&mut self.irr, irq); + self.isr_stack_index = new_index; + self.update_required = true; + } + + pub fn check_delivered_interrupts( + &mut self, + cpu_state: &mut T, + caa_addr: Option, + ) { + // Check to see if a previously delivered interrupt is still pending. + // If so, move it back to the IRR. + if self.interrupt_delivered { + let irq = cpu_state.check_and_clear_pending_interrupt_event(); + if irq != 0 { + self.rewind_pending_interrupt(irq); + self.lazy_eoi_pending = false; + } + self.interrupt_delivered = false; + } + + // Check to see if a previously queued interrupt is still pending. + // If so, move it back to the IRR. + if self.interrupt_queued { + let irq = cpu_state.check_and_clear_pending_virtual_interrupt(); + if irq != 0 { + self.rewind_pending_interrupt(irq); + self.lazy_eoi_pending = false; + } + self.interrupt_queued = false; + } + + // If a lazy EOI is pending, then check to see whether an EOI has been + // requested by the guest. Note that if a lazy EOI was dismissed + // above, the guest lazy EOI flag need not be cleared here, since + // dismissal of any interrupt above will require reprocessing of + // interrupt state prior to guest reentry, and that reprocessing will + // reset the guest lazy EOI flag. + if self.lazy_eoi_pending { + if let Some(virt_addr) = caa_addr { + let calling_area = GuestPtr::::new(virt_addr); + // SAFETY: guest vmsa and ca are always validated before beeing updated + // (core_remap_ca(), core_create_vcpu() or prepare_fw_launch()) + // so they're safe to use. + if let Ok(caa) = unsafe { calling_area.read() } { + if caa.no_eoi_required == 0 { + assert!(self.isr_stack_index != 0); + self.perform_eoi(); + } + } + } + } + } + + fn get_ppr_with_tpr(&self, tpr: u8) -> u8 { + // Determine the priority of the current in-service interrupt, if any. + let ppr = if self.isr_stack_index != 0 { + self.isr_stack[self.isr_stack_index] + } else { + 0 + }; + + // The PPR is the higher of the in-service interrupt priority and the + // task priority. + if (ppr >> 4) > (tpr >> 4) { + ppr + } else { + tpr + } + } + + fn get_ppr(&self, cpu_state: &T) -> u8 { + self.get_ppr_with_tpr(cpu_state.get_tpr()) + } + + fn clear_guest_eoi_pending(caa_addr: Option) -> Option> { + let virt_addr = caa_addr?; + let calling_area = GuestPtr::::new(virt_addr); + // Ignore errors here, since nothing can be done if an error occurs. + // SAFETY: guest vmsa and ca are always validated before beeing updated + // (core_remap_ca(), core_create_vcpu() or prepare_fw_launch()) so + // they're safe to use. + if let Ok(caa) = unsafe { calling_area.read() } { + let _ = unsafe { calling_area.write(caa.update_no_eoi_required(0)) }; + } + Some(calling_area) + } + + /// Attempts to deliver the specified IRQ into the specified guest CPU + /// so that it will be immediately observed upon guest entry. + /// Returns `true` if the interrupt request was delivered, or `false` + /// if the guest cannot immediately receive an interrupt. + fn deliver_interrupt_immediately(&self, irq: u8, cpu_state: &mut T) -> bool { + if !cpu_state.interrupts_enabled() || cpu_state.in_intr_shadow() { + false + } else { + // This interrupt can only be delivered if it is a higher priority + // than the processor's current priority. + let ppr = self.get_ppr(cpu_state); + if (irq >> 4) <= (ppr >> 4) { + false + } else { + cpu_state.try_deliver_interrupt_immediately(irq) + } + } + } + + fn consume_pending_ipis(&mut self, cpu_shared: &PerCpuShared) { + // Scan the IPI IRR vector and transfer any pending IPIs into the local + // IRR vector. + for (i, irr) in self.irr.iter_mut().enumerate() { + *irr |= cpu_shared.ipi_irr_vector(i); + } + if cpu_shared.nmi_pending() { + self.nmi_pending = true; + } + self.update_required = true; + } + + pub fn present_interrupts( + &mut self, + cpu_shared: &PerCpuShared, + cpu_state: &mut T, + caa_addr: Option, + ) { + // Make sure any interrupts being presented by the host have been + // consumed. + self.consume_host_interrupts(); + + // Consume any pending IPIs. + if cpu_shared.ipi_pending() { + self.consume_pending_ipis(cpu_shared); + } + + if self.update_required { + // Make sure that all previously delivered interrupts have been + // processed before attempting to process any more. + self.check_delivered_interrupts(cpu_state, caa_addr); + self.update_required = false; + + // If an NMI is pending, then present it first. + if self.nmi_pending { + cpu_state.request_nmi(); + self.nmi_pending = false; + } + + let irq = self.scan_irr(); + let current_priority = if self.isr_stack_index != 0 { + self.isr_stack[self.isr_stack_index - 1] + } else { + 0 + }; + + // Assume no lazy EOI can be attempted unless it is recalculated + // below. + self.lazy_eoi_pending = false; + let guest_caa = Self::clear_guest_eoi_pending(caa_addr); + + // This interrupt is a candidate for delivery only if its priority + // exceeds the priority of the highest priority interrupt currently + // in service. This check does not consider TPR, because an + // interrupt lower in priority than TPR must be queued for delivery + // as soon as TPR is lowered. + if (irq & 0xF0) <= (current_priority & 0xF0) { + return; + } + + // Determine whether this interrupt can be injected + // immediately. If not, queue it for delivery when possible. + let try_lazy_eoi = if self.deliver_interrupt_immediately(irq, cpu_state) { + self.interrupt_delivered = true; + + // Use of lazy EOI can safely be attempted, because the + // highest priority interrupt in service is unambiguous. + true + } else { + cpu_state.queue_interrupt(irq); + self.interrupt_queued = true; + + // A lazy EOI can only be attempted if there is no lower + // priority interrupt in service. If a lower priority + // interrupt is in service, then the lazy EOI handler + // won't know whether the lazy EOI is for the one that + // is already in service or the one that is being queued + // here. + self.isr_stack_index == 0 + }; + + // Mark this interrupt in-service. It will be recalled if + // the ISR is examined again before the interrupt is actually + // delivered. + Self::remove_vector_register(&mut self.irr, irq); + self.isr_stack[self.isr_stack_index] = irq; + self.isr_stack_index += 1; + + // Configure a lazy EOI if possible. Lazy EOI is not possible + // for level-sensitive interrupts, because an explicit EOI + // is required to acknowledge the interrupt at the source. + if try_lazy_eoi && !Self::test_vector_register(&self.tmr, irq) { + // A lazy EOI is possible only if there is no other + // interrupt pending. If another interrupt is pending, + // then an explicit EOI will be required to prompt + // delivery of the next interrupt. + if self.scan_irr() == 0 { + if let Some(calling_area) = guest_caa { + // SAFETY: guest vmsa and ca are always validated before beeing upated + // (core_remap_ca(), core_create_vcpu() or prepare_fw_launch()) + // so they're safe to use. + if let Ok(caa) = unsafe { calling_area.read() } { + if unsafe { calling_area.write(caa.update_no_eoi_required(1)).is_ok() } + { + // Only track a pending lazy EOI if the + // calling area page could successfully be + // updated. + self.lazy_eoi_pending = true; + } + } + } + } + } + } + } + + fn perform_host_eoi(vector: u8) { + // Errors from the host are not expected and cannot be meaningfully + // handled, so simply ignore them. + let _r = current_ghcb().specific_eoi(vector, GUEST_VMPL.try_into().unwrap()); + assert!(_r.is_ok()); + } + + fn perform_eoi(&mut self) { + // Pop any in-service interrupt from the stack. If there is no + // interrupt in service, then there is nothing to do. + if self.isr_stack_index == 0 { + return; + } + + self.isr_stack_index -= 1; + let vector = self.isr_stack[self.isr_stack_index]; + if Self::test_vector_register(&self.tmr, vector) { + if Self::test_vector_register(&self.host_tmr, vector) { + Self::perform_host_eoi(vector); + Self::remove_vector_register(&mut self.host_tmr, vector); + } else { + // FIXME: should do something with locally generated + // level-sensitive interrupts. + } + Self::remove_vector_register(&mut self.tmr, vector); + } + + // Schedule the APIC for reevaluation so any additional pending + // interrupt can be processed. + self.update_required = true; + self.lazy_eoi_pending = false; + } + + fn get_isr(&self, index: usize) -> u32 { + let mut value = 0; + for isr in self.isr_stack.into_iter().take(self.isr_stack_index) { + if (usize::from(isr >> 5)) == index { + value |= 1 << (isr & 0x1F) + } + } + value + } + + fn post_interrupt(&mut self, irq: u8, level_sensitive: bool) { + // Set the appropriate bit in the IRR. Once set, signal that interrupt + // processing is required before returning to the guest. + Self::insert_vector_register(&mut self.irr, irq); + if level_sensitive { + Self::insert_vector_register(&mut self.tmr, irq); + } + self.update_required = true; + } + + fn post_icr_interrupt(&mut self, icr: ApicIcr) { + if icr.message_type() == IcrMessageType::Nmi { + self.nmi_pending = true; + self.update_required = true; + } else { + self.post_interrupt(icr.vector(), false); + } + } + + fn post_ipi_one_target(cpu: &PerCpuShared, icr: ApicIcr) { + if icr.message_type() == IcrMessageType::Nmi { + cpu.request_nmi(); + } else { + cpu.request_ipi(icr.vector()); + } + } + + /// Sends an IPI using the APIC logical destination mode. Returns `true` if + /// the host needs to be notified. + fn send_logical_ipi(&mut self, icr: ApicIcr) -> bool { + let mut signal = false; + + // Check whether the current CPU matches the destination. + let destination = icr.destination(); + let apic_id = this_cpu().get_apic_id(); + if Self::logical_destination_match(destination, apic_id) { + self.post_icr_interrupt(icr); + } + + // Enumerate all CPUs to see which have APIC IDs that match the + // requested destination. Skip the current CPU, since it was checked + // above. + for cpu_ref in PERCPU_AREAS.iter() { + let cpu = cpu_ref.as_cpu_ref(); + let this_apic_id = cpu.apic_id(); + if (this_apic_id != apic_id) + && Self::logical_destination_match(destination, this_apic_id) + { + Self::post_ipi_one_target(cpu, icr); + signal = true; + } + } + + signal + } + + /// Returns `true` if the specified APIC ID matches the given logical destination. + fn logical_destination_match(destination: u32, apic_id: u32) -> bool { + // CHeck for a cluster match. + if (destination >> 16) != (apic_id >> 4) { + false + } else { + let bit = 1u32 << (apic_id & 0xF); + (destination & bit) != 0 + } + } + + /// Send an IPI using the APIC physical destination mode. Returns `true` if + /// the host needs to be notified. + fn send_physical_ipi(&mut self, icr: ApicIcr) -> bool { + // If the target APIC ID matches the current processor, then treat this + // as a self-IPI. Otherwise, locate the target processor by APIC ID. + let destination = icr.destination(); + if destination == this_cpu().get_apic_id() { + self.post_interrupt(icr.vector(), false); + false + } else { + // If the target CPU cannot be located, then simply drop the + // request. + if let Some(cpu) = PERCPU_AREAS.get(destination) { + cpu.request_ipi(icr.vector()); + true + } else { + false + } + } + } + + /// Sends an IPI using the specified ICR. + fn send_ipi(&mut self, icr: ApicIcr) { + let (signal_host, include_others, include_self) = match icr.destination_shorthand() { + IcrDestFmt::Dest => { + if icr.destination() == 0xFFFF_FFFF { + // This is a broadcast, so treat it as all with self. + (true, true, true) + } else { + let signal_host = if icr.destination_mode() { + self.send_logical_ipi(icr) + } else { + self.send_physical_ipi(icr) + }; + + // Any possible self-IPI was handled above as part of + // delivery to the correct destination. + (signal_host, false, false) + } + } + IcrDestFmt::OnlySelf => (false, false, true), + IcrDestFmt::AllButSelf => (true, true, false), + IcrDestFmt::AllWithSelf => (true, true, true), + }; + + if include_others { + // Enumerate all processors in the system except for the + // current CPU and indicate that an IPI has been requested. + let apic_id = this_cpu().get_apic_id(); + for cpu_ref in PERCPU_AREAS.iter() { + let cpu = cpu_ref.as_cpu_ref(); + if cpu.apic_id() != apic_id { + Self::post_ipi_one_target(cpu, icr); + } + } + } + + if include_self { + self.post_icr_interrupt(icr); + } + + if signal_host { + // Calculate an ICR value to use for a host IPI request. This will + // be a fixed interrupt on the interrupt notification vector using + // the destination format specified in the ICR value. + let mut hv_icr = ApicIcr::new() + .with_vector(INT_INJ_VECTOR as u8) + .with_message_type(IcrMessageType::Fixed) + .with_destination_mode(icr.destination_mode()) + .with_destination_shorthand(icr.destination_shorthand()) + .with_destination(icr.destination()); + + // Avoid a self interrupt if the target is all-including-self, + // because the self IPI was delivered above. In the case of + // a logical cluster IPI, it is impractical to avoid the self + // interrupt, but such cases should be rare. + if hv_icr.destination_shorthand() == IcrDestFmt::AllWithSelf { + hv_icr.set_destination_shorthand(IcrDestFmt::AllButSelf); + } + + SVSM_PLATFORM.post_irq(hv_icr.into()).unwrap(); + } + } + + /// Reads an APIC register, returning its value, or an error if an invalid + /// register is requested. + pub fn read_register( + &mut self, + cpu_shared: &PerCpuShared, + cpu_state: &mut T, + caa_addr: Option, + register: u64, + ) -> Result { + // Rewind any undelivered interrupt so it is reflected in any register + // read. + self.check_delivered_interrupts(cpu_state, caa_addr); + + match register { + APIC_REGISTER_APIC_ID => Ok(u64::from(cpu_shared.apic_id())), + APIC_REGISTER_IRR_0..=APIC_REGISTER_IRR_7 => { + let offset = register - APIC_REGISTER_IRR_0; + let index: usize = offset.try_into().unwrap(); + Ok(self.irr[index] as u64) + } + APIC_REGISTER_ISR_0..=APIC_REGISTER_ISR_7 => { + let offset = register - APIC_REGISTER_ISR_0; + Ok(self.get_isr(offset.try_into().unwrap()) as u64) + } + APIC_REGISTER_TMR_0..=APIC_REGISTER_TMR_7 => { + let offset = register - APIC_REGISTER_TMR_0; + let index: usize = offset.try_into().unwrap(); + Ok(self.tmr[index] as u64) + } + APIC_REGISTER_TPR => Ok(cpu_state.get_tpr() as u64), + APIC_REGISTER_PPR => Ok(self.get_ppr(cpu_state) as u64), + _ => Err(SvsmError::Apic(Emulation)), + } + } + + fn handle_icr_write(&mut self, value: u64) -> Result<(), SvsmError> { + let icr = ApicIcr::from(value); + + // Verify that this message type is supported. + let valid_type = match icr.message_type() { + IcrMessageType::Fixed => { + // Only asserted edge-triggered interrupts can be handled. + !icr.trigger_mode() && icr.assert() + } + IcrMessageType::Nmi => true, + _ => false, + }; + + if !valid_type { + return Err(SvsmError::Apic(Emulation)); + } + + self.send_ipi(icr); + + Ok(()) + } + + /// Writes a value to the specified APIC register. Returns an error if an + /// invalid register or value is specified. + pub fn write_register( + &mut self, + cpu_state: &mut T, + caa_addr: Option, + register: u64, + value: u64, + ) -> Result<(), SvsmError> { + // Rewind any undelivered interrupt so it is correctly processed by + // any register write. + self.check_delivered_interrupts(cpu_state, caa_addr); + + match register { + APIC_REGISTER_TPR => { + // TPR must be an 8-bit value. + let tpr = u8::try_from(value).map_err(|_| Emulation)?; + cpu_state.set_tpr(tpr); + Ok(()) + } + APIC_REGISTER_EOI => { + self.perform_eoi(); + Ok(()) + } + APIC_REGISTER_ICR => self.handle_icr_write(value), + APIC_REGISTER_SELF_IPI => { + let vector = u8::try_from(value).map_err(|_| Emulation)?; + self.post_interrupt(vector, false); + Ok(()) + } + _ => Err(SvsmError::Apic(Emulation)), + } + } + + pub fn configure_vector(&mut self, vector: u8, allowed: bool) { + let index = (vector >> 5) as usize; + let mask = 1 << (vector & 31); + if allowed { + self.allowed_irr[index] |= mask; + } else { + self.allowed_irr[index] &= !mask; + } + } + + fn signal_one_host_interrupt(&mut self, vector: u8, level_sensitive: bool) -> bool { + let index = (vector >> 5) as usize; + let mask = 1 << (vector & 31); + if (self.allowed_irr[index] & mask) != 0 { + self.post_interrupt(vector, level_sensitive); + true + } else { + false + } + } + + fn signal_several_interrupts(&mut self, group: usize, mut bits: u32) { + let vector = (group as u8) << 5; + while bits != 0 { + let index = 31 - bits.leading_zeros(); + bits &= !(1 << index); + self.post_interrupt(vector + index as u8, false); + } + } + + fn consume_host_interrupts(&mut self) { + let hv_doorbell = this_cpu().hv_doorbell().unwrap(); + let vmpl_event_mask = hv_doorbell.per_vmpl_events.swap(0, Ordering::Relaxed); + // Ignore events other than for the guest VMPL. + if vmpl_event_mask & (1 << (GUEST_VMPL - 1)) == 0 { + return; + } + + let descriptor = &hv_doorbell.per_vmpl[GUEST_VMPL - 1]; + + // First consume any level-sensitive vector that is present. + let mut flags = HVExtIntStatus::from(descriptor.status.load(Ordering::Relaxed)); + if flags.level_sensitive() { + let mut vector; + // Consume the correct vector atomically. + loop { + vector = flags.pending_vector(); + let new_flags = flags.with_pending_vector(0).with_level_sensitive(false); + if let Err(fail_flags) = descriptor.status.compare_exchange( + flags.into(), + new_flags.into(), + Ordering::Relaxed, + Ordering::Relaxed, + ) { + flags = fail_flags.into(); + } else { + flags = new_flags; + break; + } + } + + if self.signal_one_host_interrupt(vector, true) { + Self::insert_vector_register(&mut self.host_tmr, vector); + } + } + + // If a single vector is present, then signal it, otherwise + // process the entire IRR. + if flags.multiple_vectors() { + // Clear the multiple vectors flag first so that additional + // interrupts are presented via the 8-bit vector. This must + // be done before the IRR is scanned so that if additional + // vectors are presented later, the multiple vectors flag + // will be set again. + let multiple_vectors_mask: u32 = + HVExtIntStatus::new().with_multiple_vectors(true).into(); + descriptor + .status + .fetch_and(!multiple_vectors_mask, Ordering::Relaxed); + + // Handle the special case of vector 31. + if flags.vector_31() { + descriptor + .status + .fetch_and(!(1u32 << 31), Ordering::Relaxed); + self.signal_one_host_interrupt(31, false); + } + + for i in 1..8 { + let bits = descriptor.irr[i - 1].swap(0, Ordering::Relaxed); + self.signal_several_interrupts(i, bits & self.allowed_irr[i]); + } + } else if flags.pending_vector() != 0 { + // Atomically consume this interrupt. If it cannot be consumed + // atomically, then it must be because some other interrupt + // has been presented, and that can be consumed in another + // pass. + let new_flags = flags.with_pending_vector(0); + if descriptor + .status + .compare_exchange( + flags.into(), + new_flags.into(), + Ordering::Relaxed, + Ordering::Relaxed, + ) + .is_ok() + { + self.signal_one_host_interrupt(flags.pending_vector(), false); + } + } + } + + fn handoff_to_host(&mut self) { + let hv_doorbell = this_cpu().hv_doorbell().unwrap(); + let descriptor = &hv_doorbell.per_vmpl[GUEST_VMPL - 1]; + // Establish the IRR as holding multiple vectors regardless of the + // number of active vectors, as this makes transferring IRR state + // simpler. + let multiple_vectors_mask: u32 = HVExtIntStatus::new().with_multiple_vectors(true).into(); + descriptor + .status + .fetch_or(multiple_vectors_mask, Ordering::Relaxed); + + // Indicate whether an NMI is pending. + if self.nmi_pending { + let nmi_mask: u32 = HVExtIntStatus::new().with_nmi_pending(true).into(); + descriptor.status.fetch_or(nmi_mask, Ordering::Relaxed); + } + + // If a single, edge-triggered interrupt is present in the interrupt + // descriptor, then transfer it to the local IRR. Level-sensitive + // interrupts can be left alone since the host must be prepared to + // consume those directly. Note that consuming the interrupt does not + // require zeroing the vector, since the host is supposed to ignore the + // vector field when multiple vectors are present (except for the case + // of level-sensitive interrupts). + let flags = HVExtIntStatus::from(descriptor.status.load(Ordering::Relaxed)); + if flags.pending_vector() >= 31 && !flags.level_sensitive() { + Self::insert_vector_register(&mut self.irr, flags.pending_vector()); + } + + // Copy vector 31 if required, and then insert all of the additional + // IRR fields into the host IRR. + if self.irr[0] & 0x8000_0000 != 0 { + let irr_31_mask: u32 = HVExtIntStatus::new().vector_31().into(); + descriptor.status.fetch_or(irr_31_mask, Ordering::Relaxed); + } + + for i in 1..8 { + descriptor.irr[i - 1].fetch_or(self.irr[i], Ordering::Relaxed); + } + + // Now transfer the contents of the ISR stack into the host ISR. + let mut new_isr = [0u32; 8]; + for i in 0..self.isr_stack_index { + let index = (self.isr_stack[i] >> 5) as usize; + let bit = 1u32 << (self.isr_stack[i] & 31); + new_isr[index] |= bit; + } + + for (host_isr, temp_isr) in descriptor.isr.iter().zip(new_isr.iter()) { + host_isr.store(*temp_isr, Ordering::Relaxed); + } + } + + pub fn disable_apic_emulation( + &mut self, + cpu_state: &mut T, + caa_addr: Option, + ) { + // Ensure that any previous interrupt delivery is complete. + self.check_delivered_interrupts(cpu_state, caa_addr); + + // Rewind any pending NMI. + if cpu_state.check_and_clear_pending_nmi() { + self.nmi_pending = true; + } + + // Hand the current APIC state off to the host. + self.handoff_to_host(); + + let _ = Self::clear_guest_eoi_pending(caa_addr); + + // Disable alternate injection altogether. + cpu_state.disable_alternate_injection(); + + // Finally, ask the host to take over APIC + // emulation. + current_ghcb() + .disable_alternate_injection( + cpu_state.get_tpr(), + cpu_state.in_intr_shadow(), + cpu_state.interrupts_enabled(), + ) + .expect("Failed to disable alterate injection"); + } +} diff --git a/stage2/src/cpu/control_regs.rs b/stage2/src/cpu/control_regs.rs new file mode 100644 index 000000000..f5bfb2a17 --- /dev/null +++ b/stage2/src/cpu/control_regs.rs @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::features::cpu_has_pge; +use crate::address::{Address, PhysAddr}; +use crate::cpu::features::{cpu_has_smap, cpu_has_smep}; +use crate::platform::SvsmPlatform; +use bitflags::bitflags; +use core::arch::asm; + +pub fn cr0_init() { + let mut cr0 = read_cr0(); + + cr0.insert(CR0Flags::WP); // Enable Write Protection + cr0.remove(CR0Flags::NW); // Enable caches ... + cr0.remove(CR0Flags::CD); // ... if not already happened + + write_cr0(cr0); +} + +pub fn cr4_init(platform: &dyn SvsmPlatform) { + let mut cr4 = read_cr4(); + + cr4.insert(CR4Flags::PSE); // Enable Page Size Extensions + + // All processors that are capable of virtualization will support global + // page table entries, so there is no reason to support any processor that + // does not enumerate PGE capability. + assert!(cpu_has_pge(platform), "CPU does not support PGE"); + + cr4.insert(CR4Flags::PGE); // Enable Global Pages + + if !cfg!(feature = "nosmep") { + assert!(cpu_has_smep(platform), "CPU does not support SMEP"); + cr4.insert(CR4Flags::SMEP); + } + + if !cfg!(feature = "nosmap") { + assert!(cpu_has_smap(platform), "CPU does not support SMAP"); + cr4.insert(CR4Flags::SMAP); + } + + write_cr4(cr4); +} + +pub fn cr0_sse_enable() { + let mut cr0 = read_cr0(); + + cr0.insert(CR0Flags::MP); + cr0.remove(CR0Flags::EM); + + // No Lazy context switching + cr0.remove(CR0Flags::TS); + + write_cr0(cr0); +} + +pub fn cr4_osfxsr_enable() { + let mut cr4 = read_cr4(); + + cr4.insert(CR4Flags::OSFXSR); + + write_cr4(cr4); +} + +pub fn cr4_xsave_enable() { + let mut cr4 = read_cr4(); + + cr4.insert(CR4Flags::OSXSAVE); + + write_cr4(cr4); +} + +bitflags! { + #[derive(Debug, Clone, Copy)] + pub struct CR0Flags: u64 { + const PE = 1 << 0; // Protection Enabled + const MP = 1 << 1; // Monitor Coprocessor + const EM = 1 << 2; // Emulation + const TS = 1 << 3; // Task Switched + const ET = 1 << 4; // Extension Type + const NE = 1 << 5; // Numeric Error + const WP = 1 << 16; // Write Protect + const AM = 1 << 18; // Alignment Mask + const NW = 1 << 29; // Not Writethrough + const CD = 1 << 30; // Cache Disable + const PG = 1 << 31; // Paging + } +} + +pub fn read_cr0() -> CR0Flags { + let cr0: u64; + + unsafe { + asm!("mov %cr0, %rax", + out("rax") cr0, + options(att_syntax)); + } + + CR0Flags::from_bits_truncate(cr0) +} + +pub fn write_cr0(cr0: CR0Flags) { + let reg = cr0.bits(); + + unsafe { + asm!("mov %rax, %cr0", + in("rax") reg, + options(att_syntax)); + } +} + +pub fn read_cr2() -> usize { + let ret: usize; + unsafe { + asm!("mov %cr2, %rax", + out("rax") ret, + options(att_syntax)); + } + ret +} + +pub fn write_cr2(cr2: usize) { + unsafe { + asm!("mov %rax, %cr2", + in("rax") cr2, + options(att_syntax)); + } +} + +pub fn read_cr3() -> PhysAddr { + let ret: usize; + unsafe { + asm!("mov %cr3, %rax", + out("rax") ret, + options(att_syntax)); + } + PhysAddr::from(ret) +} + +pub fn write_cr3(cr3: PhysAddr) { + unsafe { + asm!("mov %rax, %cr3", + in("rax") cr3.bits(), + options(att_syntax)); + } +} + +bitflags! { + #[derive(Debug, Clone, Copy)] + pub struct CR4Flags: u64 { + const VME = 1 << 0; // Virtual-8086 Mode Extensions + const PVI = 1 << 1; // Protected-Mode Virtual Interrupts + const TSD = 1 << 2; // Time Stamp Disable + const DE = 1 << 3; // Debugging Extensions + const PSE = 1 << 4; // Page Size Extensions + const PAE = 1 << 5; // Physical-Address Extension + const MCE = 1 << 6; // Machine Check Enable + const PGE = 1 << 7; // Page-Global Enable + const PCE = 1 << 8; // Performance-Monitoring Counter Enable + const OSFXSR = 1 << 9; // Operating System FXSAVE/FXRSTOR Support + const OSXMMEXCPT = 1 << 10; // Operating System Unmasked Exception Support + const UMIP = 1 << 11; // User Mode Instruction Prevention + const LA57 = 1 << 12; // 57-bit linear address + const FSGSBASE = 1 << 16; // Enable RDFSBASE, RDGSBASE, WRFSBASE, and + // WRGSBASE instructions + const PCIDE = 1 << 17; // Process Context Identifier Enable + const OSXSAVE = 1 << 18; // XSAVE and Processor Extended States Enable Bit + const SMEP = 1 << 20; // Supervisor Mode Execution Prevention + const SMAP = 1 << 21; // Supervisor Mode Access Protection + const PKE = 1 << 22; // Protection Key Enable + const CET = 1 << 23; // Control-flow Enforcement Technology + } +} + +pub fn read_cr4() -> CR4Flags { + let cr4: u64; + + unsafe { + asm!("mov %cr4, %rax", + out("rax") cr4, + options(att_syntax)); + } + + CR4Flags::from_bits_truncate(cr4) +} + +pub fn write_cr4(cr4: CR4Flags) { + let reg = cr4.bits(); + + unsafe { + asm!("mov %rax, %cr4", + in("rax") reg, + options(att_syntax)); + } +} diff --git a/stage2/src/cpu/cpuid.rs b/stage2/src/cpu/cpuid.rs new file mode 100644 index 000000000..57ec51da4 --- /dev/null +++ b/stage2/src/cpu/cpuid.rs @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::utils::immut_after_init::ImmutAfterInitRef; +use cpuarch::snp_cpuid::SnpCpuidTable; + +use core::arch::asm; + +static CPUID_PAGE: ImmutAfterInitRef<'_, SnpCpuidTable> = ImmutAfterInitRef::uninit(); + +pub fn register_cpuid_table(table: &'static SnpCpuidTable) { + CPUID_PAGE + .init_from_ref(table) + .expect("Could not initialize CPUID page"); +} + +#[derive(Clone, Copy, Debug)] +#[repr(C, packed)] +pub struct CpuidLeaf { + pub cpuid_fn: u32, + pub cpuid_subfn: u32, + pub eax: u32, + pub ebx: u32, + pub ecx: u32, + pub edx: u32, +} + +impl CpuidLeaf { + pub fn new(cpuid_fn: u32, cpuid_subfn: u32) -> Self { + CpuidLeaf { + cpuid_fn, + cpuid_subfn, + eax: 0, + ebx: 0, + ecx: 0, + edx: 0, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct CpuidResult { + pub eax: u32, + pub ebx: u32, + pub ecx: u32, + pub edx: u32, +} + +impl CpuidResult { + pub fn get(cpuid_fn: u32, cpuid_subfn: u32) -> Self { + let mut result_eax: u32; + let mut result_ebx: u32; + let mut result_ecx: u32; + let mut result_edx: u32; + unsafe { + asm!("push %rbx", + "cpuid", + "movl %ebx, %edi", + "pop %rbx", + in("eax") cpuid_fn, + in("ecx") cpuid_subfn, + lateout("eax") result_eax, + lateout("edi") result_ebx, + lateout("ecx") result_ecx, + lateout("edx") result_edx, + options(att_syntax)); + } + Self { + eax: result_eax, + ebx: result_ebx, + ecx: result_ecx, + edx: result_edx, + } + } +} + +pub fn cpuid_table_raw(eax: u32, ecx: u32, xcr0: u64, xss: u64) -> Option { + let count: usize = CPUID_PAGE.count as usize; + + for i in 0..count { + if eax == CPUID_PAGE.func[i].eax_in + && ecx == CPUID_PAGE.func[i].ecx_in + && xcr0 == CPUID_PAGE.func[i].xcr0_in + && xss == CPUID_PAGE.func[i].xss_in + { + return Some(CpuidResult { + eax: CPUID_PAGE.func[i].eax_out, + ebx: CPUID_PAGE.func[i].ebx_out, + ecx: CPUID_PAGE.func[i].ecx_out, + edx: CPUID_PAGE.func[i].edx_out, + }); + } + } + + None +} + +pub fn cpuid_table(eax: u32) -> Option { + cpuid_table_raw(eax, 0, 0, 0) +} + +pub fn dump_cpuid_table() { + let count = CPUID_PAGE.count as usize; + + log::trace!("CPUID Table entry count: {}", count); + + for i in 0..count { + let eax_in = CPUID_PAGE.func[i].eax_in; + let ecx_in = CPUID_PAGE.func[i].ecx_in; + let xcr0_in = CPUID_PAGE.func[i].xcr0_in; + let xss_in = CPUID_PAGE.func[i].xss_in; + let eax_out = CPUID_PAGE.func[i].eax_out; + let ebx_out = CPUID_PAGE.func[i].ebx_out; + let ecx_out = CPUID_PAGE.func[i].ecx_out; + let edx_out = CPUID_PAGE.func[i].edx_out; + log::trace!("EAX_IN: {:#010x} ECX_IN: {:#010x} XCR0_IN: {:#010x} XSS_IN: {:#010x} EAX_OUT: {:#010x} EBX_OUT: {:#010x} ECX_OUT: {:#010x} EDX_OUT: {:#010x}", + eax_in, ecx_in, xcr0_in, xss_in, eax_out, ebx_out, ecx_out, edx_out); + } +} diff --git a/stage2/src/cpu/efer.rs b/stage2/src/cpu/efer.rs new file mode 100644 index 000000000..681db39a1 --- /dev/null +++ b/stage2/src/cpu/efer.rs @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::msr::{read_msr, write_msr, EFER}; +use bitflags::bitflags; + +bitflags! { + #[derive(Clone, Copy, Debug)] + pub struct EFERFlags: u64 { + const SCE = 1 << 0; // System Call Extensions + const LME = 1 << 8; // Long Mode Enable + const LMA = 1 << 10; // Long Mode Active + const NXE = 1 << 11; // No-Execute Enable + const SVME = 1 << 12; // Secure Virtual Machine Enable + const LMSLE = 1 << 13; // Long Mode Segment Limit Enable + const FFXSR = 1 << 14; // Fast FXSAVE/FXRSTOR + const TCE = 1 << 15; // Translation Cache Extension + const MCOMMIT = 1 << 17; // Enable MCOMMIT instruction + const INTWB = 1 << 18; // Interruptible WBINVD/WBNOINVD enable + const UAIE = 1 << 20; // Upper Address Ignore Enable + } +} + +pub fn read_efer() -> EFERFlags { + EFERFlags::from_bits_truncate(read_msr(EFER)) +} + +pub fn write_efer(efer: EFERFlags) { + let val = efer.bits(); + write_msr(EFER, val); +} diff --git a/stage2/src/cpu/extable.rs b/stage2/src/cpu/extable.rs new file mode 100644 index 000000000..e8925ad8e --- /dev/null +++ b/stage2/src/cpu/extable.rs @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern "C" { + pub static exception_table_start: u8; + pub static exception_table_end: u8; +} + +use super::idt::common::X86ExceptionContext; +use crate::address::{Address, VirtAddr}; +use core::mem; + +#[repr(C, packed)] +struct ExceptionTableEntry { + start: VirtAddr, + end: VirtAddr, +} + +fn check_exception_table(rip: VirtAddr) -> VirtAddr { + unsafe { + let ex_table_start = VirtAddr::from(&raw const exception_table_start); + let ex_table_end = VirtAddr::from(&raw const exception_table_end); + let mut current = ex_table_start; + + loop { + let addr = current.as_ptr::(); + + let start = (*addr).start; + let end = (*addr).end; + + if rip >= start && rip < end { + return end; + } + + current = current + mem::size_of::(); + if current >= ex_table_end { + break; + } + } + } + + rip +} + +pub fn dump_exception_table() { + unsafe { + let ex_table_start = VirtAddr::from(&raw const exception_table_start); + let ex_table_end = VirtAddr::from(&raw const exception_table_end); + let mut current = ex_table_start; + + loop { + let addr = current.as_ptr::(); + + let start = (*addr).start; + let end = (*addr).end; + + log::info!("Extable Entry {:#018x}-{:#018x}", start, end); + + current = current + mem::size_of::(); + if current >= ex_table_end { + break; + } + } + } +} + +pub fn handle_exception_table(ctx: &mut X86ExceptionContext) -> bool { + let ex_rip = VirtAddr::from(ctx.frame.rip); + let new_rip = check_exception_table(ex_rip); + + // If an exception hit in an area covered by the exception table, set rcx to -1 + if new_rip != ex_rip { + ctx.regs.rcx = !0usize; + ctx.frame.rip = new_rip.bits(); + return true; + } + + false +} diff --git a/stage2/src/cpu/features.rs b/stage2/src/cpu/features.rs new file mode 100644 index 000000000..6ff7d3bdc --- /dev/null +++ b/stage2/src/cpu/features.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::platform::SvsmPlatform; + +const X86_FEATURE_PGE: u32 = 13; +const X86_FEATURE_SMEP: u32 = 7; +const X86_FEATURE_SMAP: u32 = 20; + +pub fn cpu_has_pge(platform: &dyn SvsmPlatform) -> bool { + let ret = platform.cpuid(0x00000001); + + match ret { + None => false, + Some(c) => (c.edx >> X86_FEATURE_PGE) & 1 == 1, + } +} + +pub fn cpu_has_smep(platform: &dyn SvsmPlatform) -> bool { + let ret = platform.cpuid(0x0000_0007); + + match ret { + None => false, + Some(c) => (c.ebx >> X86_FEATURE_SMEP & 1) == 1, + } +} + +pub fn cpu_has_smap(platform: &dyn SvsmPlatform) -> bool { + let ret = platform.cpuid(0x0000_0007); + + match ret { + None => false, + Some(c) => (c.ebx >> X86_FEATURE_SMAP & 1) == 1, + } +} diff --git a/stage2/src/cpu/gdt.rs b/stage2/src/cpu/gdt.rs new file mode 100644 index 000000000..ea2c4e8ef --- /dev/null +++ b/stage2/src/cpu/gdt.rs @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::tss::X86Tss; +use crate::address::VirtAddr; +use crate::locking::{RWLock, ReadLockGuard, WriteLockGuard}; +use crate::types::{SVSM_CS, SVSM_DS, SVSM_TSS}; +use core::arch::asm; +use core::mem; + +#[repr(C, packed(2))] +#[derive(Clone, Copy, Debug, Default)] +struct GDTDesc { + size: u16, + addr: VirtAddr, +} + +#[derive(Copy, Clone, Debug, Default)] +pub struct GDTEntry(u64); + +impl GDTEntry { + pub const fn from_raw(entry: u64) -> Self { + Self(entry) + } + + pub fn to_raw(&self) -> u64 { + self.0 + } + + pub const fn null() -> Self { + Self(0u64) + } + + pub const fn code_64_kernel() -> Self { + Self(0x00af9a000000ffffu64) + } + + pub const fn data_64_kernel() -> Self { + Self(0x00cf92000000ffffu64) + } + + pub const fn code_64_user() -> Self { + Self(0x00affb000000ffffu64) + } + + pub const fn data_64_user() -> Self { + Self(0x00cff2000000ffffu64) + } +} + +const GDT_SIZE: u16 = 8; + +#[derive(Copy, Clone, Debug, Default)] +pub struct GDT { + entries: [GDTEntry; GDT_SIZE as usize], +} + +impl GDT { + pub const fn new() -> Self { + Self { + entries: [ + GDTEntry::null(), + GDTEntry::code_64_kernel(), + GDTEntry::data_64_kernel(), + GDTEntry::code_64_user(), + GDTEntry::data_64_user(), + GDTEntry::null(), + GDTEntry::null(), + GDTEntry::null(), + ], + } + } + + unsafe fn set_tss_entry(&mut self, desc0: GDTEntry, desc1: GDTEntry) { + let idx = (SVSM_TSS / 8) as usize; + + let tss_entries = &self.entries[idx..idx + 1].as_mut_ptr(); + + tss_entries.add(0).write_volatile(desc0); + tss_entries.add(1).write_volatile(desc1); + } + + unsafe fn clear_tss_entry(&mut self) { + self.set_tss_entry(GDTEntry::null(), GDTEntry::null()); + } + + pub fn load_tss(&mut self, tss: &X86Tss) { + let (desc0, desc1) = tss.to_gdt_entry(); + + unsafe { + self.set_tss_entry(desc0, desc1); + asm!("ltr %ax", in("ax") SVSM_TSS, options(att_syntax)); + self.clear_tss_entry() + } + } + + pub fn kernel_cs(&self) -> GDTEntry { + self.entries[(SVSM_CS / 8) as usize] + } + + pub fn kernel_ds(&self) -> GDTEntry { + self.entries[(SVSM_DS / 8) as usize] + } +} + +impl ReadLockGuard<'static, GDT> { + /// Load a GDT. Its lifetime must be static so that its entries are + /// always available to the CPU. + pub fn load(&self) { + let gdt_desc = self.descriptor(); + unsafe { + asm!(r#" /* Load GDT */ + lgdt (%rax) + + /* Reload data segments */ + movw %cx, %ds + movw %cx, %es + movw %cx, %fs + movw %cx, %gs + movw %cx, %ss + + /* Reload code segment */ + pushq %rdx + leaq 1f(%rip), %rax + pushq %rax + lretq + 1: + "#, + in("rax") &gdt_desc, + in("rdx") SVSM_CS, + in("rcx") SVSM_DS, + options(att_syntax)); + } + } + + fn descriptor(&self) -> GDTDesc { + GDTDesc { + size: (GDT_SIZE * 8) - 1, + addr: VirtAddr::from(self.entries.as_ptr()), + } + } + + pub fn base_limit(&self) -> (u64, u32) { + let gdt_entries = GDT_SIZE as usize; + let base: *const GDT = core::ptr::from_ref(self); + let limit = ((mem::size_of::() * gdt_entries) - 1) as u32; + (base as u64, limit) + } +} + +static GDT: RWLock = RWLock::new(GDT::new()); + +pub fn gdt() -> ReadLockGuard<'static, GDT> { + GDT.lock_read() +} + +pub fn gdt_mut() -> WriteLockGuard<'static, GDT> { + GDT.lock_write() +} diff --git a/stage2/src/cpu/idt/common.rs b/stage2/src/cpu/idt/common.rs new file mode 100644 index 000000000..f88b72446 --- /dev/null +++ b/stage2/src/cpu/idt/common.rs @@ -0,0 +1,421 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use crate::address::{Address, VirtAddr}; +use crate::cpu::control_regs::{read_cr0, read_cr4}; +use crate::cpu::efer::read_efer; +use crate::cpu::gdt::gdt; +use crate::cpu::registers::{X86GeneralRegs, X86InterruptFrame}; +use crate::insn_decode::{InsnError, InsnMachineCtx, InsnMachineMem, Register, SegRegister}; +use crate::locking::{RWLock, ReadLockGuard, WriteLockGuard}; +use crate::mm::GuestPtr; +use crate::platform::SVSM_PLATFORM; +use crate::types::{Bytes, SVSM_CS}; +use alloc::boxed::Box; +use core::arch::{asm, global_asm}; +use core::mem; + +pub const DE_VECTOR: usize = 0; +pub const DB_VECTOR: usize = 1; +pub const NMI_VECTOR: usize = 2; +pub const BP_VECTOR: usize = 3; +pub const OF_VECTOR: usize = 4; +pub const BR_VECTOR: usize = 5; +pub const UD_VECTOR: usize = 6; +pub const NM_VECTOR: usize = 7; +pub const DF_VECTOR: usize = 8; +pub const CSO_VECTOR: usize = 9; +pub const TS_VECTOR: usize = 10; +pub const NP_VECTOR: usize = 11; +pub const SS_VECTOR: usize = 12; +pub const GP_VECTOR: usize = 13; +pub const PF_VECTOR: usize = 14; +pub const MF_VECTOR: usize = 16; +pub const AC_VECTOR: usize = 17; +pub const MCE_VECTOR: usize = 18; +pub const XF_VECTOR: usize = 19; +pub const CP_VECTOR: usize = 21; +pub const HV_VECTOR: usize = 28; +pub const VC_VECTOR: usize = 29; +pub const SX_VECTOR: usize = 30; + +pub const INT_INJ_VECTOR: usize = 0x50; + +bitflags::bitflags! { + /// Page fault error code flags. + #[derive(Clone, Copy, Debug, PartialEq)] + pub struct PageFaultError :u32 { + const P = 1 << 0; + const W = 1 << 1; + const U = 1 << 2; + const R = 1 << 3; + const I = 1 << 4; + } +} + +#[repr(C, packed)] +#[derive(Default, Debug, Clone, Copy)] +pub struct X86ExceptionContext { + pub regs: X86GeneralRegs, + pub error_code: usize, + pub frame: X86InterruptFrame, +} + +impl InsnMachineCtx for X86ExceptionContext { + fn read_efer(&self) -> u64 { + read_efer().bits() + } + + fn read_seg(&self, seg: SegRegister) -> u64 { + match seg { + SegRegister::CS => gdt().kernel_cs().to_raw(), + _ => gdt().kernel_ds().to_raw(), + } + } + + fn read_cr0(&self) -> u64 { + read_cr0().bits() + } + + fn read_cr4(&self) -> u64 { + read_cr4().bits() + } + + fn read_reg(&self, reg: Register) -> usize { + match reg { + Register::Rax => self.regs.rax, + Register::Rdx => self.regs.rdx, + Register::Rcx => self.regs.rcx, + Register::Rbx => self.regs.rdx, + Register::Rsp => self.frame.rsp, + Register::Rbp => self.regs.rbp, + Register::Rdi => self.regs.rdi, + Register::Rsi => self.regs.rsi, + Register::R8 => self.regs.r8, + Register::R9 => self.regs.r9, + Register::R10 => self.regs.r10, + Register::R11 => self.regs.r11, + Register::R12 => self.regs.r12, + Register::R13 => self.regs.r13, + Register::R14 => self.regs.r14, + Register::R15 => self.regs.r15, + Register::Rip => self.frame.rip, + } + } + + fn read_flags(&self) -> usize { + self.frame.flags + } + + fn write_reg(&mut self, reg: Register, val: usize) { + match reg { + Register::Rax => self.regs.rax = val, + Register::Rdx => self.regs.rdx = val, + Register::Rcx => self.regs.rcx = val, + Register::Rbx => self.regs.rdx = val, + Register::Rsp => self.frame.rsp = val, + Register::Rbp => self.regs.rbp = val, + Register::Rdi => self.regs.rdi = val, + Register::Rsi => self.regs.rsi = val, + Register::R8 => self.regs.r8 = val, + Register::R9 => self.regs.r9 = val, + Register::R10 => self.regs.r10 = val, + Register::R11 => self.regs.r11 = val, + Register::R12 => self.regs.r12 = val, + Register::R13 => self.regs.r13 = val, + Register::R14 => self.regs.r14 = val, + Register::R15 => self.regs.r15 = val, + Register::Rip => self.frame.rip = val, + } + } + + fn read_cpl(&self) -> usize { + self.frame.cs & 3 + } + + fn map_linear_addr( + &self, + la: usize, + _write: bool, + _fetch: bool, + ) -> Result>, InsnError> { + if user_mode(self) { + todo!(); + } else { + Ok(Box::new(GuestPtr::::new(VirtAddr::from(la)))) + } + } + + fn ioio_perm(&self, _port: u16, _size: Bytes, _io_read: bool) -> bool { + // Check if the IO port can be supported by user mode + todo!(); + } + + fn ioio_in(&self, port: u16, size: Bytes) -> Result { + let io_port = SVSM_PLATFORM.get_io_port(); + let data = match size { + Bytes::One => io_port.inb(port) as u64, + Bytes::Two => io_port.inw(port) as u64, + Bytes::Four => io_port.inl(port) as u64, + _ => return Err(InsnError::IoIoIn), + }; + Ok(data) + } + + fn ioio_out(&mut self, port: u16, size: Bytes, data: u64) -> Result<(), InsnError> { + let io_port = SVSM_PLATFORM.get_io_port(); + match size { + Bytes::One => io_port.outb(port, data as u8), + Bytes::Two => io_port.outw(port, data as u16), + Bytes::Four => io_port.outl(port, data as u32), + _ => return Err(InsnError::IoIoOut), + } + Ok(()) + } +} + +pub fn user_mode(ctxt: &X86ExceptionContext) -> bool { + (ctxt.frame.cs & 3) == 3 +} + +#[derive(Copy, Clone, Default, Debug)] +#[repr(C, packed)] +pub struct IdtEntry { + low: u64, + high: u64, +} + +const IDT_TARGET_MASK_1: u64 = 0x0000_0000_0000_ffff; +const IDT_TARGET_MASK_2: u64 = 0x0000_0000_ffff_0000; +const IDT_TARGET_MASK_3: u64 = 0xffff_ffff_0000_0000; + +const IDT_TARGET_MASK_1_SHIFT: u64 = 0; +const IDT_TARGET_MASK_2_SHIFT: u64 = 48 - 16; +const IDT_TARGET_MASK_3_SHIFT: u64 = 32; + +const IDT_TYPE_MASK: u8 = 0x0f; +const IDT_TYPE_SHIFT: u64 = 40; +const IDT_TYPE_CALL: u8 = 0x0c; +const IDT_TYPE_INT: u8 = 0x0e; +const IDT_TYPE_TRAP: u8 = 0x0f; + +fn idt_type_mask(t: u8) -> u64 { + ((t & IDT_TYPE_MASK) as u64) << IDT_TYPE_SHIFT +} + +const IDT_DPL_MASK: u8 = 0x03; +const IDT_DPL_SHIFT: u64 = 45; + +fn idt_dpl_mask(dpl: u8) -> u64 { + ((dpl & IDT_DPL_MASK) as u64) << IDT_DPL_SHIFT +} + +const IDT_PRESENT_MASK: u64 = 0x1u64 << 47; +const IDT_CS_SHIFT: u64 = 16; + +const IDT_IST_MASK: u64 = 0x7; +const IDT_IST_SHIFT: u64 = 32; + +impl IdtEntry { + fn create(target: VirtAddr, cs: u16, desc_type: u8, dpl: u8, ist: u8) -> Self { + let vaddr = target.bits() as u64; + let cs_mask = (cs as u64) << IDT_CS_SHIFT; + let ist_mask = ((ist as u64) & IDT_IST_MASK) << IDT_IST_SHIFT; + let low = (vaddr & IDT_TARGET_MASK_1) << IDT_TARGET_MASK_1_SHIFT + | (vaddr & IDT_TARGET_MASK_2) << IDT_TARGET_MASK_2_SHIFT + | idt_type_mask(desc_type) + | IDT_PRESENT_MASK + | idt_dpl_mask(dpl) + | cs_mask + | ist_mask; + let high = (vaddr & IDT_TARGET_MASK_3) >> IDT_TARGET_MASK_3_SHIFT; + + IdtEntry { low, high } + } + + pub fn raw_entry(target: VirtAddr) -> Self { + IdtEntry::create(target, SVSM_CS, IDT_TYPE_INT, 0, 0) + } + + pub fn entry(handler: unsafe extern "C" fn()) -> Self { + let target = VirtAddr::from(handler as *const ()); + IdtEntry::create(target, SVSM_CS, IDT_TYPE_INT, 0, 0) + } + + pub fn user_entry(handler: unsafe extern "C" fn()) -> Self { + let target = VirtAddr::from(handler as *const ()); + IdtEntry::create(target, SVSM_CS, IDT_TYPE_INT, 3, 0) + } + + pub fn ist_entry(handler: unsafe extern "C" fn(), ist: u8) -> Self { + let target = VirtAddr::from(handler as *const ()); + IdtEntry::create(target, SVSM_CS, IDT_TYPE_INT, 0, ist) + } + + pub fn trap_entry(handler: unsafe extern "C" fn()) -> Self { + let target = VirtAddr::from(handler as *const ()); + IdtEntry::create(target, SVSM_CS, IDT_TYPE_TRAP, 0, 0) + } + + pub fn call_entry(handler: unsafe extern "C" fn()) -> Self { + let target = VirtAddr::from(handler as *const ()); + IdtEntry::create(target, SVSM_CS, IDT_TYPE_CALL, 3, 0) + } + + pub const fn no_handler() -> Self { + IdtEntry { low: 0, high: 0 } + } +} + +const IDT_ENTRIES: usize = 256; + +#[repr(C, packed)] +#[derive(Default, Clone, Copy, Debug)] +struct IdtDesc { + size: u16, + address: VirtAddr, +} + +#[derive(Copy, Clone, Debug)] +pub struct IDT { + entries: [IdtEntry; IDT_ENTRIES], +} + +impl IDT { + pub const fn new() -> Self { + IDT { + entries: [IdtEntry::no_handler(); IDT_ENTRIES], + } + } + + pub fn init(&mut self, handler_array: *const u8, size: usize) -> &mut Self { + // Set IDT handlers + let handlers = VirtAddr::from(handler_array); + + for idx in 0..size { + self.set_entry(idx, IdtEntry::raw_entry(handlers + (32 * idx))); + } + + self + } + + pub fn set_entry(&mut self, idx: usize, entry: IdtEntry) -> &mut Self { + self.entries[idx] = entry; + + self + } +} + +impl Default for IDT { + fn default() -> Self { + Self::new() + } +} + +impl WriteLockGuard<'static, IDT> { + /// Load an IDT. Its lifetime must be static so that its entries are + /// always available to the CPU. + pub fn load(&self) { + let desc: IdtDesc = IdtDesc { + size: (IDT_ENTRIES * 16) as u16, + address: VirtAddr::from(self.entries.as_ptr()), + }; + + unsafe { + asm!("lidt (%rax)", in("rax") &desc, options(att_syntax)); + } + } +} + +impl ReadLockGuard<'static, IDT> { + pub fn base_limit(&self) -> (u64, u32) { + let base: *const IDT = core::ptr::from_ref(self); + let limit = (IDT_ENTRIES * mem::size_of::()) as u32; + (base as u64, limit) + } +} + +static IDT: RWLock = RWLock::new(IDT::new()); + +pub fn idt() -> ReadLockGuard<'static, IDT> { + IDT.lock_read() +} + +pub fn idt_mut() -> WriteLockGuard<'static, IDT> { + IDT.lock_write() +} + +pub fn triple_fault() { + let desc: IdtDesc = IdtDesc { + size: 0, + address: VirtAddr::from(0u64), + }; + + unsafe { + asm!("lidt (%rax) + int3", in("rax") &desc, options(att_syntax)); + } +} + +extern "C" { + static entry_code_start: u8; + static entry_code_end: u8; +} + +pub fn is_exception_handler_return_site(rip: VirtAddr) -> bool { + let start = VirtAddr::from(&raw const entry_code_start); + let end = VirtAddr::from(&raw const entry_code_end); + (start..end).contains(&rip) +} + +global_asm!( + r#" + /* Needed by the stack unwinder to recognize exception frames. */ + .globl generic_idt_handler_return + generic_idt_handler_return: + + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %r11 + popq %r10 + popq %r9 + popq %r8 + popq %rbp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %rbx + popq %rax + + addq $8, %rsp /* Skip error code */ + + iretq + "#, + options(att_syntax) +); + +#[repr(u32)] +#[derive(Debug, Clone, Copy)] +pub enum IdtEventType { + Unknown = 0, + External, + Software, +} + +impl IdtEventType { + pub fn is_external_interrupt(&self, vector: usize) -> bool { + match self { + Self::External => true, + Self::Software => false, + Self::Unknown => SVSM_PLATFORM.is_external_interrupt(vector), + } + } +} diff --git a/stage2/src/cpu/idt/entry.S b/stage2/src/cpu/idt/entry.S new file mode 100644 index 000000000..cfd04c37f --- /dev/null +++ b/stage2/src/cpu/idt/entry.S @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Authors: Joerg Roedel + +.code64 + +.section .data +.globl HV_DOORBELL_ADDR +HV_DOORBELL_ADDR: + .quad 0 + +.pushsection .entry.text, "ax" + +.macro push_regs + pushq %rax + pushq %rbx + pushq %rcx + pushq %rdx + pushq %rsi + pushq %rdi + pushq %rbp + pushq %r8 + pushq %r9 + pushq %r10 + pushq %r11 + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 +.endm + +.macro default_entry_no_ist name: req handler:req error_code:req vector:req + .globl asm_entry_\name +asm_entry_\name: + asm_clac + + .if \error_code == 0 + pushq $0 + .endif + push_regs + movl $\vector, %esi + movq %rsp, %rdi + xorl %edx, %edx + call ex_handler_\handler + jmp default_return +.endm + +.macro irq_entry name:req vector:req + .globl asm_entry_irq_\name +asm_entry_irq_\name: + asm_clac + + pushq $0 + push_regs + movl $\vector, %edi + call common_isr_handler + jmp default_return +.endm + +// The #HV handler is coded specially in order to deal with control flow +// alterations that may be required based on when the #HV arrives. If the #HV +// arrives from a context in which interrupts are enabled, then the #HV can +// be handled immediately. In general, if the #HV arrives from a context in +// which interrupts are disabled, processing is postponed to a point in time +// when interrupt processing is safe. However, there are two cases in which +// #HV processing is required even when interrupts are disabled. +// 1. The #HV arrives just before a return to the guest VMPL. In this case, +// the return to the guest VMPL must be cancelled so the #HV can be handled +// immediately. Otherwise, if the return to the guest occurs while the #HV +// remains pending, it will remain pending until the next time the SVSM +// is reentered, which could block delivery of critical events while the +// guest is executing. +// 2. The #HV arrives while preparing to execute IRET to return to a context +// in which interrupts are enabled. If such an #HV is not handled, then +// it will remain pending indefinitely, which could block delivery of +// critical events. When an #HV arrives at a time that the IRET is +// is committed to complete, the #HV handler will "take over" the +// exception context established previously (the one from which the IRET +// intends to return). In this case, the #HV handler will complete +// processing and will perform the IRET to the point of the original +// exception. +.globl asm_entry_hv +asm_entry_hv: + asm_clac + + // Push a dummy error code, and only three registers. If no #HV + // processing is required, then only these three registers will need to + // be popped. + pushq $0 + pushq %rax + pushq %rbx + pushq %rcx + // Check whether interrupts were enabled at the time of #HV. If so, + // commit to processing all #HV events immediately. + testl ${IF}, 0x30(%rsp) + jnz continue_hv + // Check whether the trap RIP is within the guest VMPL return window. + movq 0x20(%rsp), %rax // fetch RIP from the trap frame. + leaq switch_vmpl_window_start(%rip), %rbx + leaq switch_vmpl_window_end(%rip), %rcx + cmp %rbx, %rax + jb hv_not_vmpl_switch + cmp %rcx, %rax + jae hv_not_vmpl_switch + // RIP is in the return window, so update RIP to the cancel point. + leaq switch_vmpl_cancel(%rip), %rbx + movq %rbx, 0x20(%rsp) + // Defer any further processing until interrupts can be processed. + jmp postpone_hv +hv_not_vmpl_switch: + // Load the RSP value that was live at the time of the #HV. + movq 0x38(%rsp), %rcx + // Check to see whether this interrupt occurred on the IRET path + leaq iret_return_window(%rip), %rbx + cmp %rbx, %rax + jb postpone_hv + leaq default_iret(%rip), %rbx + cmp %rbx, %rax + ja postpone_hv + // RIP is within the IRET sequence, so the IRET should be aborted, and + // the previous exception should be handled as if it were #HV. At this + // point, there are two possibilities. If RIP is before the IRET + // instruction itself, then the RSP at the time of #HV exception + // points to the register context that was established for the previous + // exception. In that case, the current RSP can be changed to point + // to that exception context, and the #HV can be handled using that + // register context, and when #HV processing completes, the subsequent + // end-of-interrupt flow will restore the context at the time of the + // previous exception. On the other hand, if RIP has advanced to the + // point of the IRET instruction itself, then all of the registers + // have already been reloaded with the previous exception context, + // and the RSP at the time of #HV points at the stack frame that + // would be consumed by the IRET instruction. In that case, a new + // exception context will need to be constructed. At this point, + // EFLAGS.ZF=1 if the previous RIP was at the IRET instruction. + jz restart_hv + // Check to see whether interrupts were enabled at the time the + // previous exception was taken. If not, no further processing is + // required. This could not be performed before the RIP check because + // the previous RIP determines where to find the previous EFLAGS.IF + // value on the stack. + testl ${IF}, 18*8(%rcx) + jz postpone_hv + // Switch to the stack pointer from the previous exception, which + // points to the register save area, and continue with #HV + // processing. + movq %rcx, %rsp + jmp handle_as_hv + +postpone_hv: + popq %rcx + popq %rbx + popq %rax + addq $8, %rsp + iretq + +restart_hv: + // The previous RIP was on an IRET instruction. Before moving forward + // with #HV processing, check to see whether interrupts were enabled at + // the time the previous exception was taken. If not, no further + // processing is required. This could not be done when RIP was + // checked because the stack location of the previous EFLAGS.IF value + // was not known until RIP was determined to be at the IRET + // instruction. + testl ${IF}, 0x10(%rcx) + jz postpone_hv + // Since interrupts were enabled in the previous exception frame, + // #HV processing is now required. The previous RSP points to the + // exception frame (minus error code) as it would be consumed by + // IRET. In order to set up a new exception context, the three + // registers that were saved upon entry to the #HV handler will need to + // be copied to the top of the stack (adjacent to the space for a + // dummy erro code). Then, the stack pointer will be loaded with + // the previous RSP and the remaining register state will be pushed + // normally to create a complete exception context reflecting the + // register state at the time of the exception that was returning at + // the time the #HV arrived. + // At this point, RCX holds the stack pointer at the time of the + // IRET that was aborted. The first QWORD below that pointer is + // reserved for the dummy error code, then the three QWORDS below that + // will hold the RAX, RBX, and RCX values, which are presently stored + // in the top three QWORDs of the current stack. + movq 0*8(%rsp), %rax + movq %rax, -4*8(%rcx) + movq 1*8(%rsp), %rax + movq %rax, -3*8(%rcx) + movq 2*8(%rsp), %rax + movq %rax, -2*8(%rcx) + leaq -4*8(%rcx), %rsp + +continue_hv: + // At this point, only the dummy error code and first three registers + // have been pushed onto the stack. Push the remainder to construct a + // full exception context. + pushq %rdx + pushq %rsi + pushq %rdi + pushq %rbp + pushq %r8 + pushq %r9 + pushq %r10 + pushq %r11 + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 +handle_as_hv: + // Load the address of the #HV doorbell page. The global address + // might not yet be configured, and the per-CPU page might also not + // yet be configured, so only process events if there is a valid + // doorbell page. + movq HV_DOORBELL_ADDR(%rip), %rsi + testq %rsi, %rsi + jz default_return + movq (%rsi), %rdi + testq %rdi, %rdi + jz default_return +handle_as_hv_with_doorbell: + call process_hv_events + // fall through to default_return + +.globl default_return +default_return: + // Ensure that interrupts are disabled before attempting any return. + cli + testb $3, 17*8(%rsp) // Check CS in exception frame + jnz return_user +return_all_paths: + // If interrupts were previously available, then check whether any #HV + // events are pending. If so, proceed as if the original trap was + // #HV. + testl ${IF}, 18*8(%rsp) // check EFLAGS.IF in exception frame + jz begin_iret_return + movq HV_DOORBELL_ADDR(%rip), %rdi + test %rdi, %rdi + jz begin_iret_return + movq (%rdi), %rdi + test %rdi, %rdi + jz begin_iret_return + testw $0x8000, (%rdi) + // The memory access to the NoFurtherSignal bit must be the last + // instruction prior to the IRET RIP window checked by the #HV entry + // code above. After this point, all code must execute within this + // instruction range to ensure that the #HV handler will be able to + // detect any #HV that arrives after the check above, except for + // the specific case of processing pending #HV events. +iret_return_window: + jnz handle_as_hv_with_doorbell +begin_iret_return: + // Reload registers without modifying the stack pointer so that if #HV + // occurs within this window, the saved registers are still intact. + movq 0*8(%rsp), %r15 + movq 1*8(%rsp), %r14 + movq 2*8(%rsp), %r13 + movq 3*8(%rsp), %r12 + movq 4*8(%rsp), %r11 + movq 5*8(%rsp), %r10 + movq 6*8(%rsp), %r9 + movq 7*8(%rsp), %r8 + movq 8*8(%rsp), %rbp + movq 9*8(%rsp), %rdi + movq 10*8(%rsp), %rsi + movq 11*8(%rsp), %rdx + movq 12*8(%rsp), %rcx + movq 13*8(%rsp), %rbx + movq 14*8(%rsp), %rax + + addq $16*8, %rsp + +default_iret: + iretq + +return_user: + // Put user-mode specific return code here + jmp return_all_paths + +.globl return_new_task +return_new_task: + call setup_new_task + jmp default_return + +// #DE Divide-by-Zero-Error Exception (Vector 0) +default_entry_no_ist name=de handler=panic error_code=0 vector=0 + +// #DB Debug Exception (Vector 1) +default_entry_no_ist name=db handler=debug error_code=0 vector=1 + +// NMI Non-Maskable-Interrupt Exception (Vector 2) +default_entry_no_ist name=nmi handler=panic error_code=0 vector=2 + +// #BP Breakpoint Exception (Vector 3) +default_entry_no_ist name=bp handler=breakpoint error_code=0 vector=3 + +// #OF Overflow Exception (Vector 4) +default_entry_no_ist name=of handler=panic error_code=0 vector=4 + +// #BR Bound-Range Exception (Vector 5) +default_entry_no_ist name=br handler=panic error_code=0 vector=5 + +// #UD Invalid-Opcode Exception (Vector 6) +default_entry_no_ist name=ud handler=panic error_code=0 vector=6 + +// #NM Device-Not-Available Exception (Vector 7) +default_entry_no_ist name=nm handler=panic error_code=0 vector=7 + +// #DF Double-Fault Exception (Vector 8) +default_entry_no_ist name=df handler=double_fault error_code=1 vector=8 + +// Coprocessor-Segment-Overrun Exception (Vector 9) +// No handler - reserved vector + +// #TS Invalid-TSS Exception (Vector 10) +default_entry_no_ist name=ts handler=panic error_code=1 vector=10 + +// #NP Segment-Not-Present Exception (Vector 11) +default_entry_no_ist name=np handler=panic error_code=1 vector=11 + +// #SS Stack Exception (Vector 12) +default_entry_no_ist name=ss handler=panic error_code=1 vector=12 + +// #GP General-Protection Exception (Vector 13) +default_entry_no_ist name=gp handler=general_protection error_code=1 vector=13 + +// #PF Page-Fault Exception (Vector 14) +default_entry_no_ist name=pf handler=page_fault error_code=1 vector=14 + +// Vector 15 not defined + +// #MF x87 Floating-Point Exception-Pending (Vector 16) +default_entry_no_ist name=mf handler=panic error_code=0 vector=16 + +// #AC Alignment-Check Exception (Vector 17) +default_entry_no_ist name=ac handler=panic error_code=1 vector=17 + +// #MC Machine-Check Exception (Vector 18) +default_entry_no_ist name=mce handler=panic error_code=0 vector=18 + +// #XF SIMD Floating-Point Exception (Vector 19) +default_entry_no_ist name=xf handler=panic error_code=0 vector=19 + +// Vector 20 not defined + +// #CP Control-Protection Exception (Vector 21) +default_entry_no_ist name=cp handler=panic error_code=1 vector=21 + +// Vectors 22-27 not defined + +// #VC VMM Communication Exception (Vector 29) +default_entry_no_ist name=vc handler=vmm_communication error_code=1 vector=29 + +// #SX Security Exception (Vector 30) +default_entry_no_ist name=sx handler=panic error_code=1 vector=30 + +// INT 0x80 system call handler +default_entry_no_ist name=int80 handler=system_call error_code=0 vector=0x80 + +// Interrupt injection vector +irq_entry name=int_inj vector=0x50 + +.popsection diff --git a/stage2/src/cpu/idt/mod.rs b/stage2/src/cpu/idt/mod.rs new file mode 100644 index 000000000..c80def1c5 --- /dev/null +++ b/stage2/src/cpu/idt/mod.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Thomas Leroy + +pub mod common; +pub mod stage2; +pub mod svsm; + +pub use common::{idt, idt_mut}; diff --git a/stage2/src/cpu/idt/stage2.rs b/stage2/src/cpu/idt/stage2.rs new file mode 100644 index 000000000..07da54d90 --- /dev/null +++ b/stage2/src/cpu/idt/stage2.rs @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::common::{idt_mut, DF_VECTOR, HV_VECTOR, VC_VECTOR}; +use crate::cpu::control_regs::read_cr2; +use crate::cpu::vc::{stage2_handle_vc_exception, stage2_handle_vc_exception_no_ghcb}; +use crate::cpu::X86ExceptionContext; +use core::arch::global_asm; + +pub fn early_idt_init_no_ghcb() { + let mut idt = idt_mut(); + idt.init(&raw const stage2_idt_handler_array_no_ghcb, 32); + idt.load(); +} + +pub fn early_idt_init() { + let mut idt = idt_mut(); + idt.init(&raw const stage2_idt_handler_array, 32); + idt.load(); +} + +#[no_mangle] +pub extern "C" fn stage2_generic_idt_handler(ctx: &mut X86ExceptionContext, vector: usize) { + match vector { + DF_VECTOR => { + let cr2 = read_cr2(); + let rip = ctx.frame.rip; + let rsp = ctx.frame.rsp; + panic!( + "Double-Fault at RIP {:#018x} RSP: {:#018x} CR2: {:#018x}", + rip, rsp, cr2 + ); + } + VC_VECTOR => stage2_handle_vc_exception(ctx).expect("Failed to handle #VC"), + HV_VECTOR => + // #HV does not require processing during stage 2 and can be + // completely ignored. + {} + _ => { + let err = ctx.error_code; + let rip = ctx.frame.rip; + + panic!( + "Unhandled exception {} RIP {:#018x} error code: {:#018x}", + vector, rip, err + ); + } + } +} + +#[no_mangle] +pub extern "C" fn stage2_generic_idt_handler_no_ghcb(ctx: &mut X86ExceptionContext, vector: usize) { + match vector { + DF_VECTOR => { + let cr2 = read_cr2(); + let rip = ctx.frame.rip; + let rsp = ctx.frame.rsp; + panic!( + "Double-Fault at RIP {:#018x} RSP: {:#018x} CR2: {:#018x}", + rip, rsp, cr2 + ); + } + VC_VECTOR => stage2_handle_vc_exception_no_ghcb(ctx).expect("Failed to handle #VC"), + _ => { + let err = ctx.error_code; + let rip = ctx.frame.rip; + + panic!( + "Unhandled exception {} RIP {:#018x} error code: {:#018x}", + vector, rip, err + ); + } + } +} + +extern "C" { + static stage2_idt_handler_array: u8; + static stage2_idt_handler_array_no_ghcb: u8; +} + +global_asm!( + r#" + /* Early tage 2 handler array setup */ + .text + push_regs_no_ghcb: + pushq %rbx + pushq %rcx + pushq %rdx + pushq %rsi + movq 0x20(%rsp), %rsi + movq %rax, 0x20(%rsp) + pushq %rdi + pushq %rbp + pushq %r8 + pushq %r9 + pushq %r10 + pushq %r11 + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + movq %rsp, %rdi + call stage2_generic_idt_handler_no_ghcb + + jmp generic_idt_handler_return + + .align 32 + .globl stage2_idt_handler_array_no_ghcb + stage2_idt_handler_array_no_ghcb: + i = 0 + .rept 32 + .align 32 + .if ((0x20027d00 >> i) & 1) == 0 + pushq $0 + .endif + pushq $i /* Vector Number */ + jmp push_regs_no_ghcb + i = i + 1 + .endr + + /* Stage 2 handler array setup */ + .text + push_regs_stage2: + pushq %rbx + pushq %rcx + pushq %rdx + pushq %rsi + movq 0x20(%rsp), %rsi + movq %rax, 0x20(%rsp) + pushq %rdi + pushq %rbp + pushq %r8 + pushq %r9 + pushq %r10 + pushq %r11 + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + + movq %rsp, %rdi + call stage2_generic_idt_handler + + jmp generic_idt_handler_return + + .align 32 + .globl stage2_idt_handler_array + stage2_idt_handler_array: + i = 0 + .rept 32 + .align 32 + .if ((0x20027d00 >> i) & 1) == 0 + pushq $0 + .endif + pushq $i /* Vector Number */ + jmp push_regs_stage2 + i = i + 1 + .endr + "#, + options(att_syntax) +); diff --git a/stage2/src/cpu/idt/svsm.rs b/stage2/src/cpu/idt/svsm.rs new file mode 100644 index 000000000..d6137ee1f --- /dev/null +++ b/stage2/src/cpu/idt/svsm.rs @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Authors: Joerg Roedel + +use super::super::control_regs::read_cr2; +use super::super::extable::handle_exception_table; +use super::super::percpu::{current_task, this_cpu}; +use super::super::tss::IST_DF; +use super::super::vc::handle_vc_exception; +use super::common::{ + idt_mut, user_mode, IdtEntry, IdtEventType, PageFaultError, AC_VECTOR, BP_VECTOR, BR_VECTOR, + CP_VECTOR, DB_VECTOR, DE_VECTOR, DF_VECTOR, GP_VECTOR, HV_VECTOR, INT_INJ_VECTOR, MCE_VECTOR, + MF_VECTOR, NMI_VECTOR, NM_VECTOR, NP_VECTOR, OF_VECTOR, PF_VECTOR, SS_VECTOR, SX_VECTOR, + TS_VECTOR, UD_VECTOR, VC_VECTOR, XF_VECTOR, +}; +use crate::address::VirtAddr; +use crate::cpu::registers::RFlags; +use crate::cpu::X86ExceptionContext; +use crate::debug::gdbstub::svsm_gdbstub::handle_debug_exception; +use crate::platform::SVSM_PLATFORM; +use crate::task::{is_task_fault, terminate}; +use core::arch::global_asm; + +use crate::syscall::*; +use syscall::*; + +extern "C" { + pub fn return_new_task(); + pub fn default_return(); + fn asm_entry_de(); + fn asm_entry_db(); + fn asm_entry_nmi(); + fn asm_entry_bp(); + fn asm_entry_of(); + fn asm_entry_br(); + fn asm_entry_ud(); + fn asm_entry_nm(); + fn asm_entry_df(); + fn asm_entry_ts(); + fn asm_entry_np(); + fn asm_entry_ss(); + fn asm_entry_gp(); + fn asm_entry_pf(); + fn asm_entry_mf(); + fn asm_entry_ac(); + fn asm_entry_mce(); + fn asm_entry_xf(); + fn asm_entry_cp(); + fn asm_entry_hv(); + fn asm_entry_vc(); + fn asm_entry_sx(); + fn asm_entry_int80(); + fn asm_entry_irq_int_inj(); + + pub static mut HV_DOORBELL_ADDR: usize; +} + +fn init_ist_vectors() { + idt_mut().set_entry(DF_VECTOR, IdtEntry::ist_entry(asm_entry_df, IST_DF.get())); +} + +pub fn early_idt_init() { + let mut idt = idt_mut(); + idt.set_entry(DE_VECTOR, IdtEntry::entry(asm_entry_de)); + idt.set_entry(DB_VECTOR, IdtEntry::entry(asm_entry_db)); + idt.set_entry(NMI_VECTOR, IdtEntry::entry(asm_entry_nmi)); + idt.set_entry(BP_VECTOR, IdtEntry::entry(asm_entry_bp)); + idt.set_entry(OF_VECTOR, IdtEntry::entry(asm_entry_of)); + idt.set_entry(BR_VECTOR, IdtEntry::entry(asm_entry_br)); + idt.set_entry(UD_VECTOR, IdtEntry::entry(asm_entry_ud)); + idt.set_entry(NM_VECTOR, IdtEntry::entry(asm_entry_nm)); + idt.set_entry(DF_VECTOR, IdtEntry::entry(asm_entry_df)); + idt.set_entry(TS_VECTOR, IdtEntry::entry(asm_entry_ts)); + idt.set_entry(NP_VECTOR, IdtEntry::entry(asm_entry_np)); + idt.set_entry(SS_VECTOR, IdtEntry::entry(asm_entry_ss)); + idt.set_entry(GP_VECTOR, IdtEntry::entry(asm_entry_gp)); + idt.set_entry(PF_VECTOR, IdtEntry::entry(asm_entry_pf)); + idt.set_entry(MF_VECTOR, IdtEntry::entry(asm_entry_mf)); + idt.set_entry(AC_VECTOR, IdtEntry::entry(asm_entry_ac)); + idt.set_entry(MCE_VECTOR, IdtEntry::entry(asm_entry_mce)); + idt.set_entry(XF_VECTOR, IdtEntry::entry(asm_entry_xf)); + idt.set_entry(CP_VECTOR, IdtEntry::entry(asm_entry_cp)); + idt.set_entry(HV_VECTOR, IdtEntry::entry(asm_entry_hv)); + idt.set_entry(VC_VECTOR, IdtEntry::entry(asm_entry_vc)); + idt.set_entry(SX_VECTOR, IdtEntry::entry(asm_entry_sx)); + idt.set_entry(INT_INJ_VECTOR, IdtEntry::entry(asm_entry_irq_int_inj)); + + // Interupts + idt.set_entry(0x80, IdtEntry::user_entry(asm_entry_int80)); + + // Load IDT + idt.load(); +} + +pub fn idt_init() { + // Set IST vectors + init_ist_vectors(); + + // Capture an address that can be used by assembly code to read the #HV + // doorbell page. The address of each CPU's doorbell page may be + // different, but the address of the field in the PerCpu structure that + // holds the actual pointer is constant across all CPUs, so that is the + // pointer that is actually captured. The address that is captured is + // stored as a usize instead of a typed value, because the declarations + // required for type safety here are cumbersome, and the assembly code + // that uses the value is not type safe in any case, so enforcing type + // safety on the pointer would offer no meaningful value. + unsafe { + HV_DOORBELL_ADDR = this_cpu().hv_doorbell_addr() as usize; + }; +} + +// Debug handler +#[no_mangle] +extern "C" fn ex_handler_debug(ctx: &mut X86ExceptionContext) { + handle_debug_exception(ctx, DB_VECTOR); +} + +// Breakpoint handler +#[no_mangle] +extern "C" fn ex_handler_breakpoint(ctx: &mut X86ExceptionContext) { + handle_debug_exception(ctx, BP_VECTOR); +} + +// Doube-Fault handler +#[no_mangle] +extern "C" fn ex_handler_double_fault(ctxt: &mut X86ExceptionContext) { + let cr2 = read_cr2(); + let rip = ctxt.frame.rip; + let rsp = ctxt.frame.rsp; + + if user_mode(ctxt) { + log::error!( + "Double-Fault at RIP {:#018x} RSP: {:#018x} CR2: {:#018x} - Terminating task", + rip, + rsp, + cr2 + ); + terminate(); + } else { + panic!( + "Double-Fault at RIP {:#018x} RSP: {:#018x} CR2: {:#018x}", + rip, rsp, cr2 + ); + } +} + +// General-Protection handler +#[no_mangle] +extern "C" fn ex_handler_general_protection(ctxt: &mut X86ExceptionContext) { + let rip = ctxt.frame.rip; + let err = ctxt.error_code; + let rsp = ctxt.frame.rsp; + + if user_mode(ctxt) { + log::error!( + "Unhandled General-Protection-Fault at RIP {:#018x} error code: {:#018x} rsp: {:#018x} - Terminating task", + rip, err, rsp); + terminate(); + } else if !handle_exception_table(ctxt) { + panic!( + "Unhandled General-Protection-Fault at RIP {:#018x} error code: {:#018x} rsp: {:#018x}", + rip, err, rsp + ); + } +} + +// Page-Fault handler +#[no_mangle] +extern "C" fn ex_handler_page_fault(ctxt: &mut X86ExceptionContext, vector: usize) { + let cr2 = read_cr2(); + let rip = ctxt.frame.rip; + let err = ctxt.error_code; + let vaddr = VirtAddr::from(cr2); + + if user_mode(ctxt) { + let kill_task: bool = if is_task_fault(vaddr) { + current_task() + .fault(vaddr, (err & PageFaultError::W.bits() as usize) != 0) + .is_err() + } else { + true + }; + + if kill_task { + log::error!("Unexpected user-mode page-fault at RIP {:#018x} CR2: {:#018x} error code: {:#018x} - Terminating task", + rip, cr2, err); + terminate(); + } + } else if this_cpu() + .handle_pf( + VirtAddr::from(cr2), + (err & PageFaultError::W.bits() as usize) != 0, + ) + .is_err() + && !handle_exception_table(ctxt) + { + handle_debug_exception(ctxt, vector); + panic!( + "Unhandled Page-Fault at RIP {:#018x} CR2: {:#018x} error code: {:#018x}", + rip, cr2, err + ); + } +} + +// VMM Communication handler +#[no_mangle] +extern "C" fn ex_handler_vmm_communication(ctxt: &mut X86ExceptionContext, vector: usize) { + let rip = ctxt.frame.rip; + let code = ctxt.error_code; + + if let Err(err) = handle_vc_exception(ctxt, vector) { + log::error!("#VC handling error: {:?}", err); + if user_mode(ctxt) { + log::error!("Failed to handle #VC from user-mode at RIP {:#018x} code: {:#018x} - Terminating task", rip, code); + terminate(); + } else { + panic!( + "Failed to handle #VC from kernel-mode at RIP {:#018x} code: {:#018x}", + rip, code + ); + } + } +} + +// System Call SoftIRQ handler +#[no_mangle] +extern "C" fn ex_handler_system_call( + ctxt: &mut X86ExceptionContext, + vector: usize, + event_type: IdtEventType, +) { + // Ensure that this vector was not invoked as a hardware interrupt vector. + if event_type.is_external_interrupt(vector) { + panic!("Syscall handler invoked as external interrupt!"); + } + + if !user_mode(ctxt) { + panic!("Syscall handler called from kernel mode!"); + } + + let Ok(input) = TryInto::::try_into(ctxt.regs.rax) else { + ctxt.regs.rax = !0; + return; + }; + + ctxt.regs.rax = match input { + SYS_HELLO => sys_hello(), + SYS_EXIT => sys_exit(), + _ => !0, + }; +} + +#[no_mangle] +pub extern "C" fn ex_handler_panic(ctx: &mut X86ExceptionContext, vector: usize) { + let rip = ctx.frame.rip; + let err = ctx.error_code; + let rsp = ctx.frame.rsp; + let ss = ctx.frame.ss; + panic!( + "Unhandled exception {} RIP {:#018x} error code: {:#018x} RSP: {:#018x} SS: {:#x}", + vector, rip, err, rsp, ss + ); +} + +#[no_mangle] +pub extern "C" fn common_isr_handler(_vector: usize) { + // Interrupt injection requests currently require no processing; they occur + // simply to ensure an exit from the guest. + + // Treat any unhandled interrupt as a spurious interrupt. + SVSM_PLATFORM.eoi(); +} + +global_asm!( + r#" + .set const_false, 0 + .set const_true, 1 + "#, + concat!(".set CFG_NOSMAP, const_", cfg!(feature = "nosmap")), + include_str!("../x86/smap.S"), + include_str!("entry.S"), + IF = const RFlags::IF.bits(), + options(att_syntax) +); diff --git a/stage2/src/cpu/irq_state.rs b/stage2/src/cpu/irq_state.rs new file mode 100644 index 000000000..429f99039 --- /dev/null +++ b/stage2/src/cpu/irq_state.rs @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel + +use crate::cpu::percpu::this_cpu; +use crate::cpu::{irqs_disable, irqs_enable}; +use core::arch::asm; +use core::marker::PhantomData; +use core::sync::atomic::{AtomicBool, AtomicIsize, Ordering}; + +/// Interrupt flag in RFLAGS register +const EFLAGS_IF: u64 = 1 << 9; + +/// Unconditionally disable IRQs +/// +/// # Safety +/// +/// Callers need to take care of re-enabling IRQs. +#[inline(always)] +pub unsafe fn raw_irqs_disable() { + asm!("cli", options(att_syntax, preserves_flags, nomem)); +} + +/// Unconditionally enable IRQs +/// +/// # Safety +/// +/// Callers need to make sure it is safe to enable IRQs. e.g. that no data +/// structures or locks which are accessed in IRQ handlers are used after IRQs +/// have been enabled. +#[inline(always)] +pub unsafe fn raw_irqs_enable() { + asm!("sti", options(att_syntax, preserves_flags, nomem)); + + // Now that interrupts are enabled, process any #HV events that may be + // pending. + if let Some(doorbell) = this_cpu().hv_doorbell() { + doorbell.process_if_required(); + } +} + +/// Query IRQ state on current CPU +/// +/// # Returns +/// +/// `true` when IRQs are enabled, `false` otherwise +#[inline(always)] +#[must_use = "Unused irqs_enabled() result - meant to be irq_enable()?"] +pub fn irqs_enabled() -> bool { + // SAFETY: The inline assembly just reads the processors RFLAGS register + // and does not change any state. + let state: u64; + unsafe { + asm!("pushfq", + "popq {}", + out(reg) state, + options(att_syntax, preserves_flags)); + }; + + (state & EFLAGS_IF) == EFLAGS_IF +} + +/// Query IRQ state on current CPU +/// +/// # Returns +/// +/// `false` when IRQs are enabled, `true` otherwise +#[inline(always)] +#[must_use = "Unused irqs_disabled() result - meant to be irq_disable()?"] +pub fn irqs_disabled() -> bool { + !irqs_enabled() +} + +/// This structure keeps track of PerCpu IRQ states. It tracks the original IRQ +/// state and how deep IRQ-disable calls have been nested. The use of atomics +/// is necessary for interior mutability and to make state modifications safe +/// wrt. to IRQs. +/// +/// The original state needs to be stored to not accidentially enable IRQs in +/// contexts which have IRQs disabled by other means, e.g. in an exception or +/// NMI/HV context. +#[derive(Debug, Default)] +pub struct IrqState { + /// IRQ state when count was `0` + state: AtomicBool, + /// Depth of IRQ-disabled nesting + count: AtomicIsize, + /// Make the type !Send + !Sync + phantom: PhantomData<*const ()>, +} + +impl IrqState { + /// Create a new instance of `IrqState` + pub fn new() -> Self { + Self { + state: AtomicBool::new(false), + count: AtomicIsize::new(0), + phantom: PhantomData, + } + } + + /// Increase IRQ-disable nesting level by 1. The method will disable IRQs. + /// + /// # Safety + /// + /// The caller needs to make sure to match the number of `disable` calls + /// with the number of `enable` calls. + #[inline(always)] + pub unsafe fn disable(&self) { + let state = irqs_enabled(); + + raw_irqs_disable(); + let val = self.count.fetch_add(1, Ordering::Relaxed); + + assert!(val >= 0); + + if val == 0 { + self.state.store(state, Ordering::Relaxed) + } + } + + /// Decrease IRQ-disable nesting level by 1. The method will restore the + /// original IRQ state when the nesting level reaches 0. + /// + /// # Safety + /// + /// The caller needs to make sure to match the number of `disable` calls + /// with the number of `enable` calls. + #[inline(always)] + pub unsafe fn enable(&self) { + debug_assert!(irqs_disabled()); + + let val = self.count.fetch_sub(1, Ordering::Relaxed); + + assert!(val > 0); + + if val == 1 { + let state = self.state.load(Ordering::Relaxed); + if state { + raw_irqs_enable(); + } + } + } + + /// Returns the current nesting count + /// + /// # Returns + /// + /// Levels of IRQ-disable nesting currently active + pub fn count(&self) -> isize { + self.count.load(Ordering::Relaxed) + } + + /// Changes whether interrupts will be enabled when the nesting count + /// drops to zero. + /// + /// # Safety + /// + /// The caller must ensure that the current nesting count is non-zero, + /// and must ensure that the specified value is appropriate for the + /// current environment. + pub unsafe fn set_restore_state(&self, enabled: bool) { + assert!(self.count.load(Ordering::Relaxed) != 0); + self.state.store(enabled, Ordering::Relaxed); + } +} + +impl Drop for IrqState { + /// This struct should never be dropped. Add a debug check in case it is + /// dropped anyway. + fn drop(&mut self) { + let count = self.count.load(Ordering::Relaxed); + assert_eq!(count, 0); + } +} + +/// And IRQ guard which saves the current IRQ state and disabled interrupts +/// upon creation. When the guard goes out of scope the previous IRQ state is +/// restored. +/// +/// The struct implements the `Default` and `Drop` traits for easy use. +#[derive(Debug)] +#[must_use = "if unused previous IRQ state will be immediatly restored"] +pub struct IrqGuard { + /// Make the type !Send + !Sync + phantom: PhantomData<*const ()>, +} + +impl IrqGuard { + pub fn new() -> Self { + // SAFETY: Safe because the struct implements `Drop`, which + // restores the IRQ state saved here. + unsafe { + irqs_disable(); + } + + Self { + phantom: PhantomData, + } + } +} + +impl Default for IrqGuard { + fn default() -> Self { + IrqGuard::new() + } +} + +impl Drop for IrqGuard { + fn drop(&mut self) { + // SAFETY: Safe because the irqs_enabled() call matches the + // irqs_disabled() call during struct creation. + unsafe { + irqs_enable(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn irq_enable_disable() { + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + assert!(irqs_enabled()); + raw_irqs_disable(); + assert!(irqs_disabled()); + if was_enabled { + raw_irqs_enable(); + } + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn irq_state() { + unsafe { + let state = IrqState::new(); + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + state.disable(); + assert!(irqs_disabled()); + state.disable(); + state.enable(); + assert!(irqs_disabled()); + state.enable(); + assert!(irqs_enabled()); + if !was_enabled { + raw_irqs_disable(); + } + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn irq_guard_test() { + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + assert!(irqs_enabled()); + let g1 = IrqGuard::new(); + assert!(irqs_disabled()); + drop(g1); + assert!(irqs_enabled()); + if !was_enabled { + raw_irqs_disable(); + } + } + } +} diff --git a/stage2/src/cpu/mem.rs b/stage2/src/cpu/mem.rs new file mode 100644 index 000000000..ffed5f2bd --- /dev/null +++ b/stage2/src/cpu/mem.rs @@ -0,0 +1,39 @@ +use core::arch::asm; + +/// Copy `size` bytes from `src` to `dst`. +/// +/// # Safety +/// +/// This function has all the safety requirements of `core::ptr::copy` except +/// that data races (both on `src` and `dst`) are explicitly permitted. +#[inline(always)] +pub unsafe fn copy_bytes(src: usize, dst: usize, size: usize) { + unsafe { + asm!( + "rep movsb", + inout("rsi") src => _, + inout("rdi") dst => _, + inout("rcx") size => _, + options(nostack), + ); + } +} + +/// Set `size` bytes at `dst` to `val`. +/// +/// # Safety +/// +/// This function has all the safety requirements of `core::ptr::write_bytes` except +/// that data races are explicitly permitted. +#[inline(always)] +pub unsafe fn write_bytes(dst: usize, size: usize, value: u8) { + unsafe { + asm!( + "rep stosb", + inout("rdi") dst => _, + inout("rcx") size => _, + in("al") value, + options(nostack), + ); + } +} diff --git a/stage2/src/cpu/mod.rs b/stage2/src/cpu/mod.rs new file mode 100644 index 000000000..4cb4882dc --- /dev/null +++ b/stage2/src/cpu/mod.rs @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod apic; +pub mod control_regs; +pub mod cpuid; +pub mod efer; +pub mod extable; +pub mod features; +pub mod gdt; +pub mod idt; +pub mod irq_state; +pub mod mem; +pub mod msr; +pub mod percpu; +pub mod registers; +pub mod smp; +pub mod sse; +pub mod tlb; +pub mod tss; +pub mod vc; +pub mod vmsa; +pub mod x86; + +pub use apic::LocalApic; +pub use gdt::{gdt, gdt_mut}; +pub use idt::common::X86ExceptionContext; +pub use irq_state::{irqs_disabled, irqs_enabled, IrqGuard, IrqState}; +pub use percpu::{irq_nesting_count, irqs_disable, irqs_enable}; +pub use registers::{X86GeneralRegs, X86InterruptFrame, X86SegmentRegs}; +pub use tlb::*; diff --git a/stage2/src/cpu/msr.rs b/stage2/src/cpu/msr.rs new file mode 100644 index 000000000..64dfbc1e3 --- /dev/null +++ b/stage2/src/cpu/msr.rs @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use core::arch::asm; + +pub const EFER: u32 = 0xC000_0080; +pub const SEV_STATUS: u32 = 0xC001_0131; +pub const SEV_GHCB: u32 = 0xC001_0130; +pub const MSR_GS_BASE: u32 = 0xC000_0101; + +pub fn read_msr(msr: u32) -> u64 { + let eax: u32; + let edx: u32; + + unsafe { + asm!("rdmsr", + in("ecx") msr, + out("eax") eax, + out("edx") edx, + options(att_syntax)); + } + (eax as u64) | (edx as u64) << 32 +} + +pub fn write_msr(msr: u32, val: u64) { + let eax = (val & 0x0000_0000_ffff_ffff) as u32; + let edx = (val >> 32) as u32; + + unsafe { + asm!("wrmsr", + in("ecx") msr, + in("eax") eax, + in("edx") edx, + options(att_syntax)); + } +} + +pub fn rdtsc() -> u64 { + let eax: u32; + let edx: u32; + + unsafe { + asm!("rdtsc", + out("eax") eax, + out("edx") edx, + options(att_syntax, nomem, nostack)); + } + (eax as u64) | (edx as u64) << 32 +} + +#[derive(Debug, Clone, Copy)] +pub struct RdtscpOut { + pub timestamp: u64, + pub pid: u32, +} + +pub fn rdtscp() -> RdtscpOut { + let eax: u32; + let edx: u32; + let ecx: u32; + + unsafe { + asm!("rdtscp", + out("eax") eax, + out("ecx") ecx, + out("edx") edx, + options(att_syntax, nomem, nostack)); + } + RdtscpOut { + timestamp: (eax as u64) | (edx as u64) << 32, + pid: ecx, + } +} + +pub fn read_flags() -> u64 { + let rax: u64; + unsafe { + asm!( + r#" + pushfq + pop %rax + "#, + out("rax") rax, + options(att_syntax)); + } + rax +} diff --git a/stage2/src/cpu/percpu.rs b/stage2/src/cpu/percpu.rs new file mode 100644 index 000000000..bf5ea49b8 --- /dev/null +++ b/stage2/src/cpu/percpu.rs @@ -0,0 +1,1027 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use super::gdt_mut; +use super::tss::{X86Tss, IST_DF}; +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::cpu::idt::common::INT_INJ_VECTOR; +use crate::cpu::tss::TSS_LIMIT; +use crate::cpu::vmsa::{init_guest_vmsa, init_svsm_vmsa}; +use crate::cpu::{IrqState, LocalApic}; +use crate::error::{ApicError, SvsmError}; +use crate::locking::{LockGuard, RWLock, RWLockIrqSafe, SpinLock}; +use crate::mm::pagetable::{PTEntryFlags, PageTable}; +use crate::mm::virtualrange::VirtualRange; +use crate::mm::vm::{Mapping, VMKernelStack, VMPhysMem, VMRMapping, VMReserved, VMR}; +use crate::mm::{ + virt_to_phys, PageBox, SVSM_PERCPU_BASE, SVSM_PERCPU_CAA_BASE, SVSM_PERCPU_END, + SVSM_PERCPU_TEMP_BASE_2M, SVSM_PERCPU_TEMP_BASE_4K, SVSM_PERCPU_TEMP_END_2M, + SVSM_PERCPU_TEMP_END_4K, SVSM_PERCPU_VMSA_BASE, SVSM_STACKS_INIT_TASK, SVSM_STACK_IST_DF_BASE, +}; +use crate::platform::{SvsmPlatform, SVSM_PLATFORM}; +use crate::sev::ghcb::{GhcbPage, GHCB}; +use crate::sev::hv_doorbell::{allocate_hv_doorbell_page, HVDoorbell}; +use crate::sev::msr_protocol::{hypervisor_ghcb_features, GHCBHvFeatures}; +use crate::sev::utils::RMPFlags; +use crate::sev::vmsa::{VMSAControl, VmsaPage}; +use crate::task::{schedule, schedule_task, RunQueue, Task, TaskPointer, WaitQueue}; +use crate::types::{PAGE_SHIFT, PAGE_SHIFT_2M, PAGE_SIZE, PAGE_SIZE_2M, SVSM_TR_FLAGS, SVSM_TSS}; +use crate::utils::MemoryRegion; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::cell::{Cell, OnceCell, Ref, RefCell, RefMut, UnsafeCell}; +use core::mem::size_of; +use core::ptr; +use core::slice::Iter; +use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use cpuarch::vmsa::{VMSASegment, VMSA}; + +#[derive(Copy, Clone, Debug)] +pub struct PerCpuInfo { + apic_id: u32, + cpu_shared: &'static PerCpuShared, +} + +impl PerCpuInfo { + const fn new(apic_id: u32, cpu_shared: &'static PerCpuShared) -> Self { + Self { + apic_id, + cpu_shared, + } + } + + pub fn as_cpu_ref(&self) -> &'static PerCpuShared { + self.cpu_shared + } +} + +// PERCPU areas virtual addresses into shared memory +pub static PERCPU_AREAS: PerCpuAreas = PerCpuAreas::new(); + +// We use an UnsafeCell to allow for a static with interior +// mutability. Normally, we would need to guarantee synchronization +// on the backing datatype, but this is not needed because writes to +// the structure only occur at initialization, from CPU 0, and reads +// should only occur after all writes are done. +#[derive(Debug)] +pub struct PerCpuAreas { + areas: UnsafeCell>, +} + +unsafe impl Sync for PerCpuAreas {} + +impl PerCpuAreas { + const fn new() -> Self { + Self { + areas: UnsafeCell::new(Vec::new()), + } + } + + unsafe fn push(&self, info: PerCpuInfo) { + let ptr = self.areas.get().as_mut().unwrap(); + ptr.push(info); + } + + pub fn iter(&self) -> Iter<'_, PerCpuInfo> { + let ptr = unsafe { self.areas.get().as_ref().unwrap() }; + ptr.iter() + } + + // Fails if no such area exists or its address is NULL + pub fn get(&self, apic_id: u32) -> Option<&'static PerCpuShared> { + // For this to not produce UB the only invariant we must + // uphold is that there are no mutations or mutable aliases + // going on when casting via as_ref(). This only happens via + // Self::push(), which is intentionally unsafe and private. + let ptr = unsafe { self.areas.get().as_ref().unwrap() }; + ptr.iter() + .find(|info| info.apic_id == apic_id) + .map(|info| info.cpu_shared) + } +} + +#[derive(Debug)] +struct IstStacks { + double_fault_stack: Cell>, +} + +impl IstStacks { + const fn new() -> Self { + IstStacks { + double_fault_stack: Cell::new(None), + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct GuestVmsaRef { + vmsa: Option, + caa: Option, + generation: u64, + gen_in_use: u64, +} + +impl GuestVmsaRef { + pub const fn new() -> Self { + GuestVmsaRef { + vmsa: None, + caa: None, + generation: 1, + gen_in_use: 0, + } + } + + pub fn needs_update(&self) -> bool { + self.generation != self.gen_in_use + } + + pub fn update_vmsa(&mut self, paddr: Option) { + self.vmsa = paddr; + self.generation += 1; + } + + pub fn update_caa(&mut self, paddr: Option) { + self.caa = paddr; + self.generation += 1; + } + + pub fn update_vmsa_caa(&mut self, vmsa: Option, caa: Option) { + self.vmsa = vmsa; + self.caa = caa; + self.generation += 1; + } + + pub fn set_updated(&mut self) { + self.gen_in_use = self.generation; + } + + pub fn vmsa_phys(&self) -> Option { + self.vmsa + } + + pub fn caa_phys(&self) -> Option { + self.caa + } + + pub fn vmsa(&mut self) -> &mut VMSA { + assert!(self.vmsa.is_some()); + // SAFETY: this function takes &mut self, so only one mutable + // reference to the underlying VMSA can exist. + unsafe { SVSM_PERCPU_VMSA_BASE.as_mut_ptr::().as_mut().unwrap() } + } + + pub fn caa_addr(&self) -> Option { + let caa_phys = self.caa_phys()?; + let offset = caa_phys.page_offset(); + + Some(SVSM_PERCPU_CAA_BASE + offset) + } +} + +#[derive(Debug)] +pub struct PerCpuShared { + apic_id: u32, + guest_vmsa: SpinLock, + online: AtomicBool, + ipi_irr: [AtomicU32; 8], + ipi_pending: AtomicBool, + nmi_pending: AtomicBool, +} + +impl PerCpuShared { + fn new(apic_id: u32) -> Self { + PerCpuShared { + apic_id, + guest_vmsa: SpinLock::new(GuestVmsaRef::new()), + online: AtomicBool::new(false), + ipi_irr: core::array::from_fn(|_| AtomicU32::new(0)), + ipi_pending: AtomicBool::new(false), + nmi_pending: AtomicBool::new(false), + } + } + + pub const fn apic_id(&self) -> u32 { + self.apic_id + } + + pub fn update_guest_vmsa_caa(&self, vmsa: PhysAddr, caa: PhysAddr) { + let mut locked = self.guest_vmsa.lock(); + locked.update_vmsa_caa(Some(vmsa), Some(caa)); + } + + pub fn update_guest_vmsa(&self, vmsa: PhysAddr) { + let mut locked = self.guest_vmsa.lock(); + locked.update_vmsa(Some(vmsa)); + } + + pub fn update_guest_caa(&self, caa: PhysAddr) { + let mut locked = self.guest_vmsa.lock(); + locked.update_caa(Some(caa)); + } + + pub fn clear_guest_vmsa_if_match(&self, paddr: PhysAddr) { + let mut locked = self.guest_vmsa.lock(); + if locked.vmsa.is_none() { + return; + } + + let vmsa_phys = locked.vmsa_phys(); + if vmsa_phys.unwrap() == paddr { + locked.update_vmsa(None); + } + } + + pub fn set_online(&self) { + self.online.store(true, Ordering::Release); + } + + pub fn is_online(&self) -> bool { + self.online.load(Ordering::Acquire) + } + + pub fn request_ipi(&self, vector: u8) { + let index = vector >> 5; + let bit = 1u32 << (vector & 31); + // Request the IPI via the IRR vector before signaling that an IPI has + // been requested. + self.ipi_irr[index as usize].fetch_or(bit, Ordering::Relaxed); + self.ipi_pending.store(true, Ordering::Release); + } + + pub fn request_nmi(&self) { + self.nmi_pending.store(true, Ordering::Relaxed); + self.ipi_pending.store(true, Ordering::Release); + } + + pub fn ipi_pending(&self) -> bool { + self.ipi_pending.swap(false, Ordering::Acquire) + } + + pub fn ipi_irr_vector(&self, index: usize) -> u32 { + self.ipi_irr[index].swap(0, Ordering::Relaxed) + } + + pub fn nmi_pending(&self) -> bool { + self.nmi_pending.swap(false, Ordering::Relaxed) + } +} + +const _: () = assert!(size_of::() <= PAGE_SIZE); + +/// CPU-local data. +/// +/// This type is not [`Sync`], as its contents will only be accessed from the +/// local CPU, much like thread-local data in an std environment. The only +/// part of the struct that may be accessed from a different CPU is the +/// `shared` field, a reference to which will be stored in [`PERCPU_AREAS`]. +#[derive(Debug)] +pub struct PerCpu { + /// Per-CPU storage that might be accessed from other CPUs. + shared: PerCpuShared, + + /// PerCpu IRQ state tracking + irq_state: IrqState, + + pgtbl: RefCell>, + tss: Cell, + svsm_vmsa: OnceCell, + reset_ip: Cell, + /// PerCpu Virtual Memory Range + vm_range: VMR, + /// Address allocator for per-cpu 4k temporary mappings + pub vrange_4k: RefCell, + /// Address allocator for per-cpu 2m temporary mappings + pub vrange_2m: RefCell, + /// Task list that has been assigned for scheduling on this CPU + runqueue: RWLockIrqSafe, + /// WaitQueue for request processing + request_waitqueue: RefCell, + /// Local APIC state for APIC emulation if enabled + apic: RefCell>, + + /// GHCB page for this CPU. + ghcb: OnceCell, + + /// `#HV` doorbell page for this CPU. + hv_doorbell: Cell>, + + init_stack: Cell>, + ist: IstStacks, + + /// Stack boundaries of the currently running task. + current_stack: Cell>, +} + +impl PerCpu { + /// Creates a new default [`PerCpu`] struct. + fn new(apic_id: u32) -> Self { + Self { + pgtbl: RefCell::new(None), + irq_state: IrqState::new(), + tss: Cell::new(X86Tss::new()), + svsm_vmsa: OnceCell::new(), + reset_ip: Cell::new(0xffff_fff0), + vm_range: { + let mut vmr = VMR::new(SVSM_PERCPU_BASE, SVSM_PERCPU_END, PTEntryFlags::GLOBAL); + vmr.set_per_cpu(true); + vmr + }, + + vrange_4k: RefCell::new(VirtualRange::new()), + vrange_2m: RefCell::new(VirtualRange::new()), + runqueue: RWLockIrqSafe::new(RunQueue::new()), + request_waitqueue: RefCell::new(WaitQueue::new()), + apic: RefCell::new(None), + + shared: PerCpuShared::new(apic_id), + ghcb: OnceCell::new(), + hv_doorbell: Cell::new(None), + init_stack: Cell::new(None), + ist: IstStacks::new(), + current_stack: Cell::new(MemoryRegion::new(VirtAddr::null(), 0)), + } + } + + /// Creates a new default [`PerCpu`] struct, allocates it via the page + /// allocator and adds it to the global per-cpu area list. + pub fn alloc(apic_id: u32) -> Result<&'static Self, SvsmError> { + let page = PageBox::try_new(Self::new(apic_id))?; + let percpu = PageBox::leak(page); + unsafe { PERCPU_AREAS.push(PerCpuInfo::new(apic_id, &percpu.shared)) }; + Ok(percpu) + } + + pub fn shared(&self) -> &PerCpuShared { + &self.shared + } + + /// Disables IRQs on the current CPU. Keeps track of the nesting level and + /// the original IRQ state. + /// + /// # Safety + /// + /// Caller needs to make sure to match every `disable()` call with an + /// `enable()` call. + #[inline(always)] + pub unsafe fn irqs_disable(&self) { + self.irq_state.disable(); + } + + /// Reduces IRQ-disable nesting level on the current CPU and restores the + /// original IRQ state when the level reaches 0. + /// + /// # Safety + /// + /// Caller needs to make sure to match every `disable()` call with an + /// `enable()` call. + #[inline(always)] + pub unsafe fn irqs_enable(&self) { + self.irq_state.enable(); + } + + /// Get IRQ-disable nesting count on the current CPU + /// + /// # Returns + /// + /// Current nesting depth of irq_disable() calls. + pub fn irq_nesting_count(&self) -> isize { + self.irq_state.count() + } + + /// Sets up the CPU-local GHCB page. + pub fn setup_ghcb(&self) -> Result<(), SvsmError> { + let page = GhcbPage::new()?; + self.ghcb + .set(page) + .expect("Attempted to reinitialize the GHCB"); + Ok(()) + } + + fn ghcb(&self) -> Option<&GhcbPage> { + self.ghcb.get() + } + + pub fn hv_doorbell(&self) -> Option<&'static HVDoorbell> { + self.hv_doorbell.get() + } + + /// Gets a pointer to the location of the HV doorbell pointer in the + /// PerCpu structure. Pointers and references have the same layout, so + /// the return type is equivalent to `*const *const HVDoorbell`. + pub fn hv_doorbell_addr(&self) -> *const &'static HVDoorbell { + self.hv_doorbell.as_ptr().cast() + } + + pub fn get_top_of_stack(&self) -> VirtAddr { + self.init_stack.get().unwrap() + } + + pub fn get_top_of_df_stack(&self) -> VirtAddr { + self.ist.double_fault_stack.get().unwrap() + } + + pub fn get_current_stack(&self) -> MemoryRegion { + self.current_stack.get() + } + + pub fn get_apic_id(&self) -> u32 { + self.shared().apic_id() + } + + pub fn init_page_table(&self, pgtable: PageBox) -> Result<(), SvsmError> { + self.vm_range.initialize()?; + self.set_pgtable(PageBox::leak(pgtable)); + + Ok(()) + } + + pub fn set_pgtable(&self, pgtable: &'static mut PageTable) { + *self.pgtbl.borrow_mut() = Some(pgtable); + } + + fn allocate_stack(&self, base: VirtAddr) -> Result { + let stack = VMKernelStack::new()?; + let top_of_stack = stack.top_of_stack(base); + let mapping = Arc::new(Mapping::new(stack)); + + self.vm_range.insert_at(base, mapping)?; + + Ok(top_of_stack) + } + + fn allocate_init_stack(&self) -> Result<(), SvsmError> { + let init_stack = Some(self.allocate_stack(SVSM_STACKS_INIT_TASK)?); + self.init_stack.set(init_stack); + Ok(()) + } + + fn allocate_ist_stacks(&self) -> Result<(), SvsmError> { + let double_fault_stack = self.allocate_stack(SVSM_STACK_IST_DF_BASE)?; + self.ist.double_fault_stack.set(Some(double_fault_stack)); + Ok(()) + } + + pub fn get_pgtable(&self) -> RefMut<'_, PageTable> { + RefMut::map(self.pgtbl.borrow_mut(), |pgtbl| { + &mut **pgtbl.as_mut().unwrap() + }) + } + + /// Registers an already set up GHCB page for this CPU. + /// + /// # Panics + /// + /// Panics if the GHCB for this CPU has not been set up via + /// [`PerCpu::setup_ghcb()`]. + pub fn register_ghcb(&self) -> Result<(), SvsmError> { + self.ghcb().unwrap().register() + } + + fn setup_hv_doorbell(&self) -> Result<(), SvsmError> { + let doorbell = allocate_hv_doorbell_page(current_ghcb())?; + assert!( + self.hv_doorbell.get().is_none(), + "Attempted to reinitialize the HV doorbell page" + ); + self.hv_doorbell.set(Some(doorbell)); + Ok(()) + } + + /// Configures the HV doorbell page if restricted injection is enabled. + /// + /// # Panics + /// + /// Panics if this function is called more than once for a given CPU and + /// restricted injection is enabled. + pub fn configure_hv_doorbell(&self) -> Result<(), SvsmError> { + // #HV doorbell configuration is only required if this system will make + // use of restricted injection. + if hypervisor_ghcb_features().contains(GHCBHvFeatures::SEV_SNP_RESTR_INJ) { + self.setup_hv_doorbell()?; + } + Ok(()) + } + + fn setup_tss(&self) { + let double_fault_stack = self.get_top_of_df_stack(); + let mut tss = self.tss.get(); + tss.set_ist_stack(IST_DF, double_fault_stack); + self.tss.set(tss); + } + + pub fn map_self_stage2(&self) -> Result<(), SvsmError> { + let vaddr = VirtAddr::from(ptr::from_ref(self)); + let paddr = virt_to_phys(vaddr); + let flags = PTEntryFlags::data(); + self.get_pgtable().map_4k(SVSM_PERCPU_BASE, paddr, flags) + } + + pub fn map_self(&self) -> Result<(), SvsmError> { + let vaddr = VirtAddr::from(ptr::from_ref(self)); + let paddr = virt_to_phys(vaddr); + let self_mapping = Arc::new(VMPhysMem::new_mapping(paddr, PAGE_SIZE, true)); + self.vm_range.insert_at(SVSM_PERCPU_BASE, self_mapping)?; + Ok(()) + } + + fn initialize_vm_ranges(&self) -> Result<(), SvsmError> { + let size_4k = SVSM_PERCPU_TEMP_END_4K - SVSM_PERCPU_TEMP_BASE_4K; + let temp_mapping_4k = Arc::new(VMReserved::new_mapping(size_4k)); + self.vm_range + .insert_at(SVSM_PERCPU_TEMP_BASE_4K, temp_mapping_4k)?; + + let size_2m = SVSM_PERCPU_TEMP_END_2M - SVSM_PERCPU_TEMP_BASE_2M; + let temp_mapping_2m = Arc::new(VMReserved::new_mapping(size_2m)); + self.vm_range + .insert_at(SVSM_PERCPU_TEMP_BASE_2M, temp_mapping_2m)?; + + Ok(()) + } + + fn finish_page_table(&self) { + let mut pgtable = self.get_pgtable(); + self.vm_range.populate(&mut pgtable); + } + + pub fn dump_vm_ranges(&self) { + self.vm_range.dump_ranges(); + } + + pub fn setup( + &self, + platform: &dyn SvsmPlatform, + pgtable: PageBox, + ) -> Result<(), SvsmError> { + self.init_page_table(pgtable)?; + + // Map PerCpu data in own page-table + self.map_self()?; + + // Reserve ranges for temporary mappings + self.initialize_vm_ranges()?; + + // Allocate per-cpu init stack + self.allocate_init_stack()?; + + // Allocate IST stacks + self.allocate_ist_stacks()?; + + // Setup TSS + self.setup_tss(); + + // Initialize allocator for temporary mappings + self.virt_range_init(); + + self.finish_page_table(); + + // Complete platform-specific initialization. + platform.setup_percpu(self)?; + + Ok(()) + } + + // Setup code which needs to run on the target CPU + pub fn setup_on_cpu(&self, platform: &dyn SvsmPlatform) -> Result<(), SvsmError> { + platform.setup_percpu_current(self) + } + + pub fn setup_idle_task(&self, entry: extern "C" fn()) -> Result<(), SvsmError> { + let idle_task = Task::create(self, entry)?; + self.runqueue.lock_read().set_idle_task(idle_task); + Ok(()) + } + + pub fn load_pgtable(&self) { + self.get_pgtable().load(); + } + + pub fn load_tss(&self) { + // SAFETY: this can only produce UB if someone else calls self.tss.set + // () while this new reference is alive, which cannot happen as this + // data is local to this CPU. We need to get a reference to the value + // inside the Cell because the address of the TSS will be used. If we + // did self.tss.get(), then the address of a temporary copy would be + // used. + let tss = unsafe { &*self.tss.as_ptr() }; + gdt_mut().load_tss(tss); + } + + pub fn load(&self) { + self.load_pgtable(); + self.load_tss(); + } + + pub fn set_reset_ip(&self, reset_ip: u64) { + self.reset_ip.set(reset_ip); + } + + /// Allocates and initializes a new VMSA for this CPU. Returns its + /// physical address and SEV features. Returns an error if allocation + /// fails of this CPU's VMSA was already initialized. + pub fn alloc_svsm_vmsa(&self, vtom: u64, start_rip: u64) -> Result<(PhysAddr, u64), SvsmError> { + if self.svsm_vmsa.get().is_some() { + // FIXME: add a more explicit error variant for this condition + return Err(SvsmError::Mem); + } + + let mut vmsa = VmsaPage::new(RMPFlags::GUEST_VMPL)?; + let paddr = vmsa.paddr(); + + // Initialize VMSA + init_svsm_vmsa(&mut vmsa, vtom); + vmsa.tr = self.vmsa_tr_segment(); + vmsa.rip = start_rip; + vmsa.rsp = self.get_top_of_stack().into(); + vmsa.cr3 = self.get_pgtable().cr3_value().into(); + vmsa.enable(); + + let sev_features = vmsa.sev_features; + + // We already checked that the VMSA is unset + self.svsm_vmsa.set(vmsa).unwrap(); + + Ok((paddr, sev_features)) + } + + pub fn unmap_guest_vmsa(&self) { + assert!(self.shared().apic_id == this_cpu().get_apic_id()); + // Ignore errors - the mapping might or might not be there + let _ = self.vm_range.remove(SVSM_PERCPU_VMSA_BASE); + } + + pub fn map_guest_vmsa(&self, paddr: PhysAddr) -> Result<(), SvsmError> { + assert!(self.shared().apic_id == this_cpu().get_apic_id()); + let vmsa_mapping = Arc::new(VMPhysMem::new_mapping(paddr, PAGE_SIZE, true)); + self.vm_range + .insert_at(SVSM_PERCPU_VMSA_BASE, vmsa_mapping)?; + + Ok(()) + } + + pub fn guest_vmsa_ref(&self) -> LockGuard<'_, GuestVmsaRef> { + self.shared().guest_vmsa.lock() + } + + pub fn alloc_guest_vmsa(&self) -> Result<(), SvsmError> { + // Enable alternate injection if the hypervisor supports it. + let use_alternate_injection = SVSM_PLATFORM.query_apic_registration_state(); + if use_alternate_injection { + self.apic.replace(Some(LocalApic::new())); + + // Configure the interrupt injection vector. + let ghcb = self.ghcb().unwrap(); + ghcb.configure_interrupt_injection(INT_INJ_VECTOR)?; + } + + let mut vmsa = VmsaPage::new(RMPFlags::GUEST_VMPL)?; + let paddr = vmsa.paddr(); + + init_guest_vmsa(&mut vmsa, self.reset_ip.get(), use_alternate_injection); + + self.shared().update_guest_vmsa(paddr); + let _ = VmsaPage::leak(vmsa); + + Ok(()) + } + + /// Returns a shared reference to the local APIC, or `None` if APIC + /// emulation is not enabled. + fn apic(&self) -> Option> { + let apic = self.apic.borrow(); + Ref::filter_map(apic, Option::as_ref).ok() + } + + /// Returns a mutable reference to the local APIC, or `None` if APIC + /// emulation is not enabled. + fn apic_mut(&self) -> Option> { + let apic = self.apic.borrow_mut(); + RefMut::filter_map(apic, Option::as_mut).ok() + } + + pub fn unmap_caa(&self) { + // Ignore errors - the mapping might or might not be there + let _ = self.vm_range.remove(SVSM_PERCPU_CAA_BASE); + } + + pub fn map_guest_caa(&self, paddr: PhysAddr) -> Result<(), SvsmError> { + self.unmap_caa(); + + let caa_mapping = Arc::new(VMPhysMem::new_mapping(paddr, PAGE_SIZE, true)); + self.vm_range.insert_at(SVSM_PERCPU_CAA_BASE, caa_mapping)?; + + Ok(()) + } + + pub fn disable_apic_emulation(&self) { + if let Some(mut apic) = self.apic_mut() { + let mut vmsa_ref = self.guest_vmsa_ref(); + let caa_addr = vmsa_ref.caa_addr(); + let vmsa = vmsa_ref.vmsa(); + apic.disable_apic_emulation(vmsa, caa_addr); + } + } + + pub fn clear_pending_interrupts(&self) { + if let Some(mut apic) = self.apic_mut() { + let mut vmsa_ref = self.guest_vmsa_ref(); + let caa_addr = vmsa_ref.caa_addr(); + let vmsa = vmsa_ref.vmsa(); + apic.check_delivered_interrupts(vmsa, caa_addr); + } + } + + pub fn update_apic_emulation(&self, vmsa: &mut VMSA, caa_addr: Option) { + if let Some(mut apic) = self.apic_mut() { + apic.present_interrupts(self.shared(), vmsa, caa_addr); + } + } + + pub fn use_apic_emulation(&self) -> bool { + self.apic().is_some() + } + + pub fn read_apic_register(&self, register: u64) -> Result { + let mut vmsa_ref = self.guest_vmsa_ref(); + let caa_addr = vmsa_ref.caa_addr(); + let vmsa = vmsa_ref.vmsa(); + self.apic_mut() + .ok_or(SvsmError::Apic(ApicError::Disabled))? + .read_register(self.shared(), vmsa, caa_addr, register) + } + + pub fn write_apic_register(&self, register: u64, value: u64) -> Result<(), SvsmError> { + let mut vmsa_ref = self.guest_vmsa_ref(); + let caa_addr = vmsa_ref.caa_addr(); + let vmsa = vmsa_ref.vmsa(); + self.apic_mut() + .ok_or(SvsmError::Apic(ApicError::Disabled))? + .write_register(vmsa, caa_addr, register, value) + } + + pub fn configure_apic_vector(&self, vector: u8, allowed: bool) -> Result<(), SvsmError> { + self.apic_mut() + .ok_or(SvsmError::Apic(ApicError::Disabled))? + .configure_vector(vector, allowed); + Ok(()) + } + + fn vmsa_tr_segment(&self) -> VMSASegment { + VMSASegment { + selector: SVSM_TSS, + flags: SVSM_TR_FLAGS, + limit: TSS_LIMIT as u32, + base: &raw const self.tss as u64, + } + } + + fn virt_range_init(&self) { + // Initialize 4k range + let page_count = (SVSM_PERCPU_TEMP_END_4K - SVSM_PERCPU_TEMP_BASE_4K) / PAGE_SIZE; + assert!(page_count <= VirtualRange::CAPACITY); + self.vrange_4k + .borrow_mut() + .init(SVSM_PERCPU_TEMP_BASE_4K, page_count, PAGE_SHIFT); + + // Initialize 2M range + let page_count = (SVSM_PERCPU_TEMP_END_2M - SVSM_PERCPU_TEMP_BASE_2M) / PAGE_SIZE_2M; + assert!(page_count <= VirtualRange::CAPACITY); + self.vrange_2m + .borrow_mut() + .init(SVSM_PERCPU_TEMP_BASE_2M, page_count, PAGE_SHIFT_2M); + } + + /// Create a new virtual memory mapping in the PerCpu VMR + /// + /// # Arguments + /// + /// * `mapping` - The mapping to insert into the PerCpu VMR + /// + /// # Returns + /// + /// On success, a new ['VMRMapping'} that provides a virtual memory address for + /// the mapping which remains valid until the ['VRMapping'] is dropped. + /// + /// On error, an ['SvsmError']. + pub fn new_mapping(&self, mapping: Arc) -> Result, SvsmError> { + VMRMapping::new(&self.vm_range, mapping) + } + + /// Add the PerCpu virtual range into the provided pagetable + /// + /// # Arguments + /// + /// * `pt` - The page table to populate the the PerCpu range into + pub fn populate_page_table(&self, pt: &mut PageTable) { + self.vm_range.populate(pt); + } + + pub fn handle_pf(&self, vaddr: VirtAddr, write: bool) -> Result<(), SvsmError> { + self.vm_range.handle_page_fault(vaddr, write) + } + + pub fn schedule_init(&self) -> TaskPointer { + // If the platform permits the use of interrupts, then ensure that + // interrupts will be enabled on the current CPU when leaving the + // scheduler environment. This is done after disabling interrupts + // for scheduler initialization so that the first interrupt that can + // be received will always observe that there is a current task and + // not the boot thread. + if SVSM_PLATFORM.use_interrupts() { + unsafe { + self.irq_state.set_restore_state(true); + } + } + let task = self.runqueue.lock_write().schedule_init(); + self.current_stack.set(task.stack_bounds()); + task + } + + pub fn schedule_prepare(&self) -> Option<(TaskPointer, TaskPointer)> { + let ret = self.runqueue.lock_write().schedule_prepare(); + if let Some((_, ref next)) = ret { + self.current_stack.set(next.stack_bounds()); + }; + ret + } + + pub fn runqueue(&self) -> &RWLockIrqSafe { + &self.runqueue + } + + pub fn current_task(&self) -> TaskPointer { + self.runqueue.lock_read().current_task() + } + + pub fn set_tss_rsp0(&self, addr: VirtAddr) { + let mut tss = self.tss.get(); + tss.stacks[0] = addr; + self.tss.set(tss); + } +} + +pub fn this_cpu() -> &'static PerCpu { + unsafe { &*SVSM_PERCPU_BASE.as_ptr::() } +} + +pub fn this_cpu_shared() -> &'static PerCpuShared { + this_cpu().shared() +} + +/// Disables IRQs on the current CPU. Keeps track of the nesting level and +/// the original IRQ state. +/// +/// # Safety +/// +/// Caller needs to make sure to match every `irqs_disable()` call with an +/// `irqs_enable()` call. +#[inline(always)] +pub unsafe fn irqs_disable() { + this_cpu().irqs_disable(); +} + +/// Reduces IRQ-disable nesting level on the current CPU and restores the +/// original IRQ state when the level reaches 0. +/// +/// # Safety +/// +/// Caller needs to make sure to match every `irqs_disable()` call with an +/// `irqs_enable()` call. +#[inline(always)] +pub unsafe fn irqs_enable() { + this_cpu().irqs_enable(); +} + +/// Get IRQ-disable nesting count on the current CPU +/// +/// # Returns +/// +/// Current nesting depth of irq_disable() calls. +pub fn irq_nesting_count() -> isize { + this_cpu().irq_nesting_count() +} + +/// Gets the GHCB for this CPU. +/// +/// # Panics +/// +/// Panics if the GHCB for this CPU has not been set up via +/// [`PerCpu::setup_ghcb()`]. +pub fn current_ghcb() -> &'static GHCB { + this_cpu().ghcb().unwrap() +} + +#[derive(Debug, Clone, Copy)] +pub struct VmsaRegistryEntry { + pub paddr: PhysAddr, + pub apic_id: u32, + pub guest_owned: bool, + pub in_use: bool, +} + +impl VmsaRegistryEntry { + pub const fn new(paddr: PhysAddr, apic_id: u32, guest_owned: bool) -> Self { + VmsaRegistryEntry { + paddr, + apic_id, + guest_owned, + in_use: false, + } + } +} + +// PERCPU VMSAs to apic_id map +pub static PERCPU_VMSAS: PerCpuVmsas = PerCpuVmsas::new(); + +#[derive(Debug)] +pub struct PerCpuVmsas { + vmsas: RWLock>, +} + +impl PerCpuVmsas { + const fn new() -> Self { + Self { + vmsas: RWLock::new(Vec::new()), + } + } + + pub fn exists(&self, paddr: PhysAddr) -> bool { + self.vmsas + .lock_read() + .iter() + .any(|vmsa| vmsa.paddr == paddr) + } + + pub fn register( + &self, + paddr: PhysAddr, + apic_id: u32, + guest_owned: bool, + ) -> Result<(), SvsmError> { + let mut guard = self.vmsas.lock_write(); + if guard.iter().any(|vmsa| vmsa.paddr == paddr) { + return Err(SvsmError::InvalidAddress); + } + + guard.push(VmsaRegistryEntry::new(paddr, apic_id, guest_owned)); + Ok(()) + } + + pub fn set_used(&self, paddr: PhysAddr) -> Option { + self.vmsas + .lock_write() + .iter_mut() + .find(|vmsa| vmsa.paddr == paddr && !vmsa.in_use) + .map(|vmsa| { + vmsa.in_use = true; + vmsa.apic_id + }) + } + + pub fn unregister(&self, paddr: PhysAddr, in_use: bool) -> Result { + let mut guard = self.vmsas.lock_write(); + let index = guard + .iter() + .position(|vmsa| vmsa.paddr == paddr && vmsa.in_use == in_use) + .ok_or(0u64)?; + + if in_use { + let vmsa = &guard[index]; + + if vmsa.apic_id == 0 { + return Err(0); + } + + let target_cpu = PERCPU_AREAS + .get(vmsa.apic_id) + .expect("Invalid APIC-ID in VMSA registry"); + target_cpu.clear_guest_vmsa_if_match(paddr); + } + + Ok(guard.swap_remove(index)) + } +} + +pub fn wait_for_requests() { + let current_task = current_task(); + this_cpu() + .request_waitqueue + .borrow_mut() + .wait_for_event(current_task); + schedule(); +} + +pub fn process_requests() { + let maybe_task = this_cpu().request_waitqueue.borrow_mut().wakeup(); + if let Some(task) = maybe_task { + schedule_task(task); + } +} + +pub fn current_task() -> TaskPointer { + this_cpu().runqueue.lock_read().current_task() +} diff --git a/stage2/src/cpu/registers.rs b/stage2/src/cpu/registers.rs new file mode 100644 index 000000000..a5aeb3a3e --- /dev/null +++ b/stage2/src/cpu/registers.rs @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +use bitflags::bitflags; + +#[repr(C, packed)] +#[derive(Default, Debug, Clone, Copy)] +pub struct X86GeneralRegs { + pub r15: usize, + pub r14: usize, + pub r13: usize, + pub r12: usize, + pub r11: usize, + pub r10: usize, + pub r9: usize, + pub r8: usize, + pub rbp: usize, + pub rdi: usize, + pub rsi: usize, + pub rdx: usize, + pub rcx: usize, + pub rbx: usize, + pub rax: usize, +} + +#[repr(C, packed)] +#[derive(Default, Debug, Clone, Copy)] +pub struct X86SegmentRegs { + pub cs: usize, + pub ds: usize, + pub es: usize, + pub fs: usize, + pub gs: usize, + pub ss: usize, +} + +#[repr(C, packed)] +#[derive(Default, Debug, Clone, Copy)] +pub struct X86InterruptFrame { + pub rip: usize, + pub cs: usize, + pub flags: usize, + pub rsp: usize, + pub ss: usize, +} + +bitflags! { + #[derive(Copy, Clone, Debug, PartialEq)] + pub struct SegDescAttrFlags: u64 { + const A = 1 << 40; + const R_W = 1 << 41; + const C_E = 1 << 42; + const C_D = 1 << 43; + const S = 1 << 44; + const AVL = 1 << 52; + const L = 1 << 53; + const DB = 1 << 54; + const G = 1 << 55; + } +} + +bitflags! { + #[derive(Clone, Copy, Debug)] + pub struct RFlags: usize { + const CF = 1 << 0; + const FIXED = 1 << 1; + const PF = 1 << 2; + const AF = 1 << 4; + const ZF = 1 << 6; + const SF = 1 << 7; + const TF = 1 << 8; + const IF = 1 << 9; + const DF = 1 << 10; + const OF = 1 << 11; + const IOPL = 3 << 12; + const NT = 1 << 14; + const MD = 1 << 15; + const RF = 1 << 16; + const VM = 1 << 17; + const AC = 1 << 18; + const VIF = 1 << 19; + const VIP = 1 << 20; + const ID = 1 << 21; + } +} diff --git a/stage2/src/cpu/smp.rs b/stage2/src/cpu/smp.rs new file mode 100644 index 000000000..af448aef3 --- /dev/null +++ b/stage2/src/cpu/smp.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::acpi::tables::ACPICPUInfo; +use crate::cpu::percpu::{this_cpu, this_cpu_shared, PerCpu}; +use crate::cpu::sse::sse_init; +use crate::error::SvsmError; +use crate::platform::SvsmPlatform; +use crate::platform::SVSM_PLATFORM; +use crate::requests::{request_loop, request_processing_main}; +use crate::task::{create_kernel_task, schedule_init}; +use crate::utils::immut_after_init::immut_after_init_set_multithreaded; + +fn start_cpu(platform: &dyn SvsmPlatform, apic_id: u32) -> Result<(), SvsmError> { + let start_rip: u64 = (start_ap as *const u8) as u64; + let percpu = PerCpu::alloc(apic_id)?; + let pgtable = this_cpu().get_pgtable().clone_shared()?; + percpu.setup(platform, pgtable)?; + platform.start_cpu(percpu, start_rip)?; + + let percpu_shared = percpu.shared(); + while !percpu_shared.is_online() {} + Ok(()) +} + +pub fn start_secondary_cpus(platform: &dyn SvsmPlatform, cpus: &[ACPICPUInfo]) { + immut_after_init_set_multithreaded(); + let mut count: usize = 0; + for c in cpus.iter().filter(|c| c.apic_id != 0 && c.enabled) { + log::info!("Launching AP with APIC-ID {}", c.apic_id); + start_cpu(platform, c.apic_id).expect("Failed to bring CPU online"); + count += 1; + } + log::info!("Brought {} AP(s) online", count); +} + +#[no_mangle] +fn start_ap() { + this_cpu() + .setup_on_cpu(&**SVSM_PLATFORM) + .expect("setup_on_cpu() failed"); + + this_cpu() + .setup_idle_task(ap_request_loop) + .expect("Failed to allocated idle task for AP"); + + // Send a life-sign + log::info!("AP with APIC-ID {} is online", this_cpu().get_apic_id()); + + // Set CPU online so that BSP can proceed + this_cpu_shared().set_online(); + + sse_init(); + schedule_init(); +} + +#[no_mangle] +pub extern "C" fn ap_request_loop() { + create_kernel_task(request_processing_main).expect("Failed to launch request processing task"); + request_loop(); + panic!("Returned from request_loop!"); +} diff --git a/stage2/src/cpu/sse.rs b/stage2/src/cpu/sse.rs new file mode 100644 index 000000000..5a70cac04 --- /dev/null +++ b/stage2/src/cpu/sse.rs @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Vasant Karasulli + +use crate::cpu::control_regs::{cr0_sse_enable, cr4_osfxsr_enable, cr4_xsave_enable}; +use crate::cpu::cpuid::CpuidResult; +use core::arch::asm; +use core::arch::x86_64::{_xgetbv, _xsetbv}; + +const CPUID_EDX_SSE1: u32 = 25; +const CPUID_ECX_XSAVE: u32 = 26; +const CPUID_EAX_XSAVEOPT: u32 = 0; +const XCR0_X87_ENABLE: u64 = 0x1; +const XCR0_SSE_ENABLE: u64 = 0x2; +const XCR0_YMM_ENABLE: u64 = 0x4; + +fn legacy_sse_supported() -> bool { + let res = CpuidResult::get(1, 0); + (res.edx & (1 << CPUID_EDX_SSE1)) != 0 +} + +fn legacy_sse_enable() { + if legacy_sse_supported() { + cr4_osfxsr_enable(); + cr0_sse_enable(); + } else { + panic!("Legacy SSE unsupported"); + } +} + +fn extended_sse_supported() -> bool { + let res = CpuidResult::get(0xD, 1); + (res.eax & 0x7) == 0x7 +} + +fn xsave_supported() -> bool { + let res = CpuidResult::get(1, 0); + (res.ecx & (1 << CPUID_ECX_XSAVE)) != 0 +} + +fn xcr0_set() { + unsafe { + // set bits [0-2] in XCR0 to enable extended SSE + let xr0 = _xgetbv(0) | XCR0_X87_ENABLE | XCR0_SSE_ENABLE | XCR0_YMM_ENABLE; + _xsetbv(0, xr0); + } +} + +pub fn get_xsave_area_size() -> u32 { + let res = CpuidResult::get(0xD, 0); + if (res.eax & (1 << CPUID_EAX_XSAVEOPT)) == 0 { + panic!("XSAVEOPT unsupported"); + } + res.ecx +} + +fn extended_sse_enable() { + if extended_sse_supported() && xsave_supported() { + cr4_xsave_enable(); + xcr0_set(); + } else { + panic!("extended SSE unsupported"); + } +} + +// Enable media and x87 instructions +pub fn sse_init() { + legacy_sse_enable(); + extended_sse_enable(); +} + +/// # Safety +/// inline assembly here is used to save the SSE/FPU +/// context. This context store is specific to a task and +/// no other part of the code is accessing this memory at the same time. +pub unsafe fn sse_save_context(addr: u64) { + let save_bits = XCR0_X87_ENABLE | XCR0_SSE_ENABLE | XCR0_YMM_ENABLE; + asm!( + r#" + xsaveopt (%rsi) + "#, + in("rsi") addr, + in("rax") save_bits, + in("rdx") 0, + options(att_syntax)); +} + +/// # Safety +/// inline assembly here is used to restore the SSE/FPU +/// context. This context store is specific to a task and +/// no other part of the code is accessing this memory at the same time. +pub unsafe fn sse_restore_context(addr: u64) { + let save_bits = XCR0_X87_ENABLE | XCR0_SSE_ENABLE | XCR0_YMM_ENABLE; + asm!( + r#" + xrstor (%rsi) + "#, + in("rsi") addr, + in("rax") save_bits, + in("rdx") 0, + options(att_syntax)); +} diff --git a/stage2/src/cpu/tlb.rs b/stage2/src/cpu/tlb.rs new file mode 100644 index 000000000..2b30e9fe4 --- /dev/null +++ b/stage2/src/cpu/tlb.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, VirtAddr}; +use crate::cpu::control_regs::{read_cr4, write_cr4, CR4Flags}; + +use core::arch::asm; + +const INVLPGB_VALID_VA: u64 = 1u64 << 0; +//const INVLPGB_VALID_PCID: u64 = 1u64 << 1; +const INVLPGB_VALID_ASID: u64 = 1u64 << 2; +const INVLPGB_VALID_GLOBAL: u64 = 1u64 << 3; + +#[inline] +fn do_invlpgb(rax: u64, rcx: u64, rdx: u64) { + unsafe { + asm!("invlpgb", + in("rax") rax, + in("rcx") rcx, + in("rdx") rdx, + options(att_syntax)); + } +} + +#[inline] +fn do_tlbsync() { + unsafe { + asm!("tlbsync", options(att_syntax)); + } +} + +pub fn flush_tlb() { + let rax: u64 = INVLPGB_VALID_ASID; + do_invlpgb(rax, 0, 0); +} + +pub fn flush_tlb_sync() { + flush_tlb(); + do_tlbsync(); +} + +pub fn flush_tlb_global() { + let rax: u64 = INVLPGB_VALID_ASID | INVLPGB_VALID_GLOBAL; + do_invlpgb(rax, 0, 0); +} + +pub fn flush_tlb_global_sync() { + flush_tlb_global(); + do_tlbsync(); +} + +pub fn flush_tlb_global_percpu() { + let cr4 = read_cr4(); + write_cr4(cr4 ^ CR4Flags::PGE); + write_cr4(cr4); +} + +pub fn flush_address_percpu(va: VirtAddr) { + let va: u64 = va.page_align().bits() as u64; + unsafe { + asm!("invlpg (%rax)", + in("rax") va, + options(att_syntax)); + } +} + +pub fn flush_address(va: VirtAddr) { + let rax: u64 = (va.page_align().bits() as u64) + | INVLPGB_VALID_VA + | INVLPGB_VALID_ASID + | INVLPGB_VALID_GLOBAL; + do_invlpgb(rax, 0, 0); +} + +pub fn flush_address_sync(va: VirtAddr) { + flush_address(va); + do_tlbsync(); +} diff --git a/stage2/src/cpu/tss.rs b/stage2/src/cpu/tss.rs new file mode 100644 index 000000000..b1f95ca51 --- /dev/null +++ b/stage2/src/cpu/tss.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::gdt::GDTEntry; +use crate::address::VirtAddr; +use core::num::NonZeroU8; + +// IST offsets +pub const IST_DF: NonZeroU8 = unsafe { NonZeroU8::new_unchecked(1) }; + +#[derive(Debug, Default, Clone, Copy)] +#[repr(C, packed(4))] +pub struct X86Tss { + reserved0: u32, + pub stacks: [VirtAddr; 3], + reserved1: u64, + ist_stacks: [VirtAddr; 7], + reserved2: u64, + reserved3: u16, + io_bmp_base: u16, +} + +pub const TSS_LIMIT: u64 = core::mem::size_of::() as u64; + +impl X86Tss { + pub const fn new() -> Self { + X86Tss { + reserved0: 0, + stacks: [VirtAddr::null(); 3], + reserved1: 0, + ist_stacks: [VirtAddr::null(); 7], + reserved2: 0, + reserved3: 0, + io_bmp_base: (TSS_LIMIT + 1) as u16, + } + } + + pub fn set_ist_stack(&mut self, index: NonZeroU8, addr: VirtAddr) { + // IST entries start at index 1 + let index = usize::from(index.get() - 1); + self.ist_stacks[index] = addr; + } + + pub fn to_gdt_entry(&self) -> (GDTEntry, GDTEntry) { + let addr = (self as *const X86Tss) as u64; + + let mut desc0: u64 = 0; + let mut desc1: u64 = 0; + + // Limit + desc0 |= TSS_LIMIT & 0xffffu64; + desc0 |= ((TSS_LIMIT >> 16) & 0xfu64) << 48; + + // Address + desc0 |= (addr & 0x00ff_ffffu64) << 16; + desc0 |= (addr & 0xff00_0000u64) << 32; + desc1 |= addr >> 32; + + // Present + desc0 |= 1u64 << 47; + + // Type + desc0 |= 0x9u64 << 40; + + (GDTEntry::from_raw(desc0), GDTEntry::from_raw(desc1)) + } +} + +#[cfg(test)] +mod test { + use super::*; + use core::mem::offset_of; + + #[test] + fn test_tss_offsets() { + assert_eq!(offset_of!(X86Tss, reserved0), 0x0); + assert_eq!(offset_of!(X86Tss, stacks), 0x4); + assert_eq!(offset_of!(X86Tss, reserved1), 0x1c); + assert_eq!(offset_of!(X86Tss, ist_stacks), 0x24); + assert_eq!(offset_of!(X86Tss, reserved2), 0x5c); + assert_eq!(offset_of!(X86Tss, reserved3), 0x64); + assert_eq!(offset_of!(X86Tss, io_bmp_base), 0x66); + } +} diff --git a/stage2/src/cpu/vc.rs b/stage2/src/cpu/vc.rs new file mode 100644 index 000000000..c15afe4fd --- /dev/null +++ b/stage2/src/cpu/vc.rs @@ -0,0 +1,647 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::idt::common::X86ExceptionContext; +use crate::address::Address; +use crate::address::VirtAddr; +use crate::cpu::cpuid::{cpuid_table_raw, CpuidLeaf}; +use crate::cpu::percpu::current_ghcb; +use crate::cpu::percpu::this_cpu; +use crate::cpu::X86GeneralRegs; +use crate::debug::gdbstub::svsm_gdbstub::handle_debug_exception; +use crate::error::SvsmError; +use crate::insn_decode::{ + DecodedInsn, DecodedInsnCtx, Immediate, Instruction, Operand, Register, MAX_INSN_SIZE, +}; +use crate::mm::GuestPtr; +use crate::sev::ghcb::GHCB; +use core::fmt; + +pub const SVM_EXIT_EXCP_BASE: usize = 0x40; +pub const SVM_EXIT_LAST_EXCP: usize = 0x5f; +pub const SVM_EXIT_RDTSC: usize = 0x6e; +pub const SVM_EXIT_CPUID: usize = 0x72; +pub const SVM_EXIT_IOIO: usize = 0x7b; +pub const SVM_EXIT_MSR: usize = 0x7c; +pub const SVM_EXIT_RDTSCP: usize = 0x87; +pub const X86_TRAP_DB: usize = 0x01; +pub const X86_TRAP: usize = SVM_EXIT_EXCP_BASE + X86_TRAP_DB; + +const MSR_SVSM_CAA: u64 = 0xc001f000; + +#[derive(Clone, Copy, Debug)] +pub struct VcError { + pub rip: usize, + pub code: usize, + pub error_type: VcErrorType, +} + +impl VcError { + const fn new(ctx: &X86ExceptionContext, error_type: VcErrorType) -> Self { + Self { + rip: ctx.frame.rip, + code: ctx.error_code, + error_type, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum VcErrorType { + Unsupported, + DecodeFailed, + UnknownCpuidLeaf, +} + +impl From for SvsmError { + fn from(e: VcError) -> Self { + Self::Vc(e) + } +} + +impl fmt::Display for VcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Unhandled #VC exception ")?; + match self.error_type { + VcErrorType::Unsupported => { + write!(f, "unsupported #VC exception")?; + } + VcErrorType::DecodeFailed => { + write!(f, "invalid instruction")?; + } + VcErrorType::UnknownCpuidLeaf => { + write!(f, "unknown CPUID leaf")?; + } + } + write!( + f, + " RIP: {:#018x}: error code: {:#018x}", + self.rip, self.code + ) + } +} + +pub fn stage2_handle_vc_exception_no_ghcb(ctx: &mut X86ExceptionContext) -> Result<(), SvsmError> { + let err = ctx.error_code; + let insn_ctx = vc_decode_insn(ctx)?; + + match err { + SVM_EXIT_CPUID => handle_cpuid(ctx), + _ => Err(VcError::new(ctx, VcErrorType::Unsupported).into()), + }?; + + vc_finish_insn(ctx, &insn_ctx); + Ok(()) +} + +pub fn stage2_handle_vc_exception(ctx: &mut X86ExceptionContext) -> Result<(), SvsmError> { + let err = ctx.error_code; + + // To handle NAE events, we're supposed to reset the VALID_BITMAP field of + // the GHCB. This is currently only relevant for IOIO, RDTSC and RDTSCP + // handling. This field is currently reset in the relevant GHCB methods + // but it would be better to move the reset out of the different + // handlers. + let ghcb = current_ghcb(); + + let insn_ctx = vc_decode_insn(ctx)?; + + match (err, insn_ctx.and_then(|d| d.insn())) { + (SVM_EXIT_CPUID, Some(DecodedInsn::Cpuid)) => handle_cpuid(ctx), + (SVM_EXIT_IOIO, Some(ins)) => handle_ioio(ctx, ghcb, ins), + (SVM_EXIT_MSR, Some(ins)) => handle_msr(ctx, ghcb, ins), + (SVM_EXIT_RDTSC, Some(DecodedInsn::Rdtsc)) => ghcb.rdtsc_regs(&mut ctx.regs), + (SVM_EXIT_RDTSCP, Some(DecodedInsn::Rdtsc)) => ghcb.rdtscp_regs(&mut ctx.regs), + _ => Err(VcError::new(ctx, VcErrorType::Unsupported).into()), + }?; + + vc_finish_insn(ctx, &insn_ctx); + Ok(()) +} + +pub fn handle_vc_exception(ctx: &mut X86ExceptionContext, vector: usize) -> Result<(), SvsmError> { + let error_code = ctx.error_code; + + // To handle NAE events, we're supposed to reset the VALID_BITMAP field of + // the GHCB. This is currently only relevant for IOIO, RDTSC and RDTSCP + // handling. This field is currently reset in the relevant GHCB methods + // but it would be better to move the reset out of the different + // handlers. + let ghcb = current_ghcb(); + + let insn_ctx = vc_decode_insn(ctx)?; + + match (error_code, insn_ctx.as_ref().and_then(|d| d.insn())) { + // If the gdb stub is enabled then debugging operations such as single stepping + // will cause either an exception via DB_VECTOR if the DEBUG_SWAP sev_feature is + // clear, or a VC exception with an error code of X86_TRAP if set. + (X86_TRAP, _) => { + handle_debug_exception(ctx, vector); + Ok(()) + } + (SVM_EXIT_CPUID, Some(DecodedInsn::Cpuid)) => handle_cpuid(ctx), + (SVM_EXIT_IOIO, Some(_)) => insn_ctx + .as_ref() + .unwrap() + .emulate(ctx) + .map_err(SvsmError::from), + (SVM_EXIT_MSR, Some(ins)) => handle_msr(ctx, ghcb, ins), + (SVM_EXIT_RDTSC, Some(DecodedInsn::Rdtsc)) => ghcb.rdtsc_regs(&mut ctx.regs), + (SVM_EXIT_RDTSCP, Some(DecodedInsn::Rdtsc)) => ghcb.rdtscp_regs(&mut ctx.regs), + _ => Err(VcError::new(ctx, VcErrorType::Unsupported).into()), + }?; + + vc_finish_insn(ctx, &insn_ctx); + Ok(()) +} + +#[inline] +const fn get_msr(regs: &X86GeneralRegs) -> u64 { + ((regs.rdx as u64) << 32) | regs.rax as u64 & u32::MAX as u64 +} + +/// Handles a read from the SVSM-specific MSR defined the in SVSM spec. +fn handle_svsm_caa_rdmsr(ctx: &mut X86ExceptionContext) -> Result<(), SvsmError> { + let caa = this_cpu() + .guest_vmsa_ref() + .caa_phys() + .ok_or(SvsmError::MissingCAA)? + .bits(); + ctx.regs.rdx = (caa >> 32) & 0xffffffff; + ctx.regs.rax = caa & 0xffffffff; + Ok(()) +} + +fn handle_msr( + ctx: &mut X86ExceptionContext, + ghcb: &GHCB, + ins: DecodedInsn, +) -> Result<(), SvsmError> { + match ins { + DecodedInsn::Wrmsr => { + if get_msr(&ctx.regs) == MSR_SVSM_CAA { + return Ok(()); + } + ghcb.wrmsr_regs(&ctx.regs) + } + DecodedInsn::Rdmsr => { + if get_msr(&ctx.regs) == MSR_SVSM_CAA { + return handle_svsm_caa_rdmsr(ctx); + } + ghcb.rdmsr_regs(&mut ctx.regs) + } + _ => Err(VcError::new(ctx, VcErrorType::DecodeFailed).into()), + } +} + +fn handle_cpuid(ctx: &mut X86ExceptionContext) -> Result<(), SvsmError> { + // Section 2.3.1 GHCB MSR Protocol in SEV-ES Guest-Hypervisor Communication Block + // Standardization Rev. 2.02. + // For SEV-ES/SEV-SNP, we can use the CPUID table already defined and populated with + // firmware information. + // We choose for now not to call the hypervisor to perform CPUID, since it's no trusted. + // Since GHCB is not needed to handle CPUID with the firmware table, we can call the handler + // very soon in stage 2. + snp_cpuid(ctx) +} + +fn snp_cpuid(ctx: &mut X86ExceptionContext) -> Result<(), SvsmError> { + let mut leaf = CpuidLeaf::new(ctx.regs.rax as u32, ctx.regs.rcx as u32); + let xcr0_in = if leaf.cpuid_fn == 0xD && (leaf.cpuid_subfn == 1 || leaf.cpuid_subfn == 0) { + 1 + } else { + 0 + }; + + let Some(ret) = cpuid_table_raw(leaf.cpuid_fn, leaf.cpuid_subfn, xcr0_in, 0) else { + return Err(VcError::new(ctx, VcErrorType::UnknownCpuidLeaf).into()); + }; + + leaf.eax = ret.eax; + leaf.ebx = ret.ebx; + leaf.ecx = ret.ecx; + leaf.edx = ret.edx; + + ctx.regs.rax = leaf.eax as usize; + ctx.regs.rbx = leaf.ebx as usize; + ctx.regs.rcx = leaf.ecx as usize; + ctx.regs.rdx = leaf.edx as usize; + + Ok(()) +} + +fn vc_finish_insn(ctx: &mut X86ExceptionContext, insn_ctx: &Option) { + ctx.frame.rip += insn_ctx.as_ref().map_or(0, |d| d.size()) +} + +fn ioio_get_port(source: Operand, ctx: &X86ExceptionContext) -> u16 { + match source { + Operand::Reg(Register::Rdx) => ctx.regs.rdx as u16, + Operand::Reg(..) => unreachable!("Port value is always in DX"), + Operand::Imm(imm) => match imm { + Immediate::U8(val) => val as u16, + _ => unreachable!("Port value in immediate is always 1 byte"), + }, + } +} + +fn handle_ioio( + ctx: &mut X86ExceptionContext, + ghcb: &GHCB, + insn: DecodedInsn, +) -> Result<(), SvsmError> { + match insn { + DecodedInsn::In(source, in_len) => { + let port = ioio_get_port(source, ctx); + ctx.regs.rax = (ghcb.ioio_in(port, in_len.try_into()?)? & in_len.mask()) as usize; + Ok(()) + } + DecodedInsn::Out(source, out_len) => { + let out_value = ctx.regs.rax as u64; + let port = ioio_get_port(source, ctx); + ghcb.ioio_out(port, out_len.try_into()?, out_value) + } + _ => Err(VcError::new(ctx, VcErrorType::DecodeFailed).into()), + } +} + +fn vc_decode_insn(ctx: &X86ExceptionContext) -> Result, SvsmError> { + if !vc_decoding_needed(ctx.error_code) { + return Ok(None); + } + + // TODO: the instruction fetch will likely to be handled differently when + // #VC exception will be raised from CPL > 0. + + let rip: GuestPtr<[u8; MAX_INSN_SIZE]> = GuestPtr::new(VirtAddr::from(ctx.frame.rip)); + + // rip and rip+15 addresses should belong to a mapped page. + // To ensure this, we rely on GuestPtr::read() that uses the exception table + // to handle faults while fetching. + // SAFETY: we trust the CPU-provided register state to be valid. Thus, RIP + // will point to the instruction that caused #VC to be raised, so it can + // safely be read. + let insn_raw = unsafe { rip.read()? }; + + let insn = Instruction::new(insn_raw); + Ok(Some(insn.decode(ctx)?)) +} + +fn vc_decoding_needed(error_code: usize) -> bool { + !(SVM_EXIT_EXCP_BASE..=SVM_EXIT_LAST_EXCP).contains(&error_code) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cpu::msr::{rdtsc, rdtscp, read_msr, write_msr, RdtscpOut}; + use crate::sev::ghcb::GHCB; + use crate::sev::utils::{get_dr7, raw_vmmcall, set_dr7}; + use core::arch::asm; + use core::arch::x86_64::__cpuid_count; + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_has_memory_encryption_info_cpuid() { + const CPUID_EXTENDED_FUNCTION_INFO: u32 = 0x8000_0000; + const CPUID_MEMORY_ENCRYPTION_INFO: u32 = 0x8000_001F; + let extended_info = unsafe { __cpuid_count(CPUID_EXTENDED_FUNCTION_INFO, 0) }; + assert!(extended_info.eax >= CPUID_MEMORY_ENCRYPTION_INFO); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_has_amd_cpuid() { + const CPUID_VENDOR_INFO: u32 = 0; + + let vendor_info = unsafe { __cpuid_count(CPUID_VENDOR_INFO, 0) }; + + let vendor_name_bytes = [vendor_info.ebx, vendor_info.edx, vendor_info.ecx] + .map(|v| v.to_le_bytes()) + .concat(); + + assert_eq!(core::str::from_utf8(&vendor_name_bytes), Ok("AuthenticAMD")); + } + + const GHCB_FILL_TEST_VALUE: u8 = b'1'; + + fn fill_ghcb_with_test_data() { + current_ghcb().fill(GHCB_FILL_TEST_VALUE); + } + + fn verify_ghcb_was_altered() { + let ghcb = current_ghcb(); + let ptr: *const GHCB = core::ptr::from_ref(ghcb); + let ghcb_bytes = + unsafe { core::slice::from_raw_parts(ptr.cast::(), core::mem::size_of::()) }; + assert!(ghcb_bytes.iter().any(|v| *v != GHCB_FILL_TEST_VALUE)); + } + + // Calls `f` with an assertion that it ended up altering the ghcb. + fn verify_ghcb_gets_altered(f: F) -> R + where + F: FnOnce() -> R, + { + fill_ghcb_with_test_data(); + let result = f(); + verify_ghcb_was_altered(); + result + } + + const TESTDEV_ECHO_LAST_PORT: u16 = 0xe0; + + fn inb(port: u16) -> u8 { + unsafe { + let ret: u8; + asm!("inb %dx, %al", in("dx") port, out("al") ret, options(att_syntax)); + ret + } + } + fn inb_from_testdev_echo() -> u8 { + unsafe { + let ret: u8; + asm!("inb $0xe0, %al", out("al") ret, options(att_syntax)); + ret + } + } + + fn outb(port: u16, value: u8) { + unsafe { asm!("outb %al, %dx", in("al") value, in("dx") port, options(att_syntax)) } + } + + fn outb_to_testdev_echo(value: u8) { + unsafe { asm!("outb %al, $0xe0", in("al") value, options(att_syntax)) } + } + + fn inw(port: u16) -> u16 { + unsafe { + let ret: u16; + asm!("inw %dx, %ax", in("dx") port, out("ax") ret, options(att_syntax)); + ret + } + } + fn inw_from_testdev_echo() -> u16 { + unsafe { + let ret: u16; + asm!("inw $0xe0, %ax", out("ax") ret, options(att_syntax)); + ret + } + } + + fn outw(port: u16, value: u16) { + unsafe { asm!("outw %ax, %dx", in("ax") value, in("dx") port, options(att_syntax)) } + } + + fn outw_to_testdev_echo(value: u16) { + unsafe { asm!("outw %ax, $0xe0", in("ax") value, options(att_syntax)) } + } + + fn inl(port: u16) -> u32 { + unsafe { + let ret: u32; + asm!("inl %dx, %eax", in("dx") port, out("eax") ret, options(att_syntax)); + ret + } + } + fn inl_from_testdev_echo() -> u32 { + unsafe { + let ret: u32; + asm!("inl $0xe0, %eax", out("eax") ret, options(att_syntax)); + ret + } + } + + fn outl(port: u16, value: u32) { + unsafe { asm!("outl %eax, %dx", in("eax") value, in("dx") port, options(att_syntax)) } + } + + fn outl_to_testdev_echo(value: u32) { + unsafe { asm!("outl %eax, $0xe0", in("eax") value, options(att_syntax)) } + } + + fn rep_outsw(port: u16, data: &[u16]) { + unsafe { + asm!("rep outsw", in("dx") port, in("rsi") data.as_ptr(), inout("rcx") data.len() => _, options(att_syntax)) + } + } + + fn rep_insw(port: u16, data: &mut [u16]) { + unsafe { + asm!("rep insw", in("dx") port, in("rdi") data.as_ptr(), inout("rcx") data.len() => _, options(att_syntax)) + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_8() { + const TEST_VAL: u8 = 0x12; + verify_ghcb_gets_altered(|| outb(TESTDEV_ECHO_LAST_PORT, TEST_VAL)); + assert_eq!( + TEST_VAL, + verify_ghcb_gets_altered(|| inb(TESTDEV_ECHO_LAST_PORT)) + ); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_16() { + const TEST_VAL: u16 = 0x4321; + verify_ghcb_gets_altered(|| outw(TESTDEV_ECHO_LAST_PORT, TEST_VAL)); + assert_eq!( + TEST_VAL, + verify_ghcb_gets_altered(|| inw(TESTDEV_ECHO_LAST_PORT)) + ); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_32() { + const TEST_VAL: u32 = 0xabcd1234; + verify_ghcb_gets_altered(|| outl(TESTDEV_ECHO_LAST_PORT, TEST_VAL)); + assert_eq!( + TEST_VAL, + verify_ghcb_gets_altered(|| inl(TESTDEV_ECHO_LAST_PORT)) + ); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_8_hardcoded() { + const TEST_VAL: u8 = 0x12; + verify_ghcb_gets_altered(|| outb_to_testdev_echo(TEST_VAL)); + assert_eq!(TEST_VAL, verify_ghcb_gets_altered(inb_from_testdev_echo)); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_16_hardcoded() { + const TEST_VAL: u16 = 0x4321; + verify_ghcb_gets_altered(|| outw_to_testdev_echo(TEST_VAL)); + assert_eq!(TEST_VAL, verify_ghcb_gets_altered(inw_from_testdev_echo)); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_32_hardcoded() { + const TEST_VAL: u32 = 0xabcd1234; + verify_ghcb_gets_altered(|| outl_to_testdev_echo(TEST_VAL)); + assert_eq!(TEST_VAL, verify_ghcb_gets_altered(inl_from_testdev_echo)); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_port_io_string_16_get_last() { + const TEST_DATA: &[u16] = &[0x1234, 0x5678, 0x9abc, 0xdef0]; + verify_ghcb_gets_altered(|| rep_outsw(TESTDEV_ECHO_LAST_PORT, TEST_DATA)); + assert_eq!( + TEST_DATA.last().unwrap(), + &verify_ghcb_gets_altered(|| inw(TESTDEV_ECHO_LAST_PORT)) + ); + + let mut test_data: [u16; 4] = [0; 4]; + verify_ghcb_gets_altered(|| rep_insw(TESTDEV_ECHO_LAST_PORT, &mut test_data)); + for d in test_data.iter() { + assert_eq!(d, TEST_DATA.last().unwrap()); + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_sev_snp_enablement_msr() { + const MSR_SEV_STATUS: u32 = 0xc0010131; + const MSR_SEV_STATUS_SEV_SNP_ENABLED: u64 = 0b10; + + let sev_status = read_msr(MSR_SEV_STATUS); + assert_ne!(sev_status & MSR_SEV_STATUS_SEV_SNP_ENABLED, 0); + } + + const MSR_APIC_BASE: u32 = 0x1b; + + const APIC_DEFAULT_PHYS_BASE: u64 = 0xfee00000; // KVM's default + const APIC_BASE_PHYS_ADDR_MASK: u64 = 0xffffff000; // bit 12-35 + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_rdmsr_apic() { + let apic_base = verify_ghcb_gets_altered(|| read_msr(MSR_APIC_BASE)); + assert_eq!(apic_base & APIC_BASE_PHYS_ADDR_MASK, APIC_DEFAULT_PHYS_BASE); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_rdmsr_debug_ctl() { + const MSR_DEBUG_CTL: u32 = 0x1d9; + let apic_base = verify_ghcb_gets_altered(|| read_msr(MSR_DEBUG_CTL)); + assert_eq!(apic_base, 0); + } + + const MSR_TSC_AUX: u32 = 0xc0000103; + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_wrmsr_tsc_aux() { + let test_val = 0x1234; + verify_ghcb_gets_altered(|| write_msr(MSR_TSC_AUX, test_val)); + let readback = verify_ghcb_gets_altered(|| read_msr(MSR_TSC_AUX)); + assert_eq!(test_val, readback); + } + + #[test] + // #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + #[ignore = "Currently unhandled by #VC handler"] + fn test_vmmcall_error() { + let res = verify_ghcb_gets_altered(|| unsafe { raw_vmmcall(1005, 0, 0, 0) }); + assert_eq!(res, -1000); + } + + #[test] + // #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + #[ignore = "Currently unhandled by #VC handler"] + fn test_vmmcall_vapic_poll_irq() { + const VMMCALL_HC_VAPIC_POLL_IRQ: u32 = 1; + + let res = + verify_ghcb_gets_altered(|| unsafe { raw_vmmcall(VMMCALL_HC_VAPIC_POLL_IRQ, 0, 0, 0) }); + assert_eq!(res, 0); + } + + #[test] + // #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + #[ignore = "Currently unhandled by #VC handler"] + fn test_read_write_dr7() { + const DR7_DEFAULT: u64 = 0x400; + const DR7_TEST: u64 = 0x401; + + let old_dr7 = verify_ghcb_gets_altered(get_dr7); + assert_eq!(old_dr7, DR7_DEFAULT); + + verify_ghcb_gets_altered(|| unsafe { set_dr7(DR7_TEST) }); + let new_dr7 = verify_ghcb_gets_altered(get_dr7); + assert_eq!(new_dr7, DR7_TEST); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_rdtsc() { + let mut prev: u64 = rdtsc(); + for _ in 0..50 { + let cur = rdtsc(); + assert!(cur > prev); + prev = cur; + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_rdtscp() { + let expected_pid = u32::try_from(verify_ghcb_gets_altered(|| read_msr(MSR_TSC_AUX))) + .expect("pid should be 32 bits"); + let RdtscpOut { + timestamp: mut prev, + pid, + } = rdtscp(); + assert_eq!(pid, expected_pid); + for _ in 0..50 { + let RdtscpOut { + timestamp: cur, + pid, + } = rdtscp(); + assert_eq!(pid, expected_pid); + assert!(cur > prev); + prev = cur; + } + } + + #[test] + // #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + #[ignore = "Currently unhandled by #VC handler"] + fn test_wbinvd() { + verify_ghcb_gets_altered(|| unsafe { + asm!("wbinvd"); + }); + } + + const APIC_DEFAULT_VERSION_REGISTER_OFFSET: u64 = 0x30; + const EXPECTED_APIC_VERSION_NUMBER: u32 = 0x50014; + + #[test] + // #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + #[ignore = "apic mmio is not supported"] + fn test_mmio_apic_version() { + let mut version: u32 = 0; + let address = u32::try_from(APIC_DEFAULT_PHYS_BASE + APIC_DEFAULT_VERSION_REGISTER_OFFSET) + .expect("APIC address should fit in 32 bits"); + verify_ghcb_gets_altered(|| unsafe { + asm!( + "mov (%edx), %eax", + out("eax") version, + in("edx") address, + options(att_syntax) + ) + }); + assert_eq!(version, EXPECTED_APIC_VERSION_NUMBER); + } +} diff --git a/stage2/src/cpu/vmsa.rs b/stage2/src/cpu/vmsa.rs new file mode 100644 index 000000000..5133c3589 --- /dev/null +++ b/stage2/src/cpu/vmsa.rs @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, VirtAddr}; +use crate::sev::status::{sev_flags, SEVStatusFlags}; +use crate::types::{GUEST_VMPL, SVSM_CS, SVSM_CS_FLAGS, SVSM_DS, SVSM_DS_FLAGS}; +use cpuarch::vmsa::{VMSASegment, VMSA}; + +use super::control_regs::{read_cr0, read_cr3, read_cr4}; +use super::efer::read_efer; +use super::gdt; +use super::idt::common::idt; + +fn svsm_code_segment() -> VMSASegment { + VMSASegment { + selector: SVSM_CS, + flags: SVSM_CS_FLAGS, + limit: 0xffff_ffff, + base: 0, + } +} + +fn svsm_data_segment() -> VMSASegment { + VMSASegment { + selector: SVSM_DS, + flags: SVSM_DS_FLAGS, + limit: 0xffff_ffff, + base: 0, + } +} + +fn svsm_gdt_segment() -> VMSASegment { + let (base, limit) = gdt().base_limit(); + VMSASegment { + selector: 0, + flags: 0, + limit, + base, + } +} + +fn svsm_idt_segment() -> VMSASegment { + let (base, limit) = idt().base_limit(); + VMSASegment { + selector: 0, + flags: 0, + limit, + base, + } +} + +pub fn init_svsm_vmsa(vmsa: &mut VMSA, vtom: u64) { + vmsa.es = svsm_data_segment(); + vmsa.cs = svsm_code_segment(); + vmsa.ss = svsm_data_segment(); + vmsa.ds = svsm_data_segment(); + vmsa.fs = svsm_data_segment(); + vmsa.gs = svsm_data_segment(); + vmsa.gdt = svsm_gdt_segment(); + vmsa.idt = svsm_idt_segment(); + + vmsa.cr0 = read_cr0().bits(); + vmsa.cr3 = read_cr3().bits() as u64; + vmsa.cr4 = read_cr4().bits(); + vmsa.efer = read_efer().bits(); + + vmsa.rflags = 0x2; + vmsa.dr6 = 0xffff0ff0; + vmsa.dr7 = 0x400; + vmsa.g_pat = 0x0007040600070406u64; + vmsa.xcr0 = 1; + vmsa.mxcsr = 0x1f80; + vmsa.x87_ftw = 0x5555; + vmsa.x87_fcw = 0x0040; + vmsa.vmpl = 0; + vmsa.vtom = vtom; + + vmsa.sev_features = sev_flags().as_sev_features(); +} + +fn real_mode_code_segment(rip: u64) -> VMSASegment { + VMSASegment { + selector: 0xf000, + base: rip & 0xffff_0000u64, + limit: 0xffff, + flags: 0x9b, + } +} +fn real_mode_data_segment() -> VMSASegment { + VMSASegment { + selector: 0, + flags: 0x93, + limit: 0xFFFF, + base: 0, + } +} + +fn real_mode_sys_seg(flags: u16) -> VMSASegment { + VMSASegment { + selector: 0, + base: 0, + limit: 0xffff, + flags, + } +} + +pub fn vmsa_ref_from_vaddr(vaddr: VirtAddr) -> &'static VMSA { + unsafe { vaddr.as_ptr::().as_ref().unwrap() } +} + +pub fn vmsa_mut_ref_from_vaddr(vaddr: VirtAddr) -> &'static mut VMSA { + unsafe { vaddr.as_mut_ptr::().as_mut().unwrap() } +} + +pub fn init_guest_vmsa(v: &mut VMSA, rip: u64, alternate_injection: bool) { + v.cr0 = 0x6000_0010; + v.rflags = 0x2; + v.rip = rip & 0xffff; + v.cs = real_mode_code_segment(rip); + v.ds = real_mode_data_segment(); + v.es = real_mode_data_segment(); + v.fs = real_mode_data_segment(); + v.gs = real_mode_data_segment(); + v.ss = real_mode_data_segment(); + v.gdt = real_mode_sys_seg(0); + v.idt = real_mode_sys_seg(0); + v.ldt = real_mode_sys_seg(0x82); + v.tr = real_mode_sys_seg(0x8b); + v.dr6 = 0xffff_0ff0; + v.dr7 = 0x0400; + v.g_pat = 0x0007040600070406u64; + v.xcr0 = 1; + v.mxcsr = 0x1f80; + v.x87_ftw = 0x5555; + v.x87_fcw = 0x0040; + + v.vmpl = GUEST_VMPL as u8; + + let mut sev_status = sev_flags(); + + // Ensure that guest VMSAs do not enable restricted injection. + sev_status.remove(SEVStatusFlags::REST_INJ); + + // Enable alternate injection if requested. + if alternate_injection { + sev_status.insert(SEVStatusFlags::ALT_INJ); + } + + v.sev_features = sev_status.as_sev_features(); +} diff --git a/stage2/src/cpu/x86/mod.rs b/stage2/src/cpu/x86/mod.rs new file mode 100644 index 000000000..7f7ba665a --- /dev/null +++ b/stage2/src/cpu/x86/mod.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Authors: Thomas Leroy + +pub mod smap; diff --git a/stage2/src/cpu/x86/smap.S b/stage2/src/cpu/x86/smap.S new file mode 100644 index 000000000..b1fa9cf4c --- /dev/null +++ b/stage2/src/cpu/x86/smap.S @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Authors: Thomas Leroy + +.code64 + +.section .text +.macro asm_clac + .if !CFG_NOSMAP + clac + .endif +.endm + +.macro asm_stac + .if !CFG_NOSMAP + stac + .endif +.endm diff --git a/stage2/src/cpu/x86/smap.rs b/stage2/src/cpu/x86/smap.rs new file mode 100644 index 000000000..48bfabb0d --- /dev/null +++ b/stage2/src/cpu/x86/smap.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Authors: Thomas Leroy + +use core::arch::asm; + +/// Clears RFLAGS.AC to enable SMAP. +/// This is currently only used when SMAP is supported and enabled. +/// SMAP protection is effective only if CR4.SMAP is set and if RFLAGS.AC = 0. +#[inline(always)] +pub fn clac() { + if !cfg!(feature = "nosmap") { + unsafe { asm!("clac", options(att_syntax, nomem, nostack, preserves_flags)) } + } +} + +/// Sets RFLAGS.AC to disable SMAP. +/// This is currently only used when SMAP is supported and enabled. +/// SMAP protection is effective only if CR4.SMAP is set and if RFLAGS.AC = 0. +#[inline(always)] +pub fn stac() { + if !cfg!(feature = "nosmap") { + unsafe { asm!("stac", options(att_syntax, nomem, nostack, preserves_flags)) } + } +} diff --git a/stage2/src/crypto/mod.rs b/stage2/src/crypto/mod.rs new file mode 100644 index 000000000..3bc8de783 --- /dev/null +++ b/stage2/src/crypto/mod.rs @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! SVSM kernel crypto API + +pub mod aead { + //! API for authentication encryption with associated data + + use crate::{protocols::errors::SvsmReqError, sev::secrets_page::VMPCK_SIZE}; + + // Message Header Format (AMD SEV-SNP spec. table 98) + + /// Authenticated tag size (128 bits) + pub const AUTHTAG_SIZE: usize = 16; + /// Initialization vector size (96 bits) + pub const IV_SIZE: usize = 12; + /// Key size + pub const KEY_SIZE: usize = VMPCK_SIZE; + + /// AES-256 GCM + pub trait Aes256GcmTrait { + /// Encrypt the provided buffer using AES-256 GCM + /// + /// # Arguments + /// + /// * `iv`: Initialization vector + /// * `key`: 256-bit key + /// * `aad`: Additional authenticated data + /// * `inbuf`: Cleartext buffer to be encrypted + /// * `outbuf`: Buffer to store the encrypted data, it must be large enough to also + /// hold the authenticated tag. + /// + /// # Returns + /// + /// * Success + /// * `usize`: Number of bytes written to `outbuf` + /// * Error + /// * [SvsmReqError] + fn encrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result; + + /// Decrypt the provided buffer using AES-256 GCM + /// + /// # Returns + /// + /// * `iv`: Initialization vector + /// * `key`: 256-bit key + /// * `aad`: Additional authenticated data + /// * `inbuf`: Cleartext buffer to be decrypted, followed by the authenticated tag + /// * `outbuf`: Buffer to store the decrypted data + /// + /// # Returns + /// + /// * Success + /// * `usize`: Number of bytes written to `outbuf` + /// * Error + /// * [SvsmReqError] + fn decrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result; + } + + /// Aes256Gcm type + #[derive(Copy, Clone, Debug)] + pub struct Aes256Gcm; +} + +// Crypto implementations supported. Only one of them must be compiled-in. + +pub mod rustcrypto; diff --git a/stage2/src/crypto/rustcrypto.rs b/stage2/src/crypto/rustcrypto.rs new file mode 100644 index 000000000..559181432 --- /dev/null +++ b/stage2/src/crypto/rustcrypto.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! RustCrypto implementation + +use aes_gcm::{ + aead::{Aead, Payload}, + Aes256Gcm, Key, KeyInit, Nonce, +}; + +use crate::{ + crypto::aead::{ + Aes256Gcm as CryptoAes256Gcm, Aes256GcmTrait as CryptoAes256GcmTrait, IV_SIZE, KEY_SIZE, + }, + protocols::errors::SvsmReqError, +}; + +#[repr(u64)] +#[derive(Clone, Copy, Debug, PartialEq)] +enum AesGcmOperation { + Encrypt = 0, + Decrypt = 1, +} + +fn aes_gcm_do( + operation: AesGcmOperation, + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], +) -> Result { + let payload = Payload { msg: inbuf, aad }; + + let aes_key = Key::::from_slice(key); + let gcm = Aes256Gcm::new(aes_key); + let nonce = Nonce::from_slice(iv); + + let result = if operation == AesGcmOperation::Encrypt { + gcm.encrypt(nonce, payload) + } else { + gcm.decrypt(nonce, payload) + }; + let buffer = result.map_err(|_| SvsmReqError::invalid_format())?; + + let outbuf = outbuf + .get_mut(..buffer.len()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + outbuf.copy_from_slice(&buffer); + + Ok(buffer.len()) +} + +impl CryptoAes256GcmTrait for CryptoAes256Gcm { + fn encrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result { + aes_gcm_do(AesGcmOperation::Encrypt, iv, key, aad, inbuf, outbuf) + } + + fn decrypt( + iv: &[u8; IV_SIZE], + key: &[u8; KEY_SIZE], + aad: &[u8], + inbuf: &[u8], + outbuf: &mut [u8], + ) -> Result { + aes_gcm_do(AesGcmOperation::Decrypt, iv, key, aad, inbuf, outbuf) + } +} diff --git a/stage2/src/debug/gdbstub.rs b/stage2/src/debug/gdbstub.rs new file mode 100644 index 000000000..7f2bbea42 --- /dev/null +++ b/stage2/src/debug/gdbstub.rs @@ -0,0 +1,753 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +// +// For release builds this module should not be compiled into the +// binary. See the bottom of this file for placeholders that are +// used when the gdb stub is disabled. +// +#[cfg(feature = "enable-gdb")] +pub mod svsm_gdbstub { + use crate::address::{Address, VirtAddr}; + use crate::cpu::control_regs::read_cr3; + use crate::cpu::idt::common::{X86ExceptionContext, BP_VECTOR, DB_VECTOR, VC_VECTOR}; + use crate::cpu::percpu::this_cpu; + use crate::cpu::X86GeneralRegs; + use crate::error::SvsmError; + use crate::locking::{LockGuard, SpinLock}; + use crate::mm::guestmem::{read_u8, write_u8}; + use crate::mm::PerCPUPageMappingGuard; + use crate::platform::SvsmPlatform; + use crate::serial::{SerialPort, Terminal}; + use crate::task::{is_current_task, TaskContext, INITIAL_TASK_ID, TASKLIST}; + use core::arch::asm; + use core::fmt; + use core::sync::atomic::{AtomicBool, Ordering}; + use gdbstub::common::{Signal, Tid}; + use gdbstub::conn::Connection; + use gdbstub::stub::state_machine::GdbStubStateMachine; + use gdbstub::stub::{GdbStubBuilder, MultiThreadStopReason}; + use gdbstub::target::ext::base::multithread::{ + MultiThreadBase, MultiThreadResume, MultiThreadResumeOps, MultiThreadSingleStep, + MultiThreadSingleStepOps, + }; + use gdbstub::target::ext::base::BaseOps; + use gdbstub::target::ext::breakpoints::{Breakpoints, SwBreakpoint}; + use gdbstub::target::ext::thread_extra_info::ThreadExtraInfo; + use gdbstub::target::{Target, TargetError}; + use gdbstub_arch::x86::reg::X86_64CoreRegs; + use gdbstub_arch::x86::X86_64_SSE; + + const INT3_INSTR: u8 = 0xcc; + const MAX_BREAKPOINTS: usize = 32; + + pub fn gdbstub_start(platform: &'static dyn SvsmPlatform) -> Result<(), u64> { + unsafe { + let mut target = GdbStubTarget::new(); + #[expect(static_mut_refs)] + let gdb = GdbStubBuilder::new(GdbStubConnection::new(platform)) + .with_packet_buffer(&mut PACKET_BUFFER) + .build() + .expect("Failed to initialise GDB stub") + .run_state_machine(&mut target) + .expect("Failed to start GDB state machine"); + *GDB_STATE.lock() = Some(SvsmGdbStub { gdb, target }); + GDB_STACK_TOP = GDB_STACK.as_mut_ptr().offset(GDB_STACK.len() as isize - 1) as u64; + } + GDB_INITIALISED.store(true, Ordering::Relaxed); + Ok(()) + } + + #[derive(PartialEq, Eq, Debug)] + enum ExceptionType { + Debug, + SwBreakpoint, + PageFault, + } + + impl From for ExceptionType { + fn from(value: usize) -> Self { + match value { + BP_VECTOR => ExceptionType::SwBreakpoint, + DB_VECTOR => ExceptionType::Debug, + VC_VECTOR => ExceptionType::Debug, + _ => ExceptionType::PageFault, + } + } + } + + pub fn handle_debug_exception(ctx: &mut X86ExceptionContext, exception: usize) { + let exception_type = ExceptionType::from(exception); + let id = this_cpu().runqueue().lock_read().current_task_id(); + let mut task_ctx = TaskContext { + regs: X86GeneralRegs { + r15: ctx.regs.r15, + r14: ctx.regs.r14, + r13: ctx.regs.r13, + r12: ctx.regs.r12, + r11: ctx.regs.r11, + r10: ctx.regs.r10, + r9: ctx.regs.r9, + r8: ctx.regs.r8, + rbp: ctx.regs.rbp, + rdi: ctx.regs.rdi, + rsi: ctx.regs.rsi, + rdx: ctx.regs.rdx, + rcx: ctx.regs.rcx, + rbx: ctx.regs.rbx, + rax: ctx.regs.rax, + }, + rsp: ctx.frame.rsp as u64, + flags: ctx.frame.flags as u64, + ret_addr: ctx.frame.rip as u64, + }; + + // Locking the GDB state for the duration of the stop will cause any other + // APs that hit a breakpoint to busy-wait until the current CPU releases + // the GDB state. They will then resume and report the stop state + // to GDB. + // One thing to watch out for - if a breakpoint is inadvertently placed in + // the GDB handling code itself then this will cause a re-entrant state + // within the same CPU causing a deadlock. + loop { + let mut gdb_state = GDB_STATE.lock(); + if let Some(stub) = gdb_state.as_ref() { + if stub.target.is_single_step != 0 && stub.target.is_single_step != id { + continue; + } + } + + unsafe { + asm!( + r#" + movq %rsp, (%rax) + movq %rax, %rsp + call handle_stop + popq %rax + movq %rax, %rsp + "#, + in("rsi") exception_type as u64, + in("rdi") &mut task_ctx, + in("rdx") &mut gdb_state, + in("rax") GDB_STACK_TOP, + options(att_syntax)); + } + + ctx.frame.rip = task_ctx.ret_addr as usize; + ctx.frame.flags = task_ctx.flags as usize; + ctx.frame.rsp = task_ctx.rsp as usize; + ctx.regs.rax = task_ctx.regs.rax; + ctx.regs.rbx = task_ctx.regs.rbx; + ctx.regs.rcx = task_ctx.regs.rcx; + ctx.regs.rdx = task_ctx.regs.rdx; + ctx.regs.rsi = task_ctx.regs.rsi; + ctx.regs.rdi = task_ctx.regs.rdi; + ctx.regs.rbp = task_ctx.regs.rbp; + ctx.regs.r8 = task_ctx.regs.r8; + ctx.regs.r9 = task_ctx.regs.r9; + ctx.regs.r10 = task_ctx.regs.r10; + ctx.regs.r11 = task_ctx.regs.r11; + ctx.regs.r12 = task_ctx.regs.r12; + ctx.regs.r13 = task_ctx.regs.r13; + ctx.regs.r14 = task_ctx.regs.r14; + ctx.regs.r15 = task_ctx.regs.r15; + + break; + } + } + + pub fn debug_break() { + if GDB_INITIALISED.load(Ordering::Acquire) { + log::info!("***********************************"); + log::info!("* Waiting for connection from GDB *"); + log::info!("***********************************"); + unsafe { + asm!("int3"); + } + } + } + + static GDB_INITIALISED: AtomicBool = AtomicBool::new(false); + static GDB_STATE: SpinLock>> = SpinLock::new(None); + static mut PACKET_BUFFER: [u8; 4096] = [0; 4096]; + // Allocate the GDB stack as an array of u64's to ensure 8 byte alignment of the stack. + static mut GDB_STACK: [u64; 8192] = [0; 8192]; + static mut GDB_STACK_TOP: u64 = 0; + + struct GdbTaskContext { + cr3: usize, + } + + impl GdbTaskContext { + #[must_use = "The task switch will have no effect if the context is dropped"] + fn switch_to_task(id: u32) -> Self { + let cr3 = if is_current_task(id) { + 0 + } else { + let tl = TASKLIST.lock(); + let cr3 = read_cr3(); + let task = tl.get_task(id); + if let Some(task) = task { + task.page_table.lock().load(); + cr3.bits() + } else { + 0 + } + }; + Self { cr3 } + } + } + + impl Drop for GdbTaskContext { + fn drop(&mut self) { + if self.cr3 != 0 { + unsafe { + asm!("mov %rax, %cr3", + in("rax") self.cr3, + options(att_syntax)); + } + } + } + } + + struct SvsmGdbStub<'a> { + gdb: GdbStubStateMachine<'a, GdbStubTarget, GdbStubConnection<'a>>, + target: GdbStubTarget, + } + + impl fmt::Debug for SvsmGdbStub<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SvsmGdbStub") + } + } + + #[no_mangle] + fn handle_stop( + ctx: &mut TaskContext, + exception_type: ExceptionType, + gdb_state: &mut LockGuard<'_, Option>>, + ) { + let SvsmGdbStub { gdb, mut target } = gdb_state.take().expect("Invalid GDB state"); + + target.set_regs(ctx); + + let hardcoded_bp = (exception_type == ExceptionType::SwBreakpoint) + && !target.is_breakpoint(ctx.ret_addr as usize - 1); + + // If the current address is on a breakpoint then we need to + // move the IP back by one byte + if (exception_type == ExceptionType::SwBreakpoint) + && target.is_breakpoint(ctx.ret_addr as usize - 1) + { + ctx.ret_addr -= 1; + } + + let tid = Tid::new(this_cpu().runqueue().lock_read().current_task_id() as usize) + .expect("Current task has invalid ID"); + let mut new_gdb = match gdb { + GdbStubStateMachine::Running(gdb_inner) => { + let reason = if hardcoded_bp { + MultiThreadStopReason::SignalWithThread { + tid, + signal: Signal::SIGINT, + } + } else if exception_type == ExceptionType::PageFault { + MultiThreadStopReason::SignalWithThread { + tid, + signal: Signal::SIGSEGV, + } + } else { + MultiThreadStopReason::SwBreak(tid) + }; + gdb_inner + .report_stop(&mut target, reason) + .expect("Failed to handle software breakpoint") + } + _ => gdb, + }; + + loop { + new_gdb = match new_gdb { + // The first entry into the debugger is via a forced breakpoint during + // initialisation. The state at this point will be Idle instead of + // Running. + GdbStubStateMachine::Idle(mut gdb_inner) => { + let byte = gdb_inner + .borrow_conn() + .read() + .expect("Failed to read from GDB port"); + gdb_inner.incoming_data(&mut target, byte) + .expect("Could not open serial port for GDB connection. \ + Please ensure the virtual machine is configured to provide a second serial port.") + } + GdbStubStateMachine::Running(gdb_inner) => { + new_gdb = gdb_inner.into(); + break; + } + _ => { + panic!("Invalid GDB state when handling breakpoint interrupt"); + } + }; + } + if target.is_single_step == tid.get() as u32 { + ctx.flags |= 0x100; + } else { + ctx.flags &= !0x100; + } + **gdb_state = Some(SvsmGdbStub { + gdb: new_gdb, + target, + }); + } + + struct GdbStubConnection<'a> { + serial_port: SerialPort<'a>, + } + + impl<'a> GdbStubConnection<'a> { + fn new(platform: &'a dyn SvsmPlatform) -> Self { + let serial_port = SerialPort::new(platform.get_io_port(), 0x2f8); + serial_port.init(); + Self { serial_port } + } + + fn read(&self) -> Result { + Ok(self.serial_port.get_byte()) + } + } + + impl Connection for GdbStubConnection<'_> { + type Error = usize; + + fn write(&mut self, byte: u8) -> Result<(), Self::Error> { + self.serial_port.put_byte(byte); + Ok(()) + } + + fn flush(&mut self) -> Result<(), Self::Error> { + Ok(()) + } + } + + #[derive(Clone, Copy)] + struct GdbStubBreakpoint { + addr: VirtAddr, + inst: u8, + } + + struct GdbStubTarget { + ctx: *mut TaskContext, + breakpoints: [GdbStubBreakpoint; MAX_BREAKPOINTS], + is_single_step: u32, + } + + // SAFETY: this can only be unsafe via aliasing of the ctx field, + // which is the exception context on the stack and should not be accessed + // from any other thread. + unsafe impl Send for GdbStubTarget {} + // SAFETY: see safety comment above + unsafe impl Sync for GdbStubTarget {} + + impl GdbStubTarget { + const fn new() -> Self { + Self { + ctx: core::ptr::null_mut(), + breakpoints: [GdbStubBreakpoint { + addr: VirtAddr::null(), + inst: 0, + }; MAX_BREAKPOINTS], + is_single_step: 0, + } + } + + fn ctx(&self) -> Option<&TaskContext> { + // SAFETY: this is a pointer to the exception context on the + // stack, so it is not aliased from a different task. We trust + // the debug exception handler to pass a well-aligned pointer + // pointing to valid memory. + unsafe { self.ctx.as_ref() } + } + + fn ctx_mut(&mut self) -> Option<&mut TaskContext> { + // SAFETY: this is a pointer to the exception context on the + // stack, so it is not aliased from a different task. We trust + // the debug exception handler to pass a well-aligned pointer + // pointing to valid memory. + unsafe { self.ctx.as_mut() } + } + + fn set_regs(&mut self, ctx: &mut TaskContext) { + self.ctx = core::ptr::from_mut(ctx) + } + + fn is_breakpoint(&self, rip: usize) -> bool { + self.breakpoints.iter().any(|b| b.addr.bits() == rip) + } + + fn write_bp_address(addr: VirtAddr, value: u8) -> Result<(), SvsmError> { + // Virtual addresses in code are likely to be in read-only memory. If we + // can get the physical address for this VA then create a temporary + // mapping + + let Ok(phys) = this_cpu().get_pgtable().phys_addr(addr) else { + // The virtual address is not one that SVSM has mapped. + // Try safely writing it to the original virtual address + // SAFETY: it is up to the user to ensure that the address we + // are writing a breakpoint to is valid. + return unsafe { write_u8(addr, value) }; + }; + + let guard = PerCPUPageMappingGuard::create_4k(phys.page_align())?; + let dst = guard + .virt_addr() + .checked_add(phys.page_offset()) + .ok_or(SvsmError::InvalidAddress)?; + + // SAFETY: guard is a new mapped page, non controllable by user. + // We also checked that the destination address didn't overflow. + unsafe { write_u8(dst, value) } + } + } + + impl Target for GdbStubTarget { + type Arch = X86_64_SSE; + + type Error = usize; + + fn base_ops(&mut self) -> BaseOps<'_, Self::Arch, Self::Error> { + BaseOps::MultiThread(self) + } + + #[inline(always)] + fn support_breakpoints( + &mut self, + ) -> Option> { + Some(self) + } + } + + impl From<&TaskContext> for X86_64CoreRegs { + fn from(value: &TaskContext) -> Self { + let mut regs = X86_64CoreRegs::default(); + regs.rip = value.ret_addr; + regs.regs = [ + value.regs.rax as u64, + value.regs.rbx as u64, + value.regs.rcx as u64, + value.regs.rdx as u64, + value.regs.rsi as u64, + value.regs.rdi as u64, + value.regs.rbp as u64, + value.rsp, + value.regs.r8 as u64, + value.regs.r9 as u64, + value.regs.r10 as u64, + value.regs.r11 as u64, + value.regs.r12 as u64, + value.regs.r13 as u64, + value.regs.r14 as u64, + value.regs.r15 as u64, + ]; + regs.eflags = value.flags as u32; + regs + } + } + + impl MultiThreadBase for GdbStubTarget { + fn read_registers( + &mut self, + regs: &mut ::Registers, + tid: Tid, + ) -> gdbstub::target::TargetResult<(), Self> { + if is_current_task(tid.get() as u32) { + *regs = X86_64CoreRegs::from(self.ctx().unwrap()); + } else { + let task = TASKLIST.lock().get_task(tid.get() as u32); + if let Some(task) = task { + // The registers are stored in the top of the task stack as part of the + // saved context. We need to switch to the task pagetable to access them. + let _task_context = GdbTaskContext::switch_to_task(tid.get() as u32); + unsafe { + *regs = X86_64CoreRegs::from(&*(task.rsp as *const TaskContext)); + }; + regs.regs[7] = task.rsp; + } else { + *regs = ::Registers::default(); + } + } + Ok(()) + } + + fn write_registers( + &mut self, + regs: &::Registers, + tid: Tid, + ) -> gdbstub::target::TargetResult<(), Self> { + if !is_current_task(tid.get() as u32) { + return Err(TargetError::NonFatal); + } + + let context = self.ctx_mut().unwrap(); + + context.ret_addr = regs.rip; + context.regs.rax = regs.regs[0] as usize; + context.regs.rbx = regs.regs[1] as usize; + context.regs.rcx = regs.regs[2] as usize; + context.regs.rdx = regs.regs[3] as usize; + context.regs.rsi = regs.regs[4] as usize; + context.regs.rdi = regs.regs[5] as usize; + context.regs.rbp = regs.regs[6] as usize; + context.rsp = regs.regs[7]; + context.regs.r8 = regs.regs[8] as usize; + context.regs.r9 = regs.regs[9] as usize; + context.regs.r10 = regs.regs[10] as usize; + context.regs.r11 = regs.regs[11] as usize; + context.regs.r12 = regs.regs[12] as usize; + context.regs.r13 = regs.regs[13] as usize; + context.regs.r14 = regs.regs[14] as usize; + context.regs.r15 = regs.regs[15] as usize; + context.flags = regs.eflags as u64; + Ok(()) + } + + fn read_addrs( + &mut self, + start_addr: ::Usize, + data: &mut [u8], + tid: Tid, + ) -> gdbstub::target::TargetResult<(), Self> { + // Switch to the task pagetable if necessary. The switch back will + // happen automatically when the variable falls out of scope + let _task_context = GdbTaskContext::switch_to_task(tid.get() as u32); + let start_addr = VirtAddr::from(start_addr); + for (off, dst) in data.iter_mut().enumerate() { + let Ok(val) = read_u8(start_addr + off) else { + return Err(TargetError::NonFatal); + }; + *dst = val; + } + Ok(()) + } + + fn write_addrs( + &mut self, + start_addr: ::Usize, + data: &[u8], + _tid: Tid, + ) -> gdbstub::target::TargetResult<(), Self> { + let start_addr = VirtAddr::from(start_addr); + for (off, src) in data.iter().enumerate() { + let dst = start_addr.checked_add(off).ok_or(TargetError::NonFatal)?; + + // SAFETY: We trust the caller of this trait method to provide a valid address. + // We only cheked that start_adddr + off didn't overflow. + unsafe { write_u8(dst, *src).map_err(|_| TargetError::NonFatal)? } + } + Ok(()) + } + + #[inline(always)] + fn support_resume(&mut self) -> Option> { + Some(self) + } + + fn list_active_threads( + &mut self, + thread_is_active: &mut dyn FnMut(Tid), + ) -> Result<(), Self::Error> { + let mut tl = TASKLIST.lock(); + + // Get the current task. If this is the first request after the remote + // GDB has connected then we need to report the current task first. + // There is no harm in doing this every time the thread list is requested. + let current_task = this_cpu().runqueue().lock_read().current_task_id(); + if current_task == INITIAL_TASK_ID { + thread_is_active(Tid::new(INITIAL_TASK_ID as usize).unwrap()); + } else { + thread_is_active(Tid::new(current_task as usize).unwrap()); + + let mut cursor = tl.list().front_mut(); + while cursor.get().is_some() { + let this_task = cursor.get().unwrap().get_task_id(); + if this_task != current_task { + thread_is_active(Tid::new(this_task as usize).unwrap()); + } + cursor.move_next(); + } + } + Ok(()) + } + + fn support_thread_extra_info( + &mut self, + ) -> Option> { + Some(self) + } + } + + impl ThreadExtraInfo for GdbStubTarget { + fn thread_extra_info(&self, tid: Tid, buf: &mut [u8]) -> Result { + // Get the current task from the stopped CPU so we can mark it as stopped + let tl = TASKLIST.lock(); + let str = match tl.get_task(tid.get() as u32) { + Some(task) => { + if task.is_running() { + "Running".as_bytes() + } else if task.is_terminated() { + "Terminated".as_bytes() + } else { + "Blocked".as_bytes() + } + } + None => "Stopped".as_bytes(), + }; + let mut count = 0; + for (dst, src) in buf.iter_mut().zip(str) { + *dst = *src; + count += 1; + } + Ok(count) + } + } + + impl MultiThreadResume for GdbStubTarget { + fn resume(&mut self) -> Result<(), Self::Error> { + Ok(()) + } + + #[inline(always)] + fn support_single_step(&mut self) -> Option> { + Some(self) + } + + fn clear_resume_actions(&mut self) -> Result<(), Self::Error> { + self.is_single_step = 0; + Ok(()) + } + + fn set_resume_action_continue( + &mut self, + _tid: Tid, + _signal: Option, + ) -> Result<(), Self::Error> { + Ok(()) + } + } + + impl MultiThreadSingleStep for GdbStubTarget { + fn set_resume_action_step( + &mut self, + tid: Tid, + _signal: Option, + ) -> Result<(), Self::Error> { + self.is_single_step = tid.get() as u32; + Ok(()) + } + } + + impl Breakpoints for GdbStubTarget { + #[inline(always)] + fn support_sw_breakpoint( + &mut self, + ) -> Option> { + Some(self) + } + + #[inline(always)] + fn support_hw_breakpoint( + &mut self, + ) -> Option> { + None + } + + #[inline(always)] + fn support_hw_watchpoint( + &mut self, + ) -> Option> { + None + } + } + + impl SwBreakpoint for GdbStubTarget { + fn add_sw_breakpoint( + &mut self, + addr: ::Usize, + _kind: ::BreakpointKind, + ) -> gdbstub::target::TargetResult { + // Find a free breakpoint slot + let Some(free_bp) = self.breakpoints.iter_mut().find(|b| b.addr.is_null()) else { + return Ok(false); + }; + // The breakpoint works by taking the opcode at the bp address, storing + // it and replacing it with an INT3 instruction + let vaddr = VirtAddr::from(addr); + let Ok(inst) = read_u8(vaddr) else { + return Ok(false); + }; + let Ok(_) = GdbStubTarget::write_bp_address(vaddr, INT3_INSTR) else { + return Ok(false); + }; + *free_bp = GdbStubBreakpoint { addr: vaddr, inst }; + Ok(true) + } + + fn remove_sw_breakpoint( + &mut self, + addr: ::Usize, + _kind: ::BreakpointKind, + ) -> gdbstub::target::TargetResult { + let vaddr = VirtAddr::from(addr); + let Some(bp) = self.breakpoints.iter_mut().find(|b| b.addr == vaddr) else { + return Ok(false); + }; + let Ok(_) = GdbStubTarget::write_bp_address(vaddr, bp.inst) else { + return Ok(false); + }; + bp.addr = VirtAddr::null(); + Ok(true) + } + } + + #[cfg(test)] + pub mod tests { + extern crate alloc; + + use super::ExceptionType; + use crate::cpu::idt::common::{BP_VECTOR, VC_VECTOR}; + use alloc::vec; + use alloc::vec::Vec; + + #[test] + fn exception_type_from() { + let exceptions: Vec = [VC_VECTOR, BP_VECTOR, 0] + .iter() + .map(|e| ExceptionType::from(*e)) + .collect(); + assert_eq!( + exceptions, + vec![ + ExceptionType::Debug, + ExceptionType::SwBreakpoint, + ExceptionType::PageFault + ] + ); + } + } +} + +#[cfg(not(feature = "enable-gdb"))] +pub mod svsm_gdbstub { + use crate::cpu::X86ExceptionContext; + use crate::platform::SvsmPlatform; + + pub fn gdbstub_start(_platform: &'static dyn SvsmPlatform) -> Result<(), u64> { + Ok(()) + } + + pub fn handle_debug_exception(_ctx: &mut X86ExceptionContext, _exception: usize) {} + + pub fn debug_break() {} +} diff --git a/stage2/src/debug/mod.rs b/stage2/src/debug/mod.rs new file mode 100644 index 000000000..06e3958c2 --- /dev/null +++ b/stage2/src/debug/mod.rs @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Nicolai Stange + +pub mod gdbstub; +pub mod stacktrace; diff --git a/stage2/src/debug/stacktrace.rs b/stage2/src/debug/stacktrace.rs new file mode 100644 index 000000000..cdd5ba01a --- /dev/null +++ b/stage2/src/debug/stacktrace.rs @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Nicolai Stange + +use crate::{ + address::VirtAddr, + cpu::idt::common::{is_exception_handler_return_site, X86ExceptionContext}, + cpu::percpu::this_cpu, + mm::address_space::STACK_SIZE, + utils::MemoryRegion, +}; +use core::{arch::asm, mem}; + +#[derive(Clone, Copy, Debug, Default)] +struct StackFrame { + rbp: VirtAddr, + rsp: VirtAddr, + rip: VirtAddr, + is_last: bool, + is_exception_frame: bool, + _stack_depth: usize, // Not needed for frame unwinding, only as diagnostic information. +} + +#[derive(Clone, Copy, Debug)] +enum UnwoundStackFrame { + Valid(StackFrame), + Invalid, +} + +type StacksBounds = [MemoryRegion; 3]; + +#[derive(Debug)] +struct StackUnwinder { + next_frame: Option, + stacks: StacksBounds, +} + +impl StackUnwinder { + pub fn unwind_this_cpu() -> Self { + let mut rbp: usize; + unsafe { + asm!("movq %rbp, {}", out(reg) rbp, + options(att_syntax)); + }; + + let cpu = this_cpu(); + let top_of_init_stack = cpu.get_top_of_stack(); + let top_of_df_stack = cpu.get_top_of_df_stack(); + let current_stack = cpu.get_current_stack(); + + let stacks: StacksBounds = [ + MemoryRegion::from_addresses(top_of_init_stack - STACK_SIZE, top_of_init_stack), + MemoryRegion::from_addresses(top_of_df_stack - STACK_SIZE, top_of_df_stack), + current_stack, + ]; + + Self::new(VirtAddr::from(rbp), stacks) + } + + fn new(rbp: VirtAddr, stacks: StacksBounds) -> Self { + let first_frame = Self::unwind_framepointer_frame(rbp, &stacks); + Self { + next_frame: Some(first_frame), + stacks, + } + } + + fn check_unwound_frame( + rbp: VirtAddr, + rsp: VirtAddr, + rip: VirtAddr, + stacks: &StacksBounds, + ) -> UnwoundStackFrame { + // The next frame's rsp or rbp should live on some valid stack, + // otherwise mark the unwound frame as invalid. + let Some(stack) = stacks + .iter() + .find(|stack| stack.contains_inclusive(rsp) || stack.contains_inclusive(rbp)) + else { + log::info!("check_unwound_frame: rsp {rsp:#018x} and rbp {rbp:#018x} does not match any known stack"); + return UnwoundStackFrame::Invalid; + }; + + let is_last = Self::frame_is_last(rbp); + let is_exception_frame = is_exception_handler_return_site(rip); + + if !is_last && !is_exception_frame { + // Consistency check to ensure forward-progress: never unwind downwards. + if rbp < rsp { + return UnwoundStackFrame::Invalid; + } + } + + let _stack_depth = stack.end() - rsp; + + UnwoundStackFrame::Valid(StackFrame { + rbp, + rip, + rsp, + is_last, + is_exception_frame, + _stack_depth, + }) + } + + fn unwind_framepointer_frame(rbp: VirtAddr, stacks: &StacksBounds) -> UnwoundStackFrame { + let rsp = rbp; + + let Some(range) = MemoryRegion::checked_new(rsp, 2 * mem::size_of::()) else { + return UnwoundStackFrame::Invalid; + }; + + if !stacks.iter().any(|stack| stack.contains_region(&range)) { + return UnwoundStackFrame::Invalid; + } + + let rbp = unsafe { rsp.as_ptr::().read_unaligned() }; + let rsp = rsp + mem::size_of::(); + let rip = unsafe { rsp.as_ptr::().read_unaligned() }; + let rsp = rsp + mem::size_of::(); + + Self::check_unwound_frame(rbp, rsp, rip, stacks) + } + + fn unwind_exception_frame(rsp: VirtAddr, stacks: &StacksBounds) -> UnwoundStackFrame { + let Some(range) = MemoryRegion::checked_new(rsp, mem::size_of::()) + else { + return UnwoundStackFrame::Invalid; + }; + + if !stacks.iter().any(|stack| stack.contains_region(&range)) { + return UnwoundStackFrame::Invalid; + } + + let ctx = unsafe { &*rsp.as_ptr::() }; + let rbp = VirtAddr::from(ctx.regs.rbp); + let rip = VirtAddr::from(ctx.frame.rip); + let rsp = VirtAddr::from(ctx.frame.rsp); + + Self::check_unwound_frame(rbp, rsp, rip, stacks) + } + + fn frame_is_last(rbp: VirtAddr) -> bool { + // A new task is launched with RBP = 0, which is pushed onto the stack + // immediatly and can serve as a marker when the end of the stack has + // been reached. + rbp == VirtAddr::new(0) + } +} + +impl Iterator for StackUnwinder { + type Item = UnwoundStackFrame; + + fn next(&mut self) -> Option { + let cur = self.next_frame; + match cur { + Some(cur) => { + match &cur { + UnwoundStackFrame::Invalid => { + self.next_frame = None; + } + UnwoundStackFrame::Valid(cur_frame) => { + if cur_frame.is_last { + self.next_frame = None + } else if cur_frame.is_exception_frame { + self.next_frame = + Some(Self::unwind_exception_frame(cur_frame.rsp, &self.stacks)); + } else { + self.next_frame = + Some(Self::unwind_framepointer_frame(cur_frame.rbp, &self.stacks)); + } + } + }; + + Some(cur) + } + None => None, + } + } +} + +pub fn print_stack(skip: usize) { + let unwinder = StackUnwinder::unwind_this_cpu(); + log::info!("---BACKTRACE---:"); + for frame in unwinder.skip(skip) { + match frame { + UnwoundStackFrame::Valid(item) => log::info!(" [{:#018x}]", item.rip), + UnwoundStackFrame::Invalid => log::info!(" Invalid frame"), + } + } + log::info!("---END---"); +} diff --git a/stage2/src/error.rs b/stage2/src/error.rs new file mode 100644 index 000000000..ddd7631af --- /dev/null +++ b/stage2/src/error.rs @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Carlos López + +//! High level error typing for the public SVSM APIs. +//! +//! This module contains the generic [`SvsmError`] type, which may be returned +//! from any public API in this codebase to signal an error during SVSM +//! operation. Each variant of the type may give more specific information +//! about the source of the error. +//! +//! As a general rule, functions private to a given module may directly return +//! leaf error types, which are contained in [`SvsmError`] variants. Public +//! functions should return an [`SvsmError`] containing a leaf error type, +//! usually the one corresponding to that module. Each module should provide +//! a way to convert a leaf error into a SvsmError via the [`From`] trait. + +use crate::cpu::vc::VcError; +use crate::fs::FsError; +use crate::fw_cfg::FwCfgError; +use crate::insn_decode::InsnError; +use crate::mm::alloc::AllocError; +use crate::sev::ghcb::GhcbError; +use crate::sev::msr_protocol::GhcbMsrError; +use crate::sev::SevSnpError; +use crate::syscall::ObjError; +use crate::task::TaskError; +use elf::ElfError; +use syscall::SysCallError; + +/// Errors related to APIC handling. These may originate from multiple +/// layers in the system. +#[derive(Clone, Copy, Debug)] +pub enum ApicError { + /// An error arising because APIC emulation is disabled. + Disabled, + + /// An error related to APIC emulation. + Emulation, + + /// An error related to APIC registration. + Registration, +} + +/// A generic error during SVSM operation. +#[derive(Clone, Copy, Debug)] +pub enum SvsmError { + /// Errors related to platform initialization. + PlatformInit, + /// Errors during ELF parsing and loading. + Elf(ElfError), + /// Errors related to GHCB + Ghcb(GhcbError), + /// Errors related to MSR protocol + GhcbMsr(GhcbMsrError), + /// Errors related to SEV-SNP operations, like PVALIDATE or RMPUPDATE + SevSnp(SevSnpError), + /// Errors related to TDX operations + Tdx, + /// Generic errors related to memory management + Mem, + /// Errors related to the memory allocator + Alloc(AllocError), + /// Error reported when there is no VMSA set up. + MissingVMSA, + /// Error reported when there is no CAA (Calling Area Address) set up. + MissingCAA, + /// Error reported when there is no secrets page set up. + MissingSecrets, + /// Instruction decode related errors + Insn(InsnError), + /// Invalid address, usually provided by the guest + InvalidAddress, + /// Error reported when convert a usize to Bytes + InvalidBytes, + /// Error reported when converting to UTF-8 + InvalidUtf8, + /// Errors related to firmware parsing + Firmware, + /// Errors related to console operation + Console, + /// Errors related to firmware configuration contents + FwCfg(FwCfgError), + /// Errors related to ACPI parsing. + Acpi, + /// Errors from the filesystem. + FileSystem(FsError), + /// Obj related error + Obj(ObjError), + /// Task management errors, + Task(TaskError), + /// Errors from #VC handler + Vc(VcError), + /// The operation is not supported. + NotSupported, + /// Generic errors related to APIC emulation. + Apic(ApicError), +} + +impl From for SvsmError { + fn from(err: ElfError) -> Self { + Self::Elf(err) + } +} + +impl From for SvsmError { + fn from(err: ApicError) -> Self { + Self::Apic(err) + } +} + +impl From for SvsmError { + fn from(err: ObjError) -> Self { + Self::Obj(err) + } +} + +impl From for SysCallError { + fn from(err: SvsmError) -> Self { + match err { + SvsmError::Alloc(AllocError::OutOfMemory) => SysCallError::ENOMEM, + SvsmError::FileSystem(FsError::FileExists) => SysCallError::EEXIST, + + SvsmError::FileSystem(FsError::FileNotFound) | SvsmError::Obj(ObjError::NotFound) => { + SysCallError::ENOTFOUND + } + + SvsmError::NotSupported => SysCallError::ENOTSUPP, + + SvsmError::FileSystem(FsError::Inval) + | SvsmError::Obj(ObjError::InvalidHandle) + | SvsmError::Mem + | SvsmError::InvalidAddress + | SvsmError::InvalidBytes + | SvsmError::InvalidUtf8 => SysCallError::EINVAL, + + _ => SysCallError::UNKNOWN, + } + } +} diff --git a/stage2/src/fs/api.rs b/stage2/src/fs/api.rs new file mode 100644 index 000000000..07b696c3a --- /dev/null +++ b/stage2/src/fs/api.rs @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; +use alloc::sync::Arc; +use alloc::vec::Vec; + +use core::fmt::Debug; + +use crate::error::SvsmError; +use crate::mm::PageRef; +use crate::string::FixedString; +use packit::PackItError; + +/// Maximum supported length for a single filename +const MAX_FILENAME_LENGTH: usize = 64; +pub type FileName = FixedString; + +/// Represents the type of error occured +/// while doing SVSM filesystem operations. +#[derive(Copy, Clone, Debug, Default)] +pub enum FsError { + #[default] + Inval, + FileExists, + FileNotFound, + PackIt(PackItError), +} + +impl From for SvsmError { + fn from(e: FsError) -> Self { + Self::FileSystem(e) + } +} + +impl From for FsError { + fn from(e: PackItError) -> Self { + Self::PackIt(e) + } +} + +impl From for SvsmError { + fn from(e: PackItError) -> Self { + Self::from(FsError::from(e)) + } +} + +/// Used to define methods of [`FsError`]. +macro_rules! impl_fs_err { + ($name:ident, $v:ident) => { + pub fn $name() -> Self { + Self::$v + } + }; +} + +impl FsError { + impl_fs_err!(inval, Inval); + impl_fs_err!(file_exists, FileExists); + impl_fs_err!(file_not_found, FileNotFound); +} + +/// Represents file operations +pub trait File: Debug + Send + Sync { + /// Used to read contents of a file + /// + /// # Arguments + /// + /// - `buf`: buffer to read the file contents into. + /// - `offset`: file offset to read from. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the number of + /// bytes read if successful, or an [`SvsmError`] if there was a problem + /// during the read operation. + fn read(&self, buf: &mut [u8], offset: usize) -> Result; + + /// Used to write contents to a file + /// + /// # Arguments + /// + /// - `buf`: buffer which holds the contents to be written to the file. + /// - `offset`: file offset to write to. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the number of + /// bytes written if successful, or an [`SvsmError`] if there was a problem + /// during the write operation. + fn write(&self, buf: &[u8], offset: usize) -> Result; + + /// Used to truncate the file to the specified size. + /// + /// # Arguments + /// + /// - `size`: specifies the size in bytes to which the file + /// is to be truncated. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the size of the + /// file after truncation if successful, or an [`SvsmError`] if there was + /// a problem during the truncate operation. + fn truncate(&self, size: usize) -> Result; + + /// Used to get the size of the file. + /// + /// # Returns + /// + /// size of the file in bytes. + fn size(&self) -> usize; + + /// Get reference to backing pages of the file + /// + /// # Arguments + /// + /// - `offset`: offset to the requested page in bytes + /// + /// # Returns + /// + /// [`Option`]: An [`Option`] with the requested page reference. + /// `None` if the offset is not backed by a page. + fn mapping(&self, _offset: usize) -> Option { + None + } +} + +/// Represents directory operations +pub trait Directory: Debug + Send + Sync { + /// Used to get the list of entries in the directory. + /// + /// # Returns + /// + /// A [`Vec`] containing all the entries in the directory. + fn list(&self) -> Vec; + + /// Used to lookup for an entry in the directory. + /// + /// # Arguments + /// + /// - `name`: name of the entry to be looked up in the directory. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the [`DirEntry`] + /// corresponding to the entry being looked up in the directory if present, or + /// an [`SvsmError`] if not present. + fn lookup_entry(&self, name: FileName) -> Result; + + /// Used to create a new file in the directory. + /// + /// # Arguments + /// + /// - `name`: name of the file to be created. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the [`DirEntry`] + /// of the new file created on success, or an [`SvsmError`] on failure + fn create_file(&self, name: FileName) -> Result, SvsmError>; + + /// Used to create a subdirectory in the directory. + /// + /// # Arguments + /// + /// - `name`: name of the subdirectory to be created. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the [`DirEntry`] + /// of the subdirectory created on success, or an [`SvsmError`] on failure + fn create_directory(&self, name: FileName) -> Result, SvsmError>; + + /// Used to remove an entry from the directory. + /// + /// # Arguments + /// + /// - `name`: name of the entry to be removed from the directory. + /// + /// # Returns + /// + /// [`Result<(), SvsmError>`]: A [`Result`] containing the empty + /// value on success, or an [`SvsmError`] on failure + fn unlink(&self, name: FileName) -> Result<(), SvsmError>; +} + +/// Represents a directory entry which could +/// either be a file or a subdirectory. +#[derive(Debug)] +pub enum DirEntry { + File(Arc), + Directory(Arc), +} + +impl DirEntry { + /// Used to check if a [`DirEntry`] variable is a file. + /// + /// # Returns + /// + /// ['true'] if [`DirEntry`] is a file, ['false'] otherwise. + pub fn is_file(&self) -> bool { + matches!(self, Self::File(_)) + } + + /// Used to check if a [`DirEntry`] variable is a directory. + /// + /// # Returns + /// + /// ['true'] if [`DirEntry`] is a directory, ['false'] otherwise. + pub fn is_dir(&self) -> bool { + matches!(self, Self::Directory(_)) + } +} + +impl Clone for DirEntry { + fn clone(&self) -> Self { + match self { + DirEntry::File(f) => DirEntry::File(f.clone()), + DirEntry::Directory(d) => DirEntry::Directory(d.clone()), + } + } +} + +/// Directory entries including their names. +#[derive(Debug)] +pub struct DirectoryEntry { + pub name: FileName, + pub entry: DirEntry, +} + +impl DirectoryEntry { + /// Create a new [`DirectoryEntry`] instance. + /// + /// # Arguments + /// + /// - `name`: name for the entry to be created. + /// - `entry`: [`DirEntry`] containing the file or directory details. + /// + /// # Returns + /// + /// A new [`DirectoryEntry`] instance. + pub fn new(name: FileName, entry: DirEntry) -> Self { + DirectoryEntry { name, entry } + } +} diff --git a/stage2/src/fs/filesystem.rs b/stage2/src/fs/filesystem.rs new file mode 100644 index 000000000..17864d621 --- /dev/null +++ b/stage2/src/fs/filesystem.rs @@ -0,0 +1,721 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::ramfs::RamDirectory; +use super::*; + +use crate::error::SvsmError; +use crate::locking::{RWLock, SpinLock}; +use crate::mm::PageRef; + +use core::cmp::min; + +extern crate alloc; +use alloc::sync::Arc; +use alloc::vec::Vec; + +/// Represents a raw file handle. +#[derive(Debug)] +struct RawFileHandle { + file: Arc, + /// current file offset for the read/write operation + current: usize, +} + +impl RawFileHandle { + fn new(file: &Arc) -> Self { + RawFileHandle { + file: file.clone(), + current: 0, + } + } + + fn read(&mut self, buf: &mut [u8]) -> Result { + let result = self.file.read(buf, self.current); + if let Ok(v) = result { + self.current += v; + } + result + } + + fn write(&mut self, buf: &[u8]) -> Result { + let result = self.file.write(buf, self.current); + if let Ok(num) = result { + self.current += num; + } + result + } + + fn truncate(&self, offset: usize) -> Result { + self.file.truncate(offset) + } + + fn seek(&mut self, pos: usize) { + self.current = min(pos, self.file.size()); + } + + fn size(&self) -> usize { + self.file.size() + } + + fn mapping(&self, offset: usize) -> Option { + self.file.mapping(offset) + } +} + +/// Represents a handle used for file operations in a thread-safe manner. +#[derive(Debug)] +pub struct FileHandle { + // Use a SpinLock here because the read operation also needs to be mutable + // (changes file pointer). Parallel reads are still possible with multiple + // file handles + handle: SpinLock, +} + +impl FileHandle { + /// Create a new file handle instance. + pub fn new(file: &Arc) -> Self { + FileHandle { + handle: SpinLock::new(RawFileHandle::new(file)), + } + } + + /// Used to read contents from the file handle. + /// + /// # Arguments + /// + /// - `buf`: buffer to read the file contents to + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the number of + /// bytes read if successful, or an [`SvsmError`] if there was a problem + /// during the read operation. + pub fn read(&self, buf: &mut [u8]) -> Result { + self.handle.lock().read(buf) + } + + /// Used to write contents to the file handle + /// + /// # Arguments + /// + /// - `buf`: buffer which holds the contents to be written to the file. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the number of + /// bytes written if successful, or an [`SvsmError`] if there was a problem + /// during the write operation. + pub fn write(&self, buf: &[u8]) -> Result { + self.handle.lock().write(buf) + } + + /// Used to truncate the file to the specified size. + /// + /// # Arguments + /// + /// - `offset`: specifies the size in bytes to which the file + /// git is to be truncated. + /// + /// # Returns + /// + /// [`Result`]: A [`Result`] containing the size of the + /// file after truncation if successful, or an [`SvsmError`] if there was + /// a problem during the truncate operation. + pub fn truncate(&self, offset: usize) -> Result { + self.handle.lock().truncate(offset) + } + + /// Used to change the current file offset. + /// + /// # Arguments + /// + /// - `pos`: intended new file offset value. + pub fn seek(&self, pos: usize) { + self.handle.lock().seek(pos); + } + + /// Used to get the size of the file. + /// + /// # Returns + /// + /// Size of the file in bytes. + pub fn size(&self) -> usize { + self.handle.lock().size() + } + + pub fn position(&self) -> usize { + self.handle.lock().current + } + + pub fn mapping(&self, offset: usize) -> Option { + self.handle.lock().mapping(offset) + } +} + +/// Represents SVSM filesystem +#[derive(Debug)] +struct SvsmFs { + root: Option>, +} + +impl SvsmFs { + const fn new() -> Self { + SvsmFs { root: None } + } + + /// Used to set the root directory of the SVSM filesystem. + /// + /// # Arguments + /// + /// - `root`: represents directory which is to be set + /// as the root of the filesystem. + fn initialize(&mut self, root: &Arc) { + assert!(!self.initialized()); + self.root = Some(root.clone()); + } + + #[cfg(all(any(test, fuzzing), not(test_in_svsm)))] + fn uninitialize(&mut self) { + self.root = None; + } + + /// Used to check if the filesystem is initialized. + /// + /// # Returns + /// + /// [`bool`]: If the filesystem is initialized. + fn initialized(&self) -> bool { + self.root.is_some() + } + + /// Used to get the root directory of the filesystem. + /// + /// # Returns + /// + /// [`Arc`]: root directory of the filesystem. + fn root_dir(&self) -> Arc { + assert!(self.initialized()); + self.root.as_ref().unwrap().clone() + } +} + +static FS_ROOT: RWLock = RWLock::new(SvsmFs::new()); + +/// Used to initialize the filesystem with an empty root directory. +pub fn initialize_fs() { + let root_dir = Arc::new(RamDirectory::new()); + + FS_ROOT.lock_write().initialize(&root_dir); +} + +#[cfg(any(test, fuzzing))] +#[cfg_attr(test_in_svsm, derive(Clone, Copy))] +#[derive(Debug)] +pub struct TestFileSystemGuard; + +#[cfg(any(test, fuzzing))] +impl TestFileSystemGuard { + /// Create a test filesystem. + /// + /// When running as a regular test in userspace: + /// + /// * Creating the struct via `setup()` will initialize an empty + /// filesystem. + /// * Dropping the struct will cause the filesystem to + /// uninitialize. + /// + /// When running inside the SVSM, creating or dropping the struct + /// is a no-op, as the filesystem is managed by the SVSM kernel. + #[must_use = "filesystem guard must be held for the whole test"] + pub fn setup() -> Self { + #[cfg(not(test_in_svsm))] + initialize_fs(); + Self + } +} + +#[cfg(all(any(test, fuzzing), not(test_in_svsm)))] +impl Drop for TestFileSystemGuard { + fn drop(&mut self) { + // Uninitialize the filesystem only if running in userspace. + FS_ROOT.lock_write().uninitialize(); + } +} + +/// Used to get an iterator over all the directory and file names contained in a path. +/// Directory name or file name in the path can be an empty value. +/// +/// # Argument +/// +/// `path`: path to be split. +/// +/// # Returns +/// +/// [`impl Iterator + DoubleEndedIterator`]: iterator over all the +/// directory and file names in the path. +fn split_path_allow_empty(path: &str) -> impl DoubleEndedIterator { + path.split('/').filter(|x| !x.is_empty()) +} + +/// Used to get an iterator over all the directory and file names contained in a path. +/// This function performs error checking. +/// +/// # Argument +/// +/// `path`: path to be split. +/// +/// # Returns +/// +/// [`impl Iterator + DoubleEndedIterator`]: iterator over all the +/// directory and file names in the path. +fn split_path(path: &str) -> Result, SvsmError> { + let mut path_items = split_path_allow_empty(path).peekable(); + path_items + .peek() + .ok_or(SvsmError::FileSystem(FsError::inval()))?; + Ok(path_items) +} + +/// Used to perform a walk over the items in a path while checking +/// each item is a directory. +/// +/// # Argument +/// +/// `path_items`: contains items in a path. +/// +/// # Returns +/// +/// [`Result, SvsmError>`]: [`Result`] containing the +/// directory corresponding to the path if successful, or [`SvsmError`] +/// if there is an error. +fn walk_path<'a, I>(path_items: I) -> Result, SvsmError> +where + I: Iterator, +{ + let fs_root = FS_ROOT.lock_read(); + let mut current_dir = fs_root.root_dir(); + drop(fs_root); + + for item in path_items { + let dir_name = FileName::from(item); + let dir_entry = current_dir.lookup_entry(dir_name)?; + current_dir = match dir_entry { + DirEntry::File(_) => return Err(SvsmError::FileSystem(FsError::file_not_found())), + DirEntry::Directory(dir) => dir, + }; + } + + Ok(current_dir) +} + +/// Used to perform a walk over the items in a path while checking +/// each existing item is a directory, while creating a directory +/// for each non-existing item. +/// +/// # Argument +/// +/// `path_items`: contains items in a path. +/// +/// # Returns +/// +/// [`Result, SvsmError>`]: [`Result`] containing the +/// directory corresponding to the path if successful, or [`SvsmError`] +/// if there is an error. +fn walk_path_create<'a, I>(path_items: I) -> Result, SvsmError> +where + I: Iterator, +{ + let fs_root = FS_ROOT.lock_read(); + let mut current_dir = fs_root.root_dir(); + drop(fs_root); + + for item in path_items { + let dir_name = FileName::from(item); + let lookup = current_dir.lookup_entry(dir_name); + let dir_entry = match lookup { + Ok(entry) => entry, + Err(_) => DirEntry::Directory(current_dir.create_directory(dir_name)?), + }; + current_dir = match dir_entry { + DirEntry::File(_) => return Err(SvsmError::FileSystem(FsError::file_not_found())), + DirEntry::Directory(dir) => dir, + }; + } + + Ok(current_dir) +} + +/// Used to open a file to get the file handle for further file operations. +/// +/// # Argument +/// +/// `path`: path of the file to be opened. +/// +/// # Returns +/// +/// [`Result`]: [`Result`] containing the [`FileHandle`] +/// of the opened file if the file exists, [`SvsmError`] otherwise. +pub fn open(path: &str) -> Result { + let mut path_items = split_path(path)?; + let file_name = FileName::from(path_items.next_back().unwrap()); + let current_dir = walk_path(path_items)?; + + let dir_entry = current_dir.lookup_entry(file_name)?; + + match dir_entry { + DirEntry::Directory(_) => Err(SvsmError::FileSystem(FsError::file_not_found())), + DirEntry::File(f) => Ok(FileHandle::new(&f)), + } +} + +/// Used to create a file with the given path. +/// +/// # Argument +/// +/// `path`: path of the file to be created. +/// +/// # Returns +/// +/// [`Result`]: [`Result`] containing the [`FileHandle`] +/// for the opened file if successful, [`SvsmError`] otherwise. +pub fn create(path: &str) -> Result { + let mut path_items = split_path(path)?; + let file_name = FileName::from(path_items.next_back().unwrap()); + let current_dir = walk_path(path_items)?; + let file = current_dir.create_file(file_name)?; + + Ok(FileHandle::new(&file)) +} + +/// Used to create a file and the missing subdirectories in the given path. +/// +/// # Argument +/// +/// `path`: path of the file to be created. +/// +/// # Returns +/// +/// [`Result`]: [`Result`] containing the [`FileHandle`] +/// for the opened file if successful, [`SvsmError`] otherwise. +pub fn create_all(path: &str) -> Result { + let mut path_items = split_path(path)?; + let file_name = FileName::from(path_items.next_back().unwrap()); + let current_dir = walk_path_create(path_items)?; + + if file_name.length() == 0 { + return Err(SvsmError::FileSystem(FsError::inval())); + } + + let file = current_dir.create_file(file_name)?; + + Ok(FileHandle::new(&file)) +} + +/// Used to create a directory with the given path. +/// +/// # Argument +/// +/// `path`: path of the directory to be created. +/// +/// # Returns +/// +/// [`Result<(), SvsmError>`]: [`Result`] containing the unit +/// value if successful, [`SvsmError`] otherwise. +pub fn mkdir(path: &str) -> Result<(), SvsmError> { + let mut path_items = split_path(path)?; + let dir_name = FileName::from(path_items.next_back().unwrap()); + let current_dir = walk_path(path_items)?; + + current_dir.create_directory(dir_name)?; + + Ok(()) +} + +/// Used to delete a file or a directory. +/// +/// # Argument +/// +/// `path`: path of the file or directory to be created. +/// +/// # Returns +/// +/// [`Result<(), SvsmError>`]: [`Result`] containing the unit +/// value if successful, [`SvsmError`] otherwise. +pub fn unlink(path: &str) -> Result<(), SvsmError> { + let mut path_items = split_path(path)?; + let entry_name = FileName::from(path_items.next_back().unwrap()); + let dir = walk_path(path_items)?; + + dir.unlink(entry_name) +} + +/// Used to list the contents of a directory. +/// +/// # Argument +/// +/// `path`: path of the directory to be listed. +/// # Returns +/// +/// [`Result<(), SvsmError>`]: [`Result`] containing the [`Vec`] +/// of directory entries if successful, [`SvsmError`] otherwise. +pub fn list_dir(path: &str) -> Result, SvsmError> { + let items = split_path_allow_empty(path); + let dir = walk_path(items)?; + Ok(dir.list()) +} + +/// Used to read from a file handle. +/// +/// # Arguments +/// +/// - `fh`: Filehandle to be read. +/// - `buf`: buffer to read the file contents into. +/// +/// # Returns +/// +/// [`Result`]: [`Result`] containing the number +/// of bytes read if successful, [`SvsmError`] otherwise. +pub fn read(fh: &FileHandle, buf: &mut [u8]) -> Result { + fh.read(buf) +} + +/// Used to write into file handle. +/// +/// # Arguments +/// +/// - `fh`: Filehandle to be written. +/// - `buf`: buffer containing the data to be written. +/// +/// # Returns +/// +/// [`Result`]: [`Result`] containing the number +/// of bytes written if successful, [`SvsmError`] otherwise. +pub fn write(fh: &FileHandle, buf: &[u8]) -> Result { + fh.write(buf) +} + +/// Used to set the file offset +/// +/// # Arguements +/// +/// - `fh`: Filehandle for the seek operation. +/// - `pos`: new file offset value to be set. +pub fn seek(fh: &FileHandle, pos: usize) { + fh.seek(pos) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mm::alloc::{TestRootMem, DEFAULT_TEST_MEMORY_SIZE}; + + #[test] + fn create_dir() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + // Create file - should fail as directory does not exist yet + create("test1/file1").unwrap_err(); + + // Create directory + mkdir("test1").unwrap(); + + // Check double-create + mkdir("test1").unwrap_err(); + + // Check if it appears in the listing + let root_list = list_dir("").unwrap(); + assert_eq!(root_list, [FileName::from("test1")]); + + // Try again - should succeed now + create("test1/file1").unwrap(); + + // Cleanup + unlink("test1/file1").unwrap(); + unlink("test1").unwrap(); + } + + #[test] + fn create_and_unlink_file() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + create("test1").unwrap(); + + // Check if it appears in the listing + let root_list = list_dir("").unwrap(); + assert_eq!(root_list, [FileName::from("test1")]); + + // Try creating again as file - should fail + create("test1").unwrap_err(); + + // Try creating again as directory - should fail + mkdir("test1").unwrap_err(); + + // Try creating a different dir + mkdir("test2").unwrap(); + + // Unlink file + unlink("test1").unwrap(); + + // Check if it is removed from the listing + let root_list = list_dir("").unwrap(); + assert_eq!(root_list, [FileName::from("test2")]); + + // Cleanup + unlink("test2").unwrap(); + } + + #[test] + fn create_sub_dir() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + // Create file - should fail as directory does not exist yet + create("test1/test2/file1").unwrap_err(); + + // Create directory + mkdir("test1").unwrap(); + + // Create sub-directory + mkdir("test1/test2").unwrap(); + + // Check if it appears in the listing + let list = list_dir("test1/").unwrap(); + assert_eq!(list, [FileName::from("test2")]); + + // Try again - should succeed now + create("test1/test2/file1").unwrap(); + + // Check if it appears in the listing + let list = list_dir("test1/test2/").unwrap(); + assert_eq!(list, [FileName::from("file1")]); + + // Cleanup + unlink("test1/test2/file1").unwrap(); + unlink("test1/test2").unwrap(); + unlink("test1/").unwrap(); + } + + #[test] + fn test_unlink() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + // Create directory + mkdir("test1").unwrap(); + + // Creating files + create("test1/file1").unwrap(); + create("test1/file2").unwrap(); + + // Check if they appears in the listing + let list = list_dir("test1").unwrap(); + assert_eq!(list, [FileName::from("file1"), FileName::from("file2")]); + + // Unlink non-existent file + unlink("test2").unwrap_err(); + + // Unlink existing file + unlink("test1/file1").unwrap(); + + // Check if it is removed from the listing + let list = list_dir("test1").unwrap(); + assert_eq!(list, [FileName::from("file2")]); + + // Cleanup + unlink("test1/file2").unwrap(); + unlink("test1").unwrap(); + } + + #[test] + fn test_open_read_write_seek() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + // Create directory + mkdir("test1").unwrap(); + + // Try again - should succeed now + create("test1/file1").unwrap(); + + // Try to open non-existent file + open("test1/file2").unwrap_err(); + + let fh = open("test1/file1").unwrap(); + + assert!(fh.size() == 0); + + let buf: [u8; 512] = [0xff; 512]; + let result = write(&fh, &buf).unwrap(); + assert_eq!(result, 512); + + assert_eq!(fh.size(), 512); + + fh.seek(256); + let buf2: [u8; 512] = [0xcc; 512]; + let result = write(&fh, &buf2).unwrap(); + assert_eq!(result, 512); + + assert_eq!(fh.size(), 768); + + let mut buf3: [u8; 1024] = [0; 1024]; + fh.seek(0); + let result = read(&fh, &mut buf3).unwrap(); + assert_eq!(result, 768); + + for (i, elem) in buf3.iter().enumerate() { + let expected: u8 = if i < 256 { + 0xff + } else if i < 768 { + 0xcc + } else { + 0x0 + }; + assert!(*elem == expected); + } + + // Cleanup + unlink("test1/file1").unwrap(); + unlink("test1").unwrap(); + } + + #[test] + fn test_multiple_file_handles() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + // Create file + let fh1 = create("file").unwrap(); + assert_eq!(fh1.size(), 0); + + let buf1: [u8; 6144] = [0xff; 6144]; + let result = fh1.write(&buf1).unwrap(); + assert_eq!(result, 6144); + assert_eq!(fh1.size(), 6144); + + // Another handle to the same file + let fh2 = open("file").unwrap(); + assert_eq!(fh2.size(), 6144); + + let mut buf2: [u8; 4096] = [0; 4096]; + let result = fh2.read(&mut buf2).unwrap(); + assert_eq!(result, 4096); + + for elem in &buf2 { + assert_eq!(*elem, 0xff); + } + + fh1.truncate(2048).unwrap(); + + let result = fh2.read(&mut buf2).unwrap(); + assert_eq!(result, 0); + + // Cleanup + unlink("file").unwrap(); + } +} diff --git a/stage2/src/fs/init.rs b/stage2/src/fs/init.rs new file mode 100644 index 000000000..9b93f578a --- /dev/null +++ b/stage2/src/fs/init.rs @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr}; +use crate::error::SvsmError; +use crate::mm::ptguards::PerCPUPageMappingGuard; +use packit::PackItArchiveDecoder; + +use super::*; + +extern crate alloc; +use alloc::slice; + +/// Used to create a SVSM RAM filesystem from a filesystem archive. +/// +/// # Arguments +/// +/// - `kernel_fs_start`: denotes the physical address at which the archive starts. +/// - `kernel_fs_end`: denotes the physical address at which the archive ends. +/// +/// # Assertion +/// +/// asserts if `kernel_fs_end` is greater than or equal to `kernel_fs_start`. +/// +/// # Returns +/// [`Result<(), SvsmError>`]: A [`Result`] containing the unit value if successful, +/// [`SvsmError`] otherwise. +pub fn populate_ram_fs(kernel_fs_start: u64, kernel_fs_end: u64) -> Result<(), SvsmError> { + assert!(kernel_fs_end >= kernel_fs_start); + + let pstart = PhysAddr::from(kernel_fs_start); + let pend = PhysAddr::from(kernel_fs_end); + let size = pend - pstart; + + if size == 0 { + return Ok(()); + } + + log::info!("Unpacking FS archive..."); + + let guard = PerCPUPageMappingGuard::create(pstart.page_align(), pend.page_align_up(), 0)?; + let vstart = guard.virt_addr() + pstart.page_offset(); + + let data: &[u8] = unsafe { slice::from_raw_parts(vstart.as_ptr(), size) }; + let archive = PackItArchiveDecoder::load(data)?; + + for file in archive { + let file = file?; + let handle = create_all(file.name())?; + handle.truncate(0)?; + let written = handle.write(file.data())?; + if written != file.data().len() { + log::error!("Incomplete data write to {}", file.name()); + return Err(SvsmError::FileSystem(FsError::inval())); + } + + log::info!(" Unpacked {}", file.name()); + } + + log::info!("Unpacking done"); + + Ok(()) +} diff --git a/stage2/src/fs/mod.rs b/stage2/src/fs/mod.rs new file mode 100644 index 000000000..337859fce --- /dev/null +++ b/stage2/src/fs/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Joerg Roedel + +mod api; +mod filesystem; +mod init; +mod ramfs; + +pub use api::*; +pub use filesystem::*; +pub use init::populate_ram_fs; diff --git a/stage2/src/fs/ramfs.rs b/stage2/src/fs/ramfs.rs new file mode 100644 index 000000000..1eae38d94 --- /dev/null +++ b/stage2/src/fs/ramfs.rs @@ -0,0 +1,547 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::*; + +use crate::error::SvsmError; +use crate::locking::RWLock; +use crate::mm::PageRef; +use crate::types::{PAGE_SHIFT, PAGE_SIZE}; +use crate::utils::{page_align_up, page_offset}; + +extern crate alloc; +use alloc::sync::Arc; +use alloc::vec::Vec; + +use core::cmp::{max, min}; + +/// Represents an SVSM Ramfile +#[derive(Debug, Default)] +struct RawRamFile { + /// Maximum size of the file without allocating new pages + capacity: usize, + /// Current size of the file + size: usize, + /// Vector of pages allocated for the file + pages: Vec, +} + +impl RawRamFile { + /// Used to get new instance of [`RawRamFile`]. + pub fn new() -> Self { + RawRamFile { + capacity: 0, + size: 0, + pages: Vec::new(), + } + } + + /// Used to increase the capacity of the file by allocating a + /// new page. + /// + /// # Returns + /// + /// [`Result<(), SvsmError>`]: A [`Result`] containing empty + /// value if successful, SvsvError otherwise. + fn increase_capacity(&mut self) -> Result<(), SvsmError> { + let page_ref = PageRef::new()?; + self.pages.push(page_ref); + self.capacity += PAGE_SIZE; + Ok(()) + } + + /// Used to set the capacity of the file. + /// + /// # Argument + /// + /// `capacity`: intended new capacity of the file. + /// + /// # Returns + /// + /// [`Result<(), SvsmError>`]: A [Result] containing empty + /// value if successful, SvsmError otherwise. + fn set_capacity(&mut self, capacity: usize) -> Result<(), SvsmError> { + let cap = page_align_up(capacity); + + while cap > self.capacity { + self.increase_capacity()?; + } + + Ok(()) + } + + /// Used to read a page corresponding to the file from + /// a particular offset. + /// + /// # Arguements + /// + /// - `buf`: buffer to read the file contents into. + /// - `offset`: offset to read the file from. + /// + /// # Assert + /// + /// Assert that read operation doesn't extend beyond a page. + fn read_from_page(&self, buf: &mut [u8], offset: usize) { + let page_index = page_offset(offset); + let index = offset / PAGE_SIZE; + self.pages[index].read(page_index, buf); + } + + /// Used to write contents to a page corresponding to + /// the file at a particular offset. + /// + /// # Arguments + /// + /// - `buf`: buffer that contains the data to write to the file. + /// - `offset`: file offset to write the data. + /// # Assert + /// + /// Assert that write operation doesn't extend beyond a page. + fn write_to_page(&self, buf: &[u8], offset: usize) { + let page_index = page_offset(offset); + let index = offset / PAGE_SIZE; + self.pages[index].write(page_index, buf); + } + + /// Used to read the file from a particular offset. + /// + /// # Arguments + /// + /// - `buf`: buffer to read the contents of the file into. + /// - `offset`: file offset to read from. + /// + /// # Returns + /// + /// [`Result<(), SvsmError>`]: A [Result] containing empty + /// value if successful, SvsmError otherwise. + fn read(&self, buf: &mut [u8], offset: usize) -> Result { + let mut current = min(offset, self.size); + let mut len = buf.len(); + let mut bytes: usize = 0; + let mut buf_offset = 0; + + while len > 0 { + let page_end = min(page_align_up(current + 1), self.size); + let page_len = min(page_end - current, len); + let buf_end = buf_offset + page_len; + + if page_len == 0 { + break; + } + + self.read_from_page(&mut buf[buf_offset..buf_end], current); + + buf_offset = buf_end; + current += page_len; + len -= page_len; + bytes += page_len; + } + + Ok(bytes) + } + + /// Used to write to the file at a particular offset. + /// + /// # Arguments + /// + /// - `buf`: buffer that contains the data to write into the file. + /// - `offset`: file offset to read from. + /// + /// # Returns + /// + /// [`Result<(), SvsmError>`]: A [Result] containing empty + /// value if successful, SvsmError otherwise. + fn write(&mut self, buf: &[u8], offset: usize) -> Result { + let mut current = offset; + let mut bytes: usize = 0; + let mut len = buf.len(); + let mut buf_offset: usize = 0; + let capacity = offset + .checked_add(len) + .ok_or(SvsmError::FileSystem(FsError::inval()))?; + + self.set_capacity(capacity)?; + + while len > 0 { + let page_len = min(PAGE_SIZE - page_offset(current), len); + let buf_end = buf_offset + page_len; + + self.write_to_page(&buf[buf_offset..buf_end], current); + self.size = max(self.size, current + page_len); + + current += page_len; + buf_offset += page_len; + len -= page_len; + bytes += page_len; + } + + Ok(bytes) + } + + /// Used to truncate the file to a given size. + /// + /// # Argument + /// + /// - `size` : intended file size after truncation + /// + /// # Returns + /// + /// [`Result`]: a [`Result`] containing the + /// number of bytes file truncated to if successful, SvsmError + /// otherwise. + fn truncate(&mut self, size: usize) -> Result { + if size > self.size { + return Err(SvsmError::FileSystem(FsError::inval())); + } + + let offset = page_offset(size); + let base_pages = size / PAGE_SIZE; + let new_pages = if offset > 0 { + base_pages + 1 + } else { + base_pages + }; + + // Clear pages and remove them from the file + for page_ref in self.pages.drain(new_pages..) { + page_ref.fill(0, 0); + } + + self.capacity = new_pages * PAGE_SIZE; + self.size = size; + + if offset > 0 { + // Clear the last page after new EOF + let page_ref = self.pages.last().unwrap(); + page_ref.fill(offset, 0); + } + + Ok(size) + } + + /// Used to get the size of the file in bytes. + /// + /// # Returns + /// the size of the file in bytes. + fn size(&self) -> usize { + self.size + } + + fn mapping(&self, offset: usize) -> Option { + if offset > self.size() { + return None; + } + self.pages.get(offset >> PAGE_SHIFT).cloned() + } +} + +/// Represents a SVSM file with synchronized access +#[derive(Debug)] +pub struct RamFile { + rawfile: RWLock, +} + +impl RamFile { + /// Used to get a new instance of [`RamFile`]. + pub fn new() -> Self { + RamFile { + rawfile: RWLock::new(RawRamFile::new()), + } + } +} + +impl File for RamFile { + fn read(&self, buf: &mut [u8], offset: usize) -> Result { + self.rawfile.lock_read().read(buf, offset) + } + + fn write(&self, buf: &[u8], offset: usize) -> Result { + self.rawfile.lock_write().write(buf, offset) + } + + fn truncate(&self, size: usize) -> Result { + self.rawfile.lock_write().truncate(size) + } + + fn size(&self) -> usize { + self.rawfile.lock_read().size() + } + + fn mapping(&self, offset: usize) -> Option { + self.rawfile.lock_read().mapping(offset) + } +} + +/// Represents a SVSM directory with synchronized access +#[derive(Debug)] +pub struct RamDirectory { + entries: RWLock>, +} + +impl RamDirectory { + /// Used to get a new instance of [`RamDirectory`] + pub fn new() -> Self { + RamDirectory { + entries: RWLock::new(Vec::new()), + } + } + + /// Used to check if an entry is present in the directory. + /// + /// # Argument + /// + /// `name`: name of the entry to be looked up. + /// + /// # Returns + /// [`true`] if the entry is present, [`false`] otherwise. + fn has_entry(&self, name: &FileName) -> bool { + self.entries + .lock_read() + .iter() + .any(|entry| entry.name == *name) + } +} + +impl Directory for RamDirectory { + fn list(&self) -> Vec { + self.entries + .lock_read() + .iter() + .map(|e| e.name) + .collect::>() + } + + fn lookup_entry(&self, name: FileName) -> Result { + for e in self.entries.lock_read().iter() { + if e.name == name { + return Ok(e.entry.clone()); + } + } + + Err(SvsmError::FileSystem(FsError::file_not_found())) + } + + fn create_file(&self, name: FileName) -> Result, SvsmError> { + if self.has_entry(&name) { + return Err(SvsmError::FileSystem(FsError::file_exists())); + } + + let new_file = Arc::new(RamFile::new()); + self.entries + .lock_write() + .push(DirectoryEntry::new(name, DirEntry::File(new_file.clone()))); + + Ok(new_file) + } + + fn create_directory(&self, name: FileName) -> Result, SvsmError> { + if self.has_entry(&name) { + return Err(SvsmError::FileSystem(FsError::file_exists())); + } + + let new_dir = Arc::new(RamDirectory::new()); + self.entries.lock_write().push(DirectoryEntry::new( + name, + DirEntry::Directory(new_dir.clone()), + )); + + Ok(new_dir) + } + + fn unlink(&self, name: FileName) -> Result<(), SvsmError> { + let mut vec = self.entries.lock_write(); + let pos = vec.iter().position(|e| e.name == name); + + match pos { + Some(idx) => { + vec.swap_remove(idx); + Ok(()) + } + None => Err(SvsmError::FileSystem(FsError::file_not_found())), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mm::alloc::{TestRootMem, DEFAULT_TEST_MEMORY_SIZE}; + + #[test] + fn test_ramfs_file_read_write() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let file = RamFile::new(); + let mut buf1 = [0xffu8; 512]; + + // Write first buffer at offset 0 + file.write(&buf1, 0).expect("Failed to write file data"); + assert!(file.size() == 512); + + // Write second buffer at offset 4096 - 256 - cross-page write + let mut buf2 = [0xaau8; 512]; + file.write(&buf2, PAGE_SIZE - 256) + .expect("Failed to write file cross-page"); + assert!(file.size() == PAGE_SIZE + 256); + + // Clear buffer before reading into it + buf1 = [0u8; 512]; + + // Read back and check first buffer + let size = file + .read(&mut buf1, 0) + .expect("Failed to read from offset 0"); + assert!(size == 512); + + for byte in buf1.iter() { + assert!(*byte == 0xff); + } + + // Clear buffer before reading into it + buf2 = [0u8; 512]; + + // Read back and check second buffer + let size = file + .read(&mut buf2, PAGE_SIZE - 256) + .expect("Failed to read from offset PAGE_SIZE - 256"); + assert!(size == 512); + + for byte in buf2.iter() { + assert!(*byte == 0xaa); + } + + // Check complete file + let mut buf3: [u8; 8192] = [0xcc; 8192]; + let size = file.read(&mut buf3, 0).expect("Failed to read whole file"); + assert!(size == PAGE_SIZE + 256); + + for (i, elem) in buf3.iter().enumerate() { + let expected: u8 = if i < 512 { + 0xff + } else if i < PAGE_SIZE - 256 { + 0 + } else if i < PAGE_SIZE + 256 { + 0xaa + } else { + 0xcc + }; + assert!(*elem == expected); + } + + assert_eq!(file.truncate(1024).unwrap(), 1024); + assert_eq!(file.size(), 1024); + + // Clear buffer before reading again into it + buf3 = [0u8; 8192]; + + // read file again + let size = file.read(&mut buf3, 0).expect("Failed to read whole file"); + assert!(size == 1024); + + for (i, elem) in buf3.iter().enumerate().take(1024) { + let expected: u8 = if i < 512 { 0xff } else { 0 }; + assert!(*elem == expected); + } + } + + #[test] + fn test_ram_directory() { + let f_name = FileName::from("file1"); + let d_name = FileName::from("dir1"); + + let ram_dir = RamDirectory::new(); + + ram_dir.create_file(f_name).expect("Failed to create file"); + ram_dir + .create_directory(d_name) + .expect("Failed to create directory"); + + let list = ram_dir.list(); + assert_eq!(list, [f_name, d_name]); + + let entry = ram_dir.lookup_entry(f_name).expect("Failed to lookup file"); + assert!(entry.is_file()); + + let entry = ram_dir + .lookup_entry(d_name) + .expect("Failed to lookup directory"); + assert!(entry.is_dir()); + + ram_dir.unlink(d_name).expect("Failed to unlink directory"); + + let list = ram_dir.list(); + assert_eq!(list, [f_name]); + } + + #[test] + fn test_ramfs_single_page_mapping() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let file = RamFile::new(); + let buf = [0xffu8; 512]; + + file.write(&buf, 0).expect("Failed to write file data"); + + let res = file + .mapping(0) + .expect("Failed to get mapping for ramfs page"); + assert_eq!( + res.phys_addr(), + file.rawfile.lock_read().pages[0].phys_addr() + ); + } + + #[test] + fn test_ramfs_multi_page_mapping() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let file = RamFile::new(); + let buf = [0xffu8; 4 * PAGE_SIZE]; + + file.write(&buf, 0).expect("Failed to write file data"); + + for i in 0..4 { + let res = file + .mapping(i * PAGE_SIZE) + .expect("Failed to get mapping for ramfs page"); + assert_eq!( + res.phys_addr(), + file.rawfile.lock_read().pages[i].phys_addr() + ); + } + } + + #[test] + fn test_ramfs_mapping_unaligned_offset() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let file = RamFile::new(); + let buf = [0xffu8; 4 * PAGE_SIZE]; + + file.write(&buf, 0).expect("Failed to write file data"); + + let res = file + .mapping(PAGE_SIZE + 0x123) + .expect("Failed to get mapping for ramfs page"); + assert_eq!( + res.phys_addr(), + file.rawfile.lock_read().pages[1].phys_addr() + ); + } + + #[test] + fn test_ramfs_mapping_out_of_range() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let file = RamFile::new(); + let buf = [0xffu8; 4 * PAGE_SIZE]; + + file.write(&buf, 0).expect("Failed to write file data"); + + let res = file.mapping(4 * PAGE_SIZE); + assert!(res.is_none()); + } +} diff --git a/stage2/src/fw_cfg.rs b/stage2/src/fw_cfg.rs new file mode 100644 index 000000000..ad2297839 --- /dev/null +++ b/stage2/src/fw_cfg.rs @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use crate::address::{Address, PhysAddr}; +use crate::error::SvsmError; +use crate::mm::pagetable::max_phys_addr; +use crate::utils::MemoryRegion; + +use super::io::IOPort; +use super::string::FixedString; +use alloc::vec::Vec; +use core::mem::size_of; + +const FW_CFG_CTL: u16 = 0x510; +const FW_CFG_DATA: u16 = 0x511; + +const _FW_CFG_ID: u16 = 0x01; +const FW_CFG_FILE_DIR: u16 = 0x19; + +// Must be a power-of-2 +const KERNEL_REGION_SIZE: u64 = 16 * 1024 * 1024; +const KERNEL_REGION_SIZE_MASK: u64 = !(KERNEL_REGION_SIZE - 1); + +const MAX_FW_CFG_FILES: u32 = 0x1000; + +//use crate::println; + +#[non_exhaustive] +#[derive(Debug)] +pub struct FwCfg<'a> { + driver: &'a dyn IOPort, +} + +#[derive(Clone, Copy, Debug)] +pub enum FwCfgError { + // Could not find the appropriate file selector. + FileNotFound, + // Unexpected file size. + FileSize(u32), + // Could not find an appropriate kernel region for the SVSM. + KernelRegion, + /// The firmware provided too many files to the guest + TooManyFiles, +} + +impl From for SvsmError { + fn from(err: FwCfgError) -> Self { + Self::FwCfg(err) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct FwCfgFile { + size: u32, + selector: u16, +} + +impl FwCfgFile { + pub fn size(&self) -> u32 { + self.size + } + pub fn selector(&self) -> u16 { + self.selector + } +} + +impl<'a> FwCfg<'a> { + pub fn new(driver: &'a dyn IOPort) -> Self { + FwCfg { driver } + } + + pub fn select(&self, cfg: u16) { + self.driver.outw(FW_CFG_CTL, cfg); + } + + pub fn read_bytes(&self, out: &mut [u8]) { + for byte in out.iter_mut() { + *byte = self.driver.inb(FW_CFG_DATA); + } + } + + pub fn read_le(&self) -> T + where + T: core::ops::Shl + core::ops::BitOr + From, + { + let mut val = T::from(0u8); + let io = &self.driver; + + for i in 0..size_of::() { + val = (T::from(io.inb(FW_CFG_DATA)) << (i * 8)) | val; + } + val + } + + pub fn read_be(&self) -> T + where + T: core::ops::Shl + core::ops::BitOr + From, + { + let mut val = T::from(0u8); + let io = &self.driver; + + for _ in 0..size_of::() { + val = (val << 8) | T::from(io.inb(FW_CFG_DATA)); + } + val + } + + pub fn read_char(&self) -> char { + self.driver.inb(FW_CFG_DATA) as char + } + + pub fn file_selector(&self, name: &str) -> Result { + self.select(FW_CFG_FILE_DIR); + let n: u32 = self.read_be(); + + if n > MAX_FW_CFG_FILES { + return Err(SvsmError::FwCfg(FwCfgError::TooManyFiles)); + } + + for _ in 0..n { + let size: u32 = self.read_be(); + let selector: u16 = self.read_be(); + let _unused: u16 = self.read_be(); + let mut fs = FixedString::<56>::new(); + for _ in 0..56 { + let c = self.read_char(); + fs.push(c); + } + + if fs == name { + return Ok(FwCfgFile { size, selector }); + } + } + + Err(SvsmError::FwCfg(FwCfgError::FileNotFound)) + } + + fn find_svsm_region(&self) -> Result, SvsmError> { + let file = self.file_selector("etc/sev/svsm")?; + + if file.size != 16 { + return Err(SvsmError::FwCfg(FwCfgError::FileSize(file.size))); + } + + self.select(file.selector); + Ok(self.read_memory_region()) + } + + fn read_memory_region(&self) -> MemoryRegion { + let start = PhysAddr::from(self.read_le::()); + let size = self.read_le::(); + let end = start.saturating_add(size as usize); + + assert!(start <= max_phys_addr(), "{start:#018x} is out of range"); + assert!(end <= max_phys_addr(), "{end:#018x} is out of range"); + + MemoryRegion::from_addresses(start, end) + } + + pub fn get_memory_regions(&self) -> Result>, SvsmError> { + let mut regions = Vec::new(); + let file = self.file_selector("etc/e820")?; + let entries = file.size / 20; + + self.select(file.selector); + + for _ in 0..entries { + let region = self.read_memory_region(); + let t: u32 = self.read_le(); + + if t == 1 { + regions.push(region); + } + } + + Ok(regions) + } + + fn find_kernel_region_e820(&self) -> Result, SvsmError> { + let regions = self.get_memory_regions()?; + let kernel_region = regions + .iter() + .max_by_key(|region| region.start()) + .ok_or(SvsmError::FwCfg(FwCfgError::KernelRegion))?; + + let start = PhysAddr::from( + kernel_region + .end() + .bits() + .saturating_sub(KERNEL_REGION_SIZE as usize) + & KERNEL_REGION_SIZE_MASK as usize, + ); + + if start < kernel_region.start() { + return Err(SvsmError::FwCfg(FwCfgError::KernelRegion)); + } + + Ok(MemoryRegion::new(start, kernel_region.len())) + } + + pub fn find_kernel_region(&self) -> Result, SvsmError> { + let kernel_region = self + .find_svsm_region() + .or_else(|_| self.find_kernel_region_e820())?; + + // Make sure that the kernel region doesn't overlap with the loader. + if kernel_region.start() < PhysAddr::from(640 * 1024u64) { + return Err(SvsmError::FwCfg(FwCfgError::KernelRegion)); + } + + Ok(kernel_region) + } + + // This needs to be &mut self to prevent iterator invalidation, where the caller + // could do fw_cfg.select() while iterating. Having a mutable reference prevents + // other references. + pub fn iter_flash_regions(&self) -> impl Iterator> + '_ { + let num = match self.file_selector("etc/flash") { + Ok(file) => { + self.select(file.selector); + file.size as usize / 16 + } + Err(_) => 0, + }; + + (0..num).map(|_| self.read_memory_region()) + } +} diff --git a/stage2/src/fw_meta.rs b/stage2/src/fw_meta.rs new file mode 100644 index 000000000..001de192c --- /dev/null +++ b/stage2/src/fw_meta.rs @@ -0,0 +1,436 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use crate::address::PhysAddr; +use crate::config::SvsmConfig; +use crate::cpu::percpu::current_ghcb; +use crate::error::SvsmError; +use crate::kernel_region::new_kernel_region; +use crate::mm::PerCPUPageMappingGuard; +use crate::platform::PageStateChangeOp; +use crate::sev::{pvalidate, rmp_adjust, PvalidateOp, RMPFlags}; +use crate::types::{PageSize, PAGE_SIZE}; +use crate::utils::{zero_mem_region, MemoryRegion}; +use alloc::vec::Vec; +use bootlib::kernel_launch::KernelLaunchInfo; +use zerocopy::{FromBytes, Immutable, KnownLayout}; + +use core::fmt; +use core::mem::{size_of, size_of_val}; +use core::str::FromStr; + +#[derive(Clone, Debug, Default)] +pub struct SevFWMetaData { + pub cpuid_page: Option, + pub secrets_page: Option, + pub caa_page: Option, + pub valid_mem: Vec>, +} + +impl SevFWMetaData { + pub const fn new() -> Self { + Self { + cpuid_page: None, + secrets_page: None, + caa_page: None, + valid_mem: Vec::new(), + } + } + + pub fn add_valid_mem(&mut self, base: PhysAddr, len: usize) { + self.valid_mem.push(MemoryRegion::new(base, len)); + } +} + +fn from_hex(c: char) -> Result { + match c.to_digit(16) { + Some(d) => Ok(d as u8), + None => Err(SvsmError::Firmware), + } +} + +#[derive(Copy, Clone, Debug)] +struct Uuid { + data: [u8; 16], +} + +impl Uuid { + pub const fn new() -> Self { + Uuid { data: [0; 16] } + } +} + +impl TryFrom<&[u8]> for Uuid { + type Error = (); + fn try_from(mem: &[u8]) -> Result { + let arr: &[u8; 16] = mem.try_into().map_err(|_| ())?; + Ok(Self::from(arr)) + } +} + +impl From<&[u8; 16]> for Uuid { + fn from(mem: &[u8; 16]) -> Self { + Self { + data: [ + mem[3], mem[2], mem[1], mem[0], mem[5], mem[4], mem[7], mem[6], mem[8], mem[9], + mem[10], mem[11], mem[12], mem[13], mem[14], mem[15], + ], + } + } +} + +impl FromStr for Uuid { + type Err = SvsmError; + fn from_str(s: &str) -> Result { + let mut uuid = Uuid::new(); + let mut buf: u8 = 0; + let mut index = 0; + + for c in s.chars() { + if !c.is_ascii_hexdigit() { + continue; + } + + if (index % 2) == 0 { + buf = from_hex(c)? << 4; + } else { + buf |= from_hex(c)?; + let i = index / 2; + if i >= 16 { + break; + } + uuid.data[i] = buf; + } + + index += 1; + } + + Ok(uuid) + } +} + +impl PartialEq for Uuid { + fn eq(&self, other: &Self) -> bool { + self.data.iter().zip(&other.data).all(|(a, b)| a == b) + } +} + +impl fmt::Display for Uuid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for i in 0..16 { + write!(f, "{:02x}", self.data[i])?; + if i == 3 || i == 5 || i == 7 || i == 9 { + write!(f, "-")?; + } + } + Ok(()) + } +} + +const OVMF_TABLE_FOOTER_GUID: &str = "96b582de-1fb2-45f7-baea-a366c55a082d"; +const OVMF_SEV_META_DATA_GUID: &str = "dc886566-984a-4798-a75e-5585a7bf67cc"; +const SVSM_INFO_GUID: &str = "a789a612-0597-4c4b-a49f-cbb1fe9d1ddd"; + +#[derive(Clone, Copy, Debug, FromBytes, KnownLayout, Immutable)] +#[repr(C, packed)] +struct SevMetaDataHeader { + signature: [u8; 4], + len: u32, + version: u32, + num_desc: u32, +} + +#[derive(Clone, Copy, Debug, FromBytes, KnownLayout, Immutable)] +#[repr(C, packed)] +struct SevMetaDataDesc { + base: u32, + len: u32, + t: u32, +} + +const SEV_META_DESC_TYPE_MEM: u32 = 1; +const SEV_META_DESC_TYPE_SECRETS: u32 = 2; +const SEV_META_DESC_TYPE_CPUID: u32 = 3; +const SEV_META_DESC_TYPE_CAA: u32 = 4; + +#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)] +#[repr(C, packed)] +struct RawMetaHeader { + len: u16, + uuid: [u8; size_of::()], +} + +impl RawMetaHeader { + fn data_len(&self) -> Option { + let full_len = self.len as usize; + full_len.checked_sub(size_of::()) + } +} + +#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)] +#[repr(C, packed)] +struct RawMetaBuffer { + data: [u8; PAGE_SIZE - size_of::() - 32], + header: RawMetaHeader, + _pad: [u8; 32], +} + +// Compile-time size checks +const _: () = assert!(size_of::() == PAGE_SIZE); +const _: () = assert!(size_of::() == size_of::() + size_of::()); + +/// Find a table with the given UUID in the given memory slice, and return a +/// subslice into its data +fn find_table<'a>(uuid: &Uuid, mem: &'a [u8]) -> Option<&'a [u8]> { + let mut idx = mem.len(); + + while idx != 0 { + let hdr_start = idx.checked_sub(size_of::())?; + let hdr = RawMetaHeader::ref_from_bytes(&mem[hdr_start..idx]).unwrap(); + + let data_len = hdr.data_len()?; + idx = hdr_start.checked_sub(data_len)?; + + let raw_uuid = hdr.uuid; + let curr_uuid = Uuid::from(&raw_uuid); + if *uuid == curr_uuid { + return Some(&mem[idx..idx + data_len]); + } + } + + None +} + +/// Parse the firmware metadata from the given slice. +pub fn parse_fw_meta_data(mem: &[u8]) -> Result { + let mut meta_data = SevFWMetaData::new(); + + let raw_meta = RawMetaBuffer::ref_from_bytes(mem).map_err(|_| SvsmError::Firmware)?; + + // Check the UUID + let raw_uuid = raw_meta.header.uuid; + let uuid = Uuid::from(&raw_uuid); + let meta_uuid = Uuid::from_str(OVMF_TABLE_FOOTER_GUID)?; + if uuid != meta_uuid { + return Err(SvsmError::Firmware); + } + + // Get the tables and their length + let data_len = raw_meta.header.data_len().ok_or(SvsmError::Firmware)?; + let data_start = size_of_val(&raw_meta.data) + .checked_sub(data_len) + .ok_or(SvsmError::Firmware)?; + let raw_data = raw_meta.data.get(data_start..).ok_or(SvsmError::Firmware)?; + + // First check if this is the SVSM itself instead of OVMF + let svsm_info_uuid = Uuid::from_str(SVSM_INFO_GUID)?; + if find_table(&svsm_info_uuid, raw_data).is_some() { + return Err(SvsmError::Firmware); + } + + // Search and parse SEV metadata + parse_sev_meta(&mut meta_data, raw_meta, raw_data)?; + + // Verify that the required elements are present. + if meta_data.cpuid_page.is_none() { + log::error!("FW does not specify CPUID_PAGE location"); + return Err(SvsmError::Firmware); + } + + Ok(meta_data) +} + +fn parse_sev_meta( + meta: &mut SevFWMetaData, + raw_meta: &RawMetaBuffer, + raw_data: &[u8], +) -> Result<(), SvsmError> { + // Find SEV metadata table + let sev_meta_uuid = Uuid::from_str(OVMF_SEV_META_DATA_GUID)?; + let Some(tbl) = find_table(&sev_meta_uuid, raw_data) else { + log::warn!("Could not find SEV metadata in firmware"); + return Ok(()); + }; + + // Find the location of the metadata header. We need to adjust the offset + // since it is computed by taking into account the trailing header and + // padding, and it is computed backwards. + let bytes: [u8; 4] = tbl.try_into().map_err(|_| SvsmError::Firmware)?; + let sev_meta_offset = (u32::from_le_bytes(bytes) as usize) + .checked_sub(size_of_val(&raw_meta.header) + size_of_val(&raw_meta._pad)) + .ok_or(SvsmError::Firmware)?; + // Now compute the start and end of the SEV metadata header + let sev_meta_start = size_of_val(&raw_meta.data) + .checked_sub(sev_meta_offset) + .ok_or(SvsmError::Firmware)?; + let sev_meta_end = sev_meta_start + size_of::(); + // Bounds check the header and get a pointer to it + let bytes = raw_meta + .data + .get(sev_meta_start..sev_meta_end) + .ok_or(SvsmError::Firmware)?; + let sev_meta_hdr = SevMetaDataHeader::ref_from_bytes(bytes).map_err(|_| SvsmError::Firmware)?; + + // Now find the descriptors + let bytes = &raw_meta.data[sev_meta_end..]; + let num_desc = sev_meta_hdr.num_desc as usize; + let (descs, _) = <[SevMetaDataDesc]>::ref_from_prefix_with_elems(bytes, num_desc) + .map_err(|_| SvsmError::Firmware)?; + + for desc in descs { + let t = desc.t; + let base = PhysAddr::from(desc.base as usize); + let len = desc.len as usize; + + match t { + SEV_META_DESC_TYPE_MEM => meta.add_valid_mem(base, len), + SEV_META_DESC_TYPE_SECRETS => { + if len != PAGE_SIZE { + return Err(SvsmError::Firmware); + } + meta.secrets_page = Some(base); + } + SEV_META_DESC_TYPE_CPUID => { + if len != PAGE_SIZE { + return Err(SvsmError::Firmware); + } + meta.cpuid_page = Some(base); + } + SEV_META_DESC_TYPE_CAA => { + if len != PAGE_SIZE { + return Err(SvsmError::Firmware); + } + meta.caa_page = Some(base); + } + _ => log::info!("Unknown metadata item type: {}", t), + } + } + + Ok(()) +} + +fn validate_fw_mem_region( + config: &SvsmConfig<'_>, + region: MemoryRegion, +) -> Result<(), SvsmError> { + let pstart = region.start(); + let pend = region.end(); + + log::info!("Validating {:#018x}-{:#018x}", pstart, pend); + + if config.page_state_change_required() { + current_ghcb() + .page_state_change(region, PageSize::Regular, PageStateChangeOp::Private) + .expect("GHCB PSC call failed to validate firmware memory"); + } + + for paddr in region.iter_pages(PageSize::Regular) { + let guard = PerCPUPageMappingGuard::create_4k(paddr)?; + let vaddr = guard.virt_addr(); + + pvalidate(vaddr, PageSize::Regular, PvalidateOp::Valid)?; + + // Make page accessible to guest VMPL + rmp_adjust( + vaddr, + RMPFlags::GUEST_VMPL | RMPFlags::RWX, + PageSize::Regular, + )?; + + zero_mem_region(vaddr, vaddr + PAGE_SIZE); + } + + Ok(()) +} + +fn validate_fw_memory_vec( + config: &SvsmConfig<'_>, + regions: Vec>, +) -> Result<(), SvsmError> { + if regions.is_empty() { + return Ok(()); + } + + let mut next_vec = Vec::new(); + let mut region = regions[0]; + + for next in regions.into_iter().skip(1) { + if region.contiguous(&next) { + region = region.merge(&next); + } else { + next_vec.push(next); + } + } + + validate_fw_mem_region(config, region)?; + validate_fw_memory_vec(config, next_vec) +} + +pub fn validate_fw_memory( + config: &SvsmConfig<'_>, + fw_meta: &SevFWMetaData, + launch_info: &KernelLaunchInfo, +) -> Result<(), SvsmError> { + // Initalize vector with regions from the FW + let mut regions = fw_meta.valid_mem.clone(); + + // Add region for CPUID page if present + if let Some(cpuid_paddr) = fw_meta.cpuid_page { + regions.push(MemoryRegion::new(cpuid_paddr, PAGE_SIZE)); + } + + // Add region for Secrets page if present + if let Some(secrets_paddr) = fw_meta.secrets_page { + regions.push(MemoryRegion::new(secrets_paddr, PAGE_SIZE)); + } + + // Add region for CAA page if present + if let Some(caa_paddr) = fw_meta.caa_page { + regions.push(MemoryRegion::new(caa_paddr, PAGE_SIZE)); + } + + // Sort regions by base address + regions.sort_unstable_by_key(|a| a.start()); + + let kernel_region = new_kernel_region(launch_info); + for region in regions.iter() { + if region.overlap(&kernel_region) { + log::error!("FwMeta region ovelaps with kernel"); + return Err(SvsmError::Firmware); + } + } + + validate_fw_memory_vec(config, regions) +} + +pub fn print_fw_meta(fw_meta: &SevFWMetaData) { + log::info!("FW Meta Data"); + + match fw_meta.cpuid_page { + Some(addr) => log::info!(" CPUID Page : {:#010x}", addr), + None => log::info!(" CPUID Page : None"), + }; + + match fw_meta.secrets_page { + Some(addr) => log::info!(" Secrets Page : {:#010x}", addr), + None => log::info!(" Secrets Page : None"), + }; + + match fw_meta.caa_page { + Some(addr) => log::info!(" CAA Page : {:#010x}", addr), + None => log::info!(" CAA Page : None"), + }; + + for region in &fw_meta.valid_mem { + log::info!( + " Pre-Validated Region {:#018x}-{:#018x}", + region.start(), + region.end() + ); + } +} diff --git a/stage2/src/greq/driver.rs b/stage2/src/greq/driver.rs new file mode 100644 index 000000000..e58d5638e --- /dev/null +++ b/stage2/src/greq/driver.rs @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! Driver to send `SNP_GUEST_REQUEST` commands to the PSP. It can be any of the +//! request or response command types defined in the SEV-SNP spec, regardless if it's +//! a regular or an extended command. + +extern crate alloc; + +use alloc::boxed::Box; +use core::{cell::OnceCell, mem::size_of}; +use zerocopy::FromZeros; + +use crate::mm::alloc::AllocError; +use crate::mm::page_visibility::SharedBox; +use crate::{ + cpu::percpu::current_ghcb, + error::SvsmError, + greq::msg::{SnpGuestRequestExtData, SnpGuestRequestMsg, SnpGuestRequestMsgType}, + locking::SpinLock, + protocols::errors::{SvsmReqError, SvsmResultCode}, + sev::{ghcb::GhcbError, secrets_page, secrets_page_mut, VMPCK_SIZE}, + types::PAGE_SHIFT, + BIT, +}; + +/// Global `SNP_GUEST_REQUEST` driver instance +static GREQ_DRIVER: SpinLock> = SpinLock::new(OnceCell::new()); + +// Hypervisor error codes + +/// Buffer provided is too small +const SNP_GUEST_REQ_INVALID_LEN: u64 = BIT!(32); +/// Hypervisor busy, try again +const SNP_GUEST_REQ_ERR_BUSY: u64 = BIT!(33); + +/// Class of the `SNP_GUEST_REQUEST` command: Regular or Extended +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +enum SnpGuestRequestClass { + Regular = 0, + Extended = 1, +} + +/// `SNP_GUEST_REQUEST` driver +#[derive(Debug)] +struct SnpGuestRequestDriver { + /// Shared page used for the `SNP_GUEST_REQUEST` request + request: SharedBox, + /// Shared page used for the `SNP_GUEST_REQUEST` response + response: SharedBox, + /// Encrypted page where we perform crypto operations + staging: Box, + /// Extended data buffer that will be provided to the hypervisor + /// to store the SEV-SNP certificates + ext_data: SharedBox, + /// Extended data size (`certs` size) provided by the user in [`super::services::get_extended_report`]. + /// It will be provided to the hypervisor. + user_extdata_size: usize, + /// Each `SNP_GUEST_REQUEST` message contains a sequence number per VMPCK. + /// The sequence number is incremented with each message sent. Messages + /// sent by the guest to the PSP and by the PSP to the guest must be + /// delivered in order. If not, the PSP will reject subsequent messages + /// by the guest when it detects that the sequence numbers are out of sync. + /// + /// NOTE: If the vmpl field of a `SNP_GUEST_REQUEST` message is set to VMPL0, + /// then it must contain the VMPL0 sequence number and be protected (encrypted) + /// with the VMPCK0 key; additionally, if this message fails, the VMPCK0 key + /// must be disabled. The same idea applies to the other VMPL levels. + /// + /// The SVSM needs to support only VMPL0 `SNP_GUEST_REQUEST` commands because + /// other layers in the software stack (e.g. OVMF and guest kernel) can send + /// non-VMPL0 commands directly to PSP. Therefore, the SVSM needs to maintain + /// the sequence number and the VMPCK only for VMPL0. + vmpck0_seqno: u64, +} + +impl SnpGuestRequestDriver { + /// Create a new [`SnpGuestRequestDriver`] + pub fn new() -> Result { + let request = SharedBox::try_new_zeroed()?; + let response = SharedBox::try_new_zeroed()?; + let staging = SnpGuestRequestMsg::new_box_zeroed() + .map_err(|_| SvsmError::Alloc(AllocError::OutOfMemory))?; + let ext_data = SharedBox::try_new_zeroed()?; + + Ok(Self { + request, + response, + staging, + ext_data, + user_extdata_size: size_of::(), + vmpck0_seqno: 0, + }) + } + + /// Get the last VMPCK0 sequence number accounted + fn seqno_last_used(&self) -> u64 { + self.vmpck0_seqno + } + + /// Increase the VMPCK0 sequence number by two. In order to keep the + /// sequence number in-sync with the PSP, this is called only when the + /// `SNP_GUEST_REQUEST` response is received. + fn seqno_add_two(&mut self) { + self.vmpck0_seqno += 2; + } + + /// Set the user_extdata_size to `n` and clear the first `n` bytes from `ext_data` + pub fn set_user_extdata_size(&mut self, n: usize) -> Result<(), SvsmReqError> { + // At least one page + if (n >> PAGE_SHIFT) == 0 { + return Err(SvsmReqError::invalid_parameter()); + } + self.ext_data.nclear(n)?; + self.user_extdata_size = n; + + Ok(()) + } + + /// Call the GHCB layer to send the encrypted SNP_GUEST_REQUEST message + /// to the PSP. + fn send(&mut self, req_class: SnpGuestRequestClass) -> Result<(), SvsmReqError> { + let req_page = self.request.addr(); + let resp_page = self.response.addr(); + let data_pages = self.ext_data.addr(); + let ghcb = current_ghcb(); + + if req_class == SnpGuestRequestClass::Extended { + let num_user_pages = (self.user_extdata_size >> PAGE_SHIFT) as u64; + ghcb.guest_ext_request(req_page, resp_page, data_pages, num_user_pages)?; + } else { + ghcb.guest_request(req_page, resp_page)?; + } + + self.seqno_add_two(); + + Ok(()) + } + + // Encrypt the request message from encrypted memory + fn encrypt_request( + &mut self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + buffer: &[u8], + command_len: usize, + ) -> Result<(), SvsmReqError> { + // VMPL0 `SNP_GUEST_REQUEST` commands are encrypted with the VMPCK0 key + let vmpck0: [u8; VMPCK_SIZE] = secrets_page().get_vmpck(0); + + let inbuf = buffer + .get(..command_len) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + // For security reasons, encrypt the message in protected memory (staging) + // and then copy the result to shared memory (request) + self.staging + .encrypt_set(msg_type, msg_seqno, &vmpck0, inbuf)?; + self.request.write_from(&self.staging); + Ok(()) + } + + // Decrypt the response message from encrypted memory + fn decrypt_response( + &mut self, + msg_seqno: u64, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + ) -> Result { + let vmpck0: [u8; VMPCK_SIZE] = secrets_page().get_vmpck(0); + + // For security reasons, decrypt the message in protected memory (staging) + self.response.read_into(&mut self.staging); + let result = self + .staging + .decrypt_get(msg_type, msg_seqno, &vmpck0, buffer); + + if let Err(e) = result { + match e { + // The buffer provided is too small to store the unwrapped response. + // There is no need to clear the VMPCK0, just report it as invalid parameter. + SvsmReqError::RequestError(SvsmResultCode::INVALID_PARAMETER) => (), + _ => secrets_page_mut().clear_vmpck(0), + } + } + + result + } + + /// Send the provided VMPL0 `SNP_GUEST_REQUEST` command to the PSP. + /// + /// The command will be encrypted using AES-256 GCM. + /// + /// # Arguments + /// + /// * `req_class`: whether this is a regular or extended `SNP_GUEST_REQUEST` command + /// * `msg_type`: type of the command stored in `buffer`, e.g. SNP_MSG_REPORT_REQ + /// * `buffer`: buffer with the `SNP_GUEST_REQUEST` command to be sent. + /// The same buffer will also be used to store the response. + /// * `command_len`: Size (in bytes) of the command stored in `buffer` + /// + /// # Returns + /// + /// * Success: + /// * `usize`: Size (in bytes) of the response stored in `buffer` + /// * Error: + /// * [`SvsmReqError`] + fn send_request( + &mut self, + req_class: SnpGuestRequestClass, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + command_len: usize, + ) -> Result { + if secrets_page().is_vmpck_clear(0) { + return Err(SvsmReqError::invalid_request()); + } + + // Message sequence number overflow, the driver will not able + // to send subsequent `SNP_GUEST_REQUEST` messages to the PSP. + // The sequence number is restored only when the guest is rebooted. + let Some(msg_seqno) = self.seqno_last_used().checked_add(1) else { + log::error!("SNP_GUEST_REQUEST: sequence number overflow"); + secrets_page_mut().clear_vmpck(0); + return Err(SvsmReqError::invalid_request()); + }; + + self.encrypt_request(msg_type, msg_seqno, buffer, command_len)?; + + if let Err(e) = self.send(req_class) { + if let SvsmReqError::FatalError(SvsmError::Ghcb(GhcbError::VmgexitError(_rbx, info2))) = + e + { + // For some reason the hypervisor did not forward the request to the PSP. + // + // Because the message sequence number is used as part of the AES-GCM IV, it is important that the + // guest retry the request before allowing another request to be performed so that the IV cannot be + // reused on a new message payload. + match info2 & 0xffff_ffff_0000_0000u64 { + // The certificate buffer provided is too small. + SNP_GUEST_REQ_INVALID_LEN => { + if req_class == SnpGuestRequestClass::Extended { + if let Err(e1) = self.send(SnpGuestRequestClass::Regular) { + log::error!( + "SNP_GUEST_REQ_INVALID_LEN. Aborting, request resend failed" + ); + secrets_page_mut().clear_vmpck(0); + return Err(e1); + } + return Err(e); + } else { + // We sent a regular SNP_GUEST_REQUEST, but the hypervisor returned + // an error code that is exclusive for extended SNP_GUEST_REQUEST + secrets_page_mut().clear_vmpck(0); + return Err(SvsmReqError::invalid_request()); + } + } + // The hypervisor is busy. + SNP_GUEST_REQ_ERR_BUSY => { + if let Err(e2) = self.send(req_class) { + log::error!("SNP_GUEST_REQ_ERR_BUSY. Aborting, request resend failed"); + secrets_page_mut().clear_vmpck(0); + return Err(e2); + } + // ... request resend worked, continue normally. + } + // Failed for unknown reason. Status codes can be found in + // the AMD SEV-SNP spec or in the linux kernel include/uapi/linux/psp-sev.h + _ => { + log::error!("SNP_GUEST_REQUEST failed, unknown error code={}\n", info2); + secrets_page_mut().clear_vmpck(0); + return Err(e); + } + } + } + } + + let msg_seqno = self.seqno_last_used(); + let resp_msg_type = SnpGuestRequestMsgType::try_from(msg_type as u8 + 1)?; + + self.decrypt_response(msg_seqno, resp_msg_type, buffer) + } + + /// Send the provided regular `SNP_GUEST_REQUEST` command to the PSP + pub fn send_regular_guest_request( + &mut self, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + command_len: usize, + ) -> Result { + self.send_request(SnpGuestRequestClass::Regular, msg_type, buffer, command_len) + } + + /// Send the provided extended `SNP_GUEST_REQUEST` command to the PSP + pub fn send_extended_guest_request( + &mut self, + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + command_len: usize, + certs: &mut [u8], + ) -> Result { + self.set_user_extdata_size(certs.len())?; + + let outbuf_len: usize = self.send_request( + SnpGuestRequestClass::Extended, + msg_type, + buffer, + command_len, + )?; + + // The SEV-SNP certificates can be used to verify the attestation report. + self.ext_data.copy_to_slice(certs)?; + // At this point, a zeroed ext_data buffer indicates that the + // certificates were not imported. The VM owner can import them from the + // host using the virtee/snphost project + if certs[..24] == [0; 24] { + log::warn!("SEV-SNP certificates not found. Make sure they were loaded from the host."); + } + + Ok(outbuf_len) + } +} + +/// Initialize the global `SnpGuestRequestDriver` +/// +/// # Panics +/// +/// This function panics if we fail to initialize any of the `SnpGuestRequestDriver` fields. +pub fn guest_request_driver_init() { + let cell = GREQ_DRIVER.lock(); + let _ = cell.get_or_init(|| { + SnpGuestRequestDriver::new().expect("SnpGuestRequestDriver failed to initialize") + }); +} + +/// Send the provided regular `SNP_GUEST_REQUEST` command to the PSP. +/// Further details can be found in the `SnpGuestRequestDriver.send_request()` documentation. +pub fn send_regular_guest_request( + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + request_len: usize, +) -> Result { + let mut cell = GREQ_DRIVER.lock(); + let driver: &mut SnpGuestRequestDriver = + cell.get_mut().ok_or_else(SvsmReqError::invalid_request)?; + driver.send_regular_guest_request(msg_type, buffer, request_len) +} + +/// Send the provided extended `SNP_GUEST_REQUEST` command to the PSP +/// Further details can be found in the `SnpGuestRequestDriver.send_request()` documentation. +pub fn send_extended_guest_request( + msg_type: SnpGuestRequestMsgType, + buffer: &mut [u8], + request_len: usize, + certs: &mut [u8], +) -> Result { + let mut cell = GREQ_DRIVER.lock(); + let driver: &mut SnpGuestRequestDriver = + cell.get_mut().ok_or_else(SvsmReqError::invalid_request)?; + driver.send_extended_guest_request(msg_type, buffer, request_len, certs) +} diff --git a/stage2/src/greq/mod.rs b/stage2/src/greq/mod.rs new file mode 100644 index 000000000..56aca7547 --- /dev/null +++ b/stage2/src/greq/mod.rs @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! `SNP_GUEST_REQUEST` mechanism to communicate with the PSP + +pub mod driver; +pub mod msg; +pub mod pld_report; +pub mod services; diff --git a/stage2/src/greq/msg.rs b/stage2/src/greq/msg.rs new file mode 100644 index 000000000..a82c2c4c1 --- /dev/null +++ b/stage2/src/greq/msg.rs @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! Message that carries an encrypted `SNP_GUEST_REQUEST` command in the payload + +use core::mem::{offset_of, size_of}; + +use crate::{ + crypto::aead::{Aes256Gcm, Aes256GcmTrait, AUTHTAG_SIZE, IV_SIZE}, + protocols::errors::SvsmReqError, + sev::secrets_page::VMPCK_SIZE, + types::PAGE_SIZE, +}; + +use zerocopy::{FromBytes, Immutable, IntoBytes}; + +/// Version of the message header +const HDR_VERSION: u8 = 1; +/// Version of the message payload +const MSG_VERSION: u8 = 1; + +/// AEAD Algorithm Encodings (AMD SEV-SNP spec. table 99) +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum SnpGuestRequestAead { + Invalid = 0, + Aes256Gcm = 1, +} + +/// Message Type Encodings (AMD SEV-SNP spec. table 100) +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum SnpGuestRequestMsgType { + Invalid = 0, + ReportRequest = 5, + ReportResponse = 6, +} + +impl TryFrom for SnpGuestRequestMsgType { + type Error = SvsmReqError; + + fn try_from(v: u8) -> Result { + match v { + x if x == Self::Invalid as u8 => Ok(Self::Invalid), + x if x == Self::ReportRequest as u8 => Ok(Self::ReportRequest), + x if x == Self::ReportResponse as u8 => Ok(Self::ReportResponse), + _ => Err(SvsmReqError::invalid_parameter()), + } + } +} + +/// Message header size +const MSG_HDR_SIZE: usize = size_of::(); +/// Message payload size +const MSG_PAYLOAD_SIZE: usize = PAGE_SIZE - MSG_HDR_SIZE; + +/// Maximum buffer size that the hypervisor takes to store the +/// SEV-SNP certificates +pub const SNP_GUEST_REQ_MAX_DATA_SIZE: usize = 4 * PAGE_SIZE; + +/// `SNP_GUEST_REQUEST` message header format (AMD SEV-SNP spec. table 98) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, FromBytes, IntoBytes, Immutable)] +pub struct SnpGuestRequestMsgHdr { + /// Message authentication tag + authtag: [u8; 32], + /// The sequence number for this message + msg_seqno: u64, + /// Reserve. Must be zero. + rsvd1: [u8; 8], + /// The AEAD used to encrypt this message + algo: u8, + /// The version of the message header + hdr_version: u8, + /// The size of the message header in bytes + hdr_sz: u16, + /// The type of the payload + msg_type: u8, + /// The version of the payload + msg_version: u8, + /// The size of the payload in bytes + msg_sz: u16, + /// Reserved. Must be zero. + rsvd2: u32, + /// The ID of the VMPCK used to protect this message + msg_vmpck: u8, + /// Reserved. Must be zero. + rsvd3: [u8; 35], +} + +const _: () = assert!(size_of::() <= u16::MAX as usize); + +impl SnpGuestRequestMsgHdr { + /// Allocate a new [`SnpGuestRequestMsgHdr`] and initialize it + pub fn new(msg_sz: u16, msg_type: SnpGuestRequestMsgType, msg_seqno: u64) -> Self { + Self { + msg_seqno, + algo: SnpGuestRequestAead::Aes256Gcm as u8, + hdr_version: HDR_VERSION, + hdr_sz: MSG_HDR_SIZE as u16, + msg_type: msg_type as u8, + msg_version: MSG_VERSION, + msg_sz, + msg_vmpck: 0, + ..Default::default() + } + } + + /// Set the authenticated tag + fn set_authtag(&mut self, new_tag: &[u8]) -> Result<(), SvsmReqError> { + self.authtag + .get_mut(..new_tag.len()) + .ok_or_else(SvsmReqError::invalid_parameter)? + .copy_from_slice(new_tag); + Ok(()) + } + + /// Validate the [`SnpGuestRequestMsgHdr`] fields + fn validate( + &self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + ) -> Result<(), SvsmReqError> { + if self.hdr_version != HDR_VERSION + || self.hdr_sz != MSG_HDR_SIZE as u16 + || self.algo != SnpGuestRequestAead::Aes256Gcm as u8 + || self.msg_type != msg_type as u8 + || self.msg_vmpck != 0 + || self.msg_seqno != msg_seqno + { + return Err(SvsmReqError::invalid_format()); + } + Ok(()) + } + + /// Get a slice of the header fields used as additional authenticated data (AAD) + fn get_aad_slice(&self) -> &[u8] { + let algo_offset = offset_of!(Self, algo); + &self.as_bytes()[algo_offset..] + } +} + +impl Default for SnpGuestRequestMsgHdr { + /// default() method implementation. We can't derive Default because + /// the field "rsvd3: [u8; 35]" conflicts with the Default trait, which + /// supports up to [T; 32]. + fn default() -> Self { + Self { + authtag: [0; 32], + msg_seqno: 0, + rsvd1: [0; 8], + algo: 0, + hdr_version: 0, + hdr_sz: 0, + msg_type: 0, + msg_version: 0, + msg_sz: 0, + rsvd2: 0, + msg_vmpck: 0, + rsvd3: [0; 35], + } + } +} + +/// `SNP_GUEST_REQUEST` message format +#[repr(C, align(4096))] +#[derive(Clone, Copy, Debug, FromBytes)] +pub struct SnpGuestRequestMsg { + hdr: SnpGuestRequestMsgHdr, + pld: [u8; MSG_PAYLOAD_SIZE], +} + +// The GHCB spec says it has to fit in one page and be page aligned +const _: () = assert!(size_of::() <= PAGE_SIZE); + +impl SnpGuestRequestMsg { + /// Encrypt the provided `SNP_GUEST_REQUEST` command and store the result in the actual message payload + /// + /// The command will be encrypted using AES-256 GCM and part of the message header will be + /// used as additional authenticated data (AAD). + /// + /// # Arguments + /// + /// * `msg_type`: Type of the command stored in the `command` buffer. + /// * `msg_seqno`: VMPL0 sequence number to be used in the message. The PSP will reject + /// subsequent messages when it detects that the sequence numbers are + /// out of sync. The sequence number is also used as initialization + /// vector (IV) in encryption. + /// * `vmpck0`: VMPCK0 key that will be used to encrypt the command. + /// * `command`: command slice to be encrypted. + /// + /// # Returns + /// + /// () on success and [`SvsmReqError`] on error. + /// + /// # Panic + /// + /// * The command length does not fit in a u16 + /// * The encrypted and the original command don't have the same size + pub fn encrypt_set( + &mut self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + vmpck0: &[u8; VMPCK_SIZE], + command: &[u8], + ) -> Result<(), SvsmReqError> { + let payload_size_u16 = + u16::try_from(command.len()).map_err(|_| SvsmReqError::invalid_parameter())?; + + let mut msg_hdr = SnpGuestRequestMsgHdr::new(payload_size_u16, msg_type, msg_seqno); + let aad: &[u8] = msg_hdr.get_aad_slice(); + let iv: [u8; IV_SIZE] = build_iv(msg_seqno); + + self.pld.fill(0); + + // Encrypt the provided command and store the result in the message payload + let authtag_end: usize = Aes256Gcm::encrypt(&iv, vmpck0, aad, command, &mut self.pld)?; + + // In the Aes256Gcm encrypt API, the authtag is postfixed (comes after the encrypted payload) + let ciphertext_end: usize = authtag_end - AUTHTAG_SIZE; + let authtag = self + .pld + .get_mut(ciphertext_end..authtag_end) + .ok_or_else(SvsmReqError::invalid_request)?; + + // The command should have the same size when encrypted and decrypted + assert_eq!(command.len(), ciphertext_end); + + // Move the authtag to the message header + msg_hdr.set_authtag(authtag)?; + authtag.fill(0); + + self.hdr = msg_hdr; + + Ok(()) + } + + /// Decrypt the `SNP_GUEST_REQUEST` command stored in the message and store the decrypted command in + /// the provided `outbuf`. + /// + /// The command stored in the message payload is usually a response command received from the PSP. + /// It will be decrypted using AES-256 GCM and part of the message header will be used as + /// additional authenticated data (AAD). + /// + /// # Arguments + /// + /// * `msg_type`: Type of the command stored in the message payload + /// * `msg_seqno`: VMPL0 sequence number that was used in the message. + /// * `vmpck0`: VMPCK0 key, it will be used to decrypt the message + /// * `outbuf`: buffer that will be used to store the decrypted message payload + /// + /// # Returns + /// + /// * Success + /// * usize: Number of bytes written to `outbuf` + /// * Error + /// * [`SvsmReqError`] + pub fn decrypt_get( + &mut self, + msg_type: SnpGuestRequestMsgType, + msg_seqno: u64, + vmpck0: &[u8; VMPCK_SIZE], + outbuf: &mut [u8], + ) -> Result { + self.hdr.validate(msg_type, msg_seqno)?; + + let iv: [u8; IV_SIZE] = build_iv(msg_seqno); + let aad: &[u8] = self.hdr.get_aad_slice(); + + // In the Aes256Gcm decrypt API, the authtag must be provided postfix in the inbuf + let ciphertext_end = usize::from(self.hdr.msg_sz); + let tag_end: usize = ciphertext_end + AUTHTAG_SIZE; + + // The message payload must be large enough to hold the ciphertext and + // the authentication tag. + let hdr_tag = self + .hdr + .authtag + .get(..AUTHTAG_SIZE) + .ok_or_else(SvsmReqError::invalid_request)?; + let pld_tag = self + .pld + .get_mut(ciphertext_end..tag_end) + .ok_or_else(SvsmReqError::invalid_request)?; + pld_tag.copy_from_slice(hdr_tag); + + // Payload with postfixed authtag + let inbuf = self + .pld + .get(..tag_end) + .ok_or_else(SvsmReqError::invalid_request)?; + + let outbuf_len: usize = Aes256Gcm::decrypt(&iv, vmpck0, aad, inbuf, outbuf)?; + + Ok(outbuf_len) + } +} + +/// Build the initialization vector for AES-256 GCM +fn build_iv(msg_seqno: u64) -> [u8; IV_SIZE] { + const U64_SIZE: usize = size_of::(); + let mut iv = [0u8; IV_SIZE]; + + iv[..U64_SIZE].copy_from_slice(&msg_seqno.to_ne_bytes()); + iv +} + +/// Data page(s) the hypervisor will use to store certificate data in +/// an extended `SNP_GUEST_REQUEST` +pub type SnpGuestRequestExtData = [u8; SNP_GUEST_REQ_MAX_DATA_SIZE]; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_snp_guest_request_hdr_offsets() { + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, authtag), 0); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, msg_seqno), 0x20); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, rsvd1), 0x28); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, algo), 0x30); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, hdr_version), 0x31); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, hdr_sz), 0x32); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, msg_type), 0x34); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, msg_version), 0x35); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, msg_sz), 0x36); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, rsvd2), 0x38); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, msg_vmpck), 0x3c); + assert_eq!(offset_of!(SnpGuestRequestMsgHdr, rsvd3), 0x3d); + } + + #[test] + fn test_snp_guest_request_msg_offsets() { + assert_eq!(offset_of!(SnpGuestRequestMsg, hdr), 0); + assert_eq!(offset_of!(SnpGuestRequestMsg, pld), 0x60); + } + + #[test] + fn aad_size() { + let hdr = SnpGuestRequestMsgHdr::default(); + let aad = hdr.get_aad_slice(); + + const HDR_ALGO_OFFSET: usize = 48; + + assert_eq!(aad.len(), MSG_HDR_SIZE - HDR_ALGO_OFFSET); + } + + #[test] + fn encrypt_decrypt_payload() { + let mut msg = SnpGuestRequestMsg { + hdr: SnpGuestRequestMsgHdr::default(), + pld: [0; MSG_PAYLOAD_SIZE], + }; + + const PLAINTEXT: &[u8] = b"request-to-be-encrypted"; + let vmpck0 = [5u8; VMPCK_SIZE]; + let vmpck0_seqno: u64 = 1; + + msg.encrypt_set( + SnpGuestRequestMsgType::ReportRequest, + vmpck0_seqno, + &vmpck0, + PLAINTEXT, + ) + .unwrap(); + + let mut outbuf = [0u8; PLAINTEXT.len()]; + + let outbuf_len = msg + .decrypt_get( + SnpGuestRequestMsgType::ReportRequest, + vmpck0_seqno, + &vmpck0, + &mut outbuf, + ) + .unwrap(); + + assert_eq!(outbuf_len, PLAINTEXT.len()); + + assert_eq!(outbuf, PLAINTEXT); + } +} diff --git a/stage2/src/greq/pld_report.rs b/stage2/src/greq/pld_report.rs new file mode 100644 index 000000000..8599918fd --- /dev/null +++ b/stage2/src/greq/pld_report.rs @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! `SNP_GUEST_REQUEST` command to request an attestation report. + +use core::mem::size_of; + +use zerocopy::{FromBytes, Immutable, KnownLayout}; + +use crate::protocols::errors::SvsmReqError; + +/// Size of the `SnpReportRequest.user_data` +pub const USER_DATA_SIZE: usize = 64; + +/// MSG_REPORT_REQ payload format (AMD SEV-SNP spec. table 20) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, FromBytes, KnownLayout, Immutable)] +pub struct SnpReportRequest { + /// Guest-provided data to be included in the attestation report + /// REPORT_DATA (512 bits) + user_data: [u8; USER_DATA_SIZE], + /// The VMPL to put in the attestation report + vmpl: u32, + /// 31:2 - Reserved + /// 1:0 - KEY_SEL. Selects which key to use for derivation + /// 0: If VLEK is installed, sign with VLEK. Otherwise, sign with VCEK + /// 1: Sign with VCEK + /// 2: Sign with VLEK + /// 3: Reserved + flags: u32, + /// Reserved, must be zero + rsvd: [u8; 24], +} + +impl SnpReportRequest { + /// Take a slice and return a reference for Self + pub fn try_from_as_ref(buffer: &[u8]) -> Result<&Self, SvsmReqError> { + Self::ref_from_prefix(buffer) + .ok() + .map(|(request, _rest)| request) + .filter(|request| request.is_reserved_clear()) + .ok_or_else(SvsmReqError::invalid_parameter) + } + + pub fn is_vmpl0(&self) -> bool { + self.vmpl == 0 + } + + /// Check if the reserved field is clear + fn is_reserved_clear(&self) -> bool { + self.rsvd.into_iter().all(|e| e == 0) + } +} + +/// MSG_REPORT_RSP payload format (AMD SEV-SNP spec. table 23) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, FromBytes, KnownLayout, Immutable)] +pub struct SnpReportResponse { + /// The status of the key derivation operation, see [SnpReportResponseStatus] + status: u32, + /// Size in bytes of the report + report_size: u32, + /// Reserved + _reserved: [u8; 24], + /// The attestation report generated by firmware + report: AttestationReport, +} + +/// Supported values for SnpReportResponse.status +#[repr(u32)] +#[derive(Clone, Copy, Debug)] +pub enum SnpReportResponseStatus { + Success = 0, + InvalidParameters = 0x16, + InvalidKeySelection = 0x27, +} + +impl SnpReportResponse { + /// Validate the [SnpReportResponse] fields + pub fn validate(&self) -> Result<(), SvsmReqError> { + if self.status != SnpReportResponseStatus::Success as u32 { + return Err(SvsmReqError::invalid_request()); + } + + if self.report_size != size_of::() as u32 { + return Err(SvsmReqError::invalid_format()); + } + + Ok(()) + } + + pub fn measurement(&self) -> &[u8; 48] { + &self.report.measurement + } +} + +/// The `TCB_VERSION` contains the security version numbers of each +/// component in the trusted computing base (TCB) of the SNP firmware. +/// (AMD SEV-SNP spec. table 3) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, FromBytes, Immutable)] +struct TcbVersion { + /// Version of the Microcode, SNP firmware, PSP and boot loader + raw: u64, +} + +/// Format for an ECDSA P-384 with SHA-384 signature (AMD SEV-SNP spec. table 115) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, FromBytes, KnownLayout, Immutable)] +struct Signature { + /// R component of this signature + r: [u8; 72], + /// S component of this signature + s: [u8; 72], + /// Reserved + reserved: [u8; 368], +} + +/// ATTESTATION_REPORT format (AMD SEV-SNP spec. table 21) +#[repr(C, packed)] +#[derive(Clone, Copy, Debug, FromBytes, KnownLayout, Immutable)] +pub struct AttestationReport { + /// Version number of this attestation report + version: u32, + /// The guest SVN + guest_svn: u32, + /// The guest policy + policy: u64, + /// The family ID provided at launch + family_id: [u8; 16], + /// The image ID provided at launch + image_id: [u8; 16], + /// The request VMPL for the attestation report + vmpl: u32, + /// The signature algorithm used to sign this report + signature_algo: u32, + /// CurrentTcb + platform_version: TcbVersion, + /// Information about the platform + platform_info: u64, + /// Flags + flags: u32, + /// Reserved, must be zero + reserved0: u32, + /// Guest-provided data + report_data: [u8; 64], + /// The measurement calculated at launch + measurement: [u8; 48], + /// Data provided by the hypervisor at launch + host_data: [u8; 32], + /// SHA-384 digest of the ID public key that signed the ID block + /// provided in `SNP_LAUNCH_FINISH` + id_key_digest: [u8; 48], + /// SHA-384 digest of the Author public key that certified the ID key, + /// if provided in `SNP_LAUNCH_FINISH`. Zeroes if `AUTHOR_KEY_EN` is 1 + author_key_digest: [u8; 48], + /// Report ID of this guest + report_id: [u8; 32], + /// Report ID of this guest's migration agent + report_id_ma: [u8; 32], + /// Report TCB version used to derive the VCEK that signed this report + reported_tcb: TcbVersion, + /// Reserved + reserved1: [u8; 24], + /// If `MaskChipId` is set to 0, Identifier unique to the chip as + /// output by `GET_ID`. Otherwise, set to 0h + chip_id: [u8; 64], + /// Reserved and some more flags + reserved2: [u8; 192], + /// Signature of bytes 0h to 29Fh inclusive of this report + signature: Signature, + /// `zerocopy` needs to expose the type of the last field when + /// generating the implementation for `KnownLayout`, but this + /// causes problems because `AttestationReport` is public and + /// `Signature` is private. Instead add an empty field with a type + /// that's not private. + _empty: (), +} + +const _: () = assert!(size_of::() <= u32::MAX as usize); + +#[cfg(test)] +mod tests { + use super::*; + use core::mem::offset_of; + + #[test] + fn test_snp_report_request_offsets() { + assert_eq!(offset_of!(SnpReportRequest, user_data), 0x0); + assert_eq!(offset_of!(SnpReportRequest, vmpl), 0x40); + assert_eq!(offset_of!(SnpReportRequest, flags), 0x44); + assert_eq!(offset_of!(SnpReportRequest, rsvd), 0x48); + } + + #[test] + fn test_snp_report_response_offsets() { + assert_eq!(offset_of!(SnpReportResponse, status), 0x0); + assert_eq!(offset_of!(SnpReportResponse, report_size), 0x4); + assert_eq!(offset_of!(SnpReportResponse, _reserved), 0x8); + assert_eq!(offset_of!(SnpReportResponse, report), 0x20); + } + + #[test] + fn test_ecdsa_p384_sha384_signature_offsets() { + assert_eq!(offset_of!(Signature, r), 0x0); + assert_eq!(offset_of!(Signature, s), 0x48); + assert_eq!(offset_of!(Signature, reserved), 0x90); + } + + #[test] + fn test_attestation_report_offsets() { + assert_eq!(offset_of!(AttestationReport, version), 0x0); + assert_eq!(offset_of!(AttestationReport, guest_svn), 0x4); + assert_eq!(offset_of!(AttestationReport, policy), 0x8); + assert_eq!(offset_of!(AttestationReport, family_id), 0x10); + assert_eq!(offset_of!(AttestationReport, image_id), 0x20); + assert_eq!(offset_of!(AttestationReport, vmpl), 0x30); + assert_eq!(offset_of!(AttestationReport, signature_algo), 0x34); + assert_eq!(offset_of!(AttestationReport, platform_version), 0x38); + assert_eq!(offset_of!(AttestationReport, platform_info), 0x40); + assert_eq!(offset_of!(AttestationReport, flags), 0x48); + assert_eq!(offset_of!(AttestationReport, reserved0), 0x4c); + assert_eq!(offset_of!(AttestationReport, report_data), 0x50); + assert_eq!(offset_of!(AttestationReport, measurement), 0x90); + assert_eq!(offset_of!(AttestationReport, host_data), 0xc0); + assert_eq!(offset_of!(AttestationReport, id_key_digest), 0xe0); + assert_eq!(offset_of!(AttestationReport, author_key_digest), 0x110); + assert_eq!(offset_of!(AttestationReport, report_id), 0x140); + assert_eq!(offset_of!(AttestationReport, report_id_ma), 0x160); + assert_eq!(offset_of!(AttestationReport, reported_tcb), 0x180); + assert_eq!(offset_of!(AttestationReport, reserved1), 0x188); + assert_eq!(offset_of!(AttestationReport, chip_id), 0x1a0); + assert_eq!(offset_of!(AttestationReport, reserved2), 0x1e0); + assert_eq!(offset_of!(AttestationReport, signature), 0x2a0); + } +} diff --git a/stage2/src/greq/services.rs b/stage2/src/greq/services.rs new file mode 100644 index 000000000..d810fb175 --- /dev/null +++ b/stage2/src/greq/services.rs @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! API to send `SNP_GUEST_REQUEST` commands to the PSP + +use zerocopy::FromBytes; + +use crate::{ + greq::{ + driver::{send_extended_guest_request, send_regular_guest_request}, + msg::SnpGuestRequestMsgType, + pld_report::{SnpReportRequest, SnpReportResponse}, + }, + protocols::errors::SvsmReqError, +}; +use core::mem::size_of; + +const REPORT_REQUEST_SIZE: usize = size_of::(); +const REPORT_RESPONSE_SIZE: usize = size_of::(); + +fn get_report(buffer: &mut [u8], certs: Option<&mut [u8]>) -> Result { + let request: &SnpReportRequest = SnpReportRequest::try_from_as_ref(buffer)?; + // Non-VMPL0 attestation reports can be requested by the guest kernel + // directly to the PSP. + if !request.is_vmpl0() { + return Err(SvsmReqError::invalid_parameter()); + } + let response_len = if certs.is_none() { + send_regular_guest_request( + SnpGuestRequestMsgType::ReportRequest, + buffer, + REPORT_REQUEST_SIZE, + )? + } else { + send_extended_guest_request( + SnpGuestRequestMsgType::ReportRequest, + buffer, + REPORT_REQUEST_SIZE, + certs.unwrap(), + )? + }; + if REPORT_RESPONSE_SIZE > response_len { + return Err(SvsmReqError::invalid_request()); + } + let (response, _rest) = SnpReportResponse::ref_from_prefix(buffer) + .map_err(|_| SvsmReqError::invalid_parameter())?; + response.validate()?; + + Ok(response_len) +} + +/// Request a regular VMPL0 attestation report to the PSP. +/// +/// Use the `SNP_GUEST_REQUEST` driver to send the provided `MSG_REPORT_REQ` command to +/// the PSP. The VPML field of the command must be set to zero. +/// +/// The VMPCK0 is disabled for subsequent calls if this function fails in a way that +/// the VM state can be compromised. +/// +/// # Arguments +/// +/// * `buffer`: Buffer with the [`MSG_REPORT_REQ`](SnpReportRequest) command that will be +/// sent to the PSP. It must be large enough to hold the +/// [`MSG_REPORT_RESP`](SnpReportResponse) received from the PSP. +/// +/// # Returns +/// +/// * Success +/// * `usize`: Number of bytes written to `buffer`. It should match the +/// [`MSG_REPORT_RESP`](SnpReportResponse) size. +/// * Error +/// * [`SvsmReqError`] +pub fn get_regular_report(buffer: &mut [u8]) -> Result { + get_report(buffer, None) +} + +/// Request an extended VMPL0 attestation report to the PSP. +/// +/// We say that it is extended because it requests a VMPL0 attestation report +/// to the PSP (as in [`get_regular_report()`]) and also requests to the hypervisor +/// the certificates required to verify the attestation report. +/// +/// The VMPCK0 is disabled for subsequent calls if this function fails in a way that +/// the VM state can be compromised. +/// +/// # Arguments +/// +/// * `buffer`: Buffer with the [`MSG_REPORT_REQ`](SnpReportRequest) command that will be +/// sent to the PSP. It must be large enough to hold the +/// [`MSG_REPORT_RESP`](SnpReportResponse) received from the PSP. +/// * `certs`: Buffer to store the SEV-SNP certificates received from the hypervisor. +/// +/// # Return codes +/// +/// * Success +/// * `usize`: Number of bytes written to `buffer`. It should match +/// the [`MSG_REPORT_RESP`](SnpReportResponse) size. +/// * Error +/// * [`SvsmReqError`] +/// * `SvsmReqError::FatalError(SvsmError::Ghcb(GhcbError::VmgexitError(certs_buffer_size, psp_rc)))`: +/// * `certs` is not large enough to hold the certificates. +/// * `certs_buffer_size`: number of bytes required. +/// * `psp_rc`: PSP return code +pub fn get_extended_report(buffer: &mut [u8], certs: &mut [u8]) -> Result { + get_report(buffer, Some(certs)) +} + +#[cfg(test)] +mod tests { + #[allow(unused)] + use super::*; + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + #[cfg(test_in_svsm)] + fn test_snp_launch_measurement() { + extern crate alloc; + + use crate::serial::Terminal; + use crate::testing::{assert_eq_warn, svsm_test_io, IORequest}; + + use alloc::vec; + + let sp = svsm_test_io().unwrap(); + + sp.put_byte(IORequest::GetLaunchMeasurement as u8); + + let mut expected_measurement = [0u8; 48]; + for byte in &mut expected_measurement { + *byte = sp.get_byte(); + } + + let mut buf = vec![0; size_of::()]; + let size = get_regular_report(&mut buf).unwrap(); + assert_eq!(size, buf.len()); + + let (response, _rest) = SnpReportResponse::ref_from_prefix(&buf).unwrap(); + response.validate().unwrap(); + // FIXME: we still have some cases where the precalculated value does + // not match, so for now we just issue a warning until we fix the problem. + assert_eq_warn!(expected_measurement, *response.measurement()); + } +} diff --git a/stage2/src/igvm_params.rs b/stage2/src/igvm_params.rs new file mode 100644 index 000000000..b38376694 --- /dev/null +++ b/stage2/src/igvm_params.rs @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +extern crate alloc; + +use crate::acpi::tables::ACPICPUInfo; +use crate::address::{PhysAddr, VirtAddr}; +use crate::cpu::efer::EFERFlags; +use crate::error::SvsmError; +use crate::fw_meta::SevFWMetaData; +use crate::mm::{GuestPtr, PerCPUPageMappingGuard, PAGE_SIZE}; +use crate::platform::{PageStateChangeOp, PageValidateOp, SVSM_PLATFORM}; +use crate::types::PageSize; +use crate::utils::MemoryRegion; +use alloc::vec::Vec; +use cpuarch::vmsa::VMSA; + +use bootlib::igvm_params::{IgvmGuestContext, IgvmParamBlock, IgvmParamPage}; +use core::mem::size_of; +use igvm_defs::{IgvmEnvironmentInfo, MemoryMapEntryType, IGVM_VHS_MEMORY_MAP_ENTRY}; + +const IGVM_MEMORY_ENTRIES_PER_PAGE: usize = PAGE_SIZE / size_of::(); + +const STAGE2_END_ADDR: usize = 0xA0000; + +#[derive(Clone, Debug)] +#[repr(C, align(64))] +pub struct IgvmMemoryMap { + memory_map: [IGVM_VHS_MEMORY_MAP_ENTRY; IGVM_MEMORY_ENTRIES_PER_PAGE], +} + +#[derive(Clone, Debug)] +pub struct IgvmParams<'a> { + igvm_param_block: &'a IgvmParamBlock, + igvm_param_page: &'a IgvmParamPage, + igvm_memory_map: &'a IgvmMemoryMap, + igvm_guest_context: Option<&'a IgvmGuestContext>, +} + +impl IgvmParams<'_> { + pub fn new(addr: VirtAddr) -> Result { + let param_block = Self::try_aligned_ref::(addr)?; + let param_page_address = addr + param_block.param_page_offset as usize; + let param_page = Self::try_aligned_ref::(param_page_address)?; + let memory_map_address = addr + param_block.memory_map_offset as usize; + let memory_map = Self::try_aligned_ref::(memory_map_address)?; + let guest_context = if param_block.guest_context_offset != 0 { + let offset = usize::try_from(param_block.guest_context_offset).unwrap(); + Some(Self::try_aligned_ref::(addr + offset)?) + } else { + None + }; + + Ok(Self { + igvm_param_block: param_block, + igvm_param_page: param_page, + igvm_memory_map: memory_map, + igvm_guest_context: guest_context, + }) + } + + fn try_aligned_ref<'a, T>(addr: VirtAddr) -> Result<&'a T, SvsmError> { + // SAFETY: we trust the caller to provide an address pointing to valid + // memory which is not mutably aliased. + unsafe { addr.aligned_ref::().ok_or(SvsmError::Firmware) } + } + + pub fn size(&self) -> usize { + // Calculate the total size of the parameter area. The + // parameter area always begins at the kernel base + // address. + self.igvm_param_block.param_area_size.try_into().unwrap() + } + + pub fn find_kernel_region(&self) -> Result, SvsmError> { + let kernel_base = PhysAddr::from(self.igvm_param_block.kernel_base); + let kernel_size: usize = self.igvm_param_block.kernel_size.try_into().unwrap(); + Ok(MemoryRegion::::new(kernel_base, kernel_size)) + } + + pub fn reserved_kernel_area_size(&self) -> usize { + self.igvm_param_block + .kernel_reserved_size + .try_into() + .unwrap() + } + + pub fn page_state_change_required(&self) -> bool { + let environment_info = IgvmEnvironmentInfo::from(self.igvm_param_page.environment_info); + environment_info.memory_is_shared() + } + + pub fn get_memory_regions(&self) -> Result>, SvsmError> { + // Count the number of memory entries present. They must be + // non-overlapping and strictly increasing. + let mut number_of_entries = 0; + let mut next_page_number = 0; + for entry in self.igvm_memory_map.memory_map.iter() { + if entry.number_of_pages == 0 { + break; + } + if entry.starting_gpa_page_number < next_page_number { + return Err(SvsmError::Firmware); + } + let next_supplied_page_number = entry.starting_gpa_page_number + entry.number_of_pages; + if next_supplied_page_number < next_page_number { + return Err(SvsmError::Firmware); + } + next_page_number = next_supplied_page_number; + number_of_entries += 1; + } + + // Now loop over the supplied entires and add a region for each + // known type. + let mut regions: Vec> = Vec::new(); + for entry in self + .igvm_memory_map + .memory_map + .iter() + .take(number_of_entries) + { + if entry.entry_type == MemoryMapEntryType::MEMORY { + let starting_page: usize = entry.starting_gpa_page_number.try_into().unwrap(); + let number_of_pages: usize = entry.number_of_pages.try_into().unwrap(); + regions.push(MemoryRegion::new( + PhysAddr::new(starting_page * PAGE_SIZE), + number_of_pages * PAGE_SIZE, + )); + } + } + + Ok(regions) + } + + pub fn write_guest_memory_map(&self, map: &[MemoryRegion]) -> Result<(), SvsmError> { + // If the parameters do not include a guest memory map area, then no + // work is required. + let fw_info = &self.igvm_param_block.firmware; + if fw_info.memory_map_page_count == 0 { + return Ok(()); + } + + // Map the guest memory map area into the address space. + let mem_map_gpa = PhysAddr::from(fw_info.memory_map_page as u64 * PAGE_SIZE as u64); + let mem_map_region = MemoryRegion::::new( + mem_map_gpa, + fw_info.memory_map_page_count as usize * PAGE_SIZE, + ); + log::info!( + "Filling guest IGVM memory map at {:#018x} size {:#018x}", + mem_map_region.start(), + mem_map_region.len(), + ); + + let mem_map_mapping = + PerCPUPageMappingGuard::create(mem_map_region.start(), mem_map_region.end(), 0)?; + let mem_map_va = mem_map_mapping.virt_addr(); + + // The guest expects the pages in the memory map to be treated like + // host-provided IGVM parameters, which requires the pages to be + // validated. Since the memory was not declared as part of the guest + // firmware image, the pages must be validated here. + if self.page_state_change_required() { + SVSM_PLATFORM.page_state_change( + mem_map_region, + PageSize::Regular, + PageStateChangeOp::Private, + )?; + } + + let mem_map_va_region = MemoryRegion::::new(mem_map_va, mem_map_region.len()); + SVSM_PLATFORM.validate_virtual_page_range(mem_map_va_region, PageValidateOp::Validate)?; + + // Calculate the maximum number of entries that can be inserted. + let max_entries = fw_info.memory_map_page_count as usize * PAGE_SIZE + / size_of::(); + + // Generate a guest pointer range to hold the memory map. + let mem_map = GuestPtr::new(mem_map_va); + + for (i, entry) in map.iter().enumerate() { + // Return an error if an overflow occurs. + if i >= max_entries { + return Err(SvsmError::Firmware); + } + + // SAFETY: mem_map_va points to newly mapped memory, whose physical + // address is defined in the IGVM config. + unsafe { + mem_map + .offset(i as isize) + .write(IGVM_VHS_MEMORY_MAP_ENTRY { + starting_gpa_page_number: u64::from(entry.start()) / PAGE_SIZE as u64, + number_of_pages: entry.len() as u64 / PAGE_SIZE as u64, + entry_type: MemoryMapEntryType::default(), + flags: 0, + reserved: 0, + })?; + } + } + + // Write a zero page count into the last entry to terminate the list. + let index = map.len(); + if index < max_entries { + // SAFETY: mem_map_va points to newly mapped memory, whose physical + // address is defined in the IGVM config. + unsafe { + mem_map + .offset(index as isize) + .write(IGVM_VHS_MEMORY_MAP_ENTRY { + starting_gpa_page_number: 0, + number_of_pages: 0, + entry_type: MemoryMapEntryType::default(), + flags: 0, + reserved: 0, + })?; + } + } + + Ok(()) + } + + pub fn load_cpu_info(&self) -> Result, SvsmError> { + let mut cpus: Vec = Vec::new(); + log::info!("CPU count is {}", { self.igvm_param_page.cpu_count }); + for i in 0..self.igvm_param_page.cpu_count { + let cpu = ACPICPUInfo { + apic_id: i, + enabled: true, + }; + cpus.push(cpu); + } + Ok(cpus) + } + + pub fn should_launch_fw(&self) -> bool { + self.igvm_param_block.firmware.size != 0 + } + + pub fn debug_serial_port(&self) -> u16 { + self.igvm_param_block.debug_serial_port + } + + pub fn get_fw_metadata(&self) -> Option { + if !self.should_launch_fw() { + return None; + } + + let mut fw_meta = SevFWMetaData::new(); + + if self.igvm_param_block.firmware.caa_page != 0 { + fw_meta.caa_page = Some(PhysAddr::new( + self.igvm_param_block.firmware.caa_page.try_into().unwrap(), + )); + } + + if self.igvm_param_block.firmware.secrets_page != 0 { + fw_meta.secrets_page = Some(PhysAddr::new( + self.igvm_param_block + .firmware + .secrets_page + .try_into() + .unwrap(), + )); + } + + if self.igvm_param_block.firmware.cpuid_page != 0 { + fw_meta.cpuid_page = Some(PhysAddr::new( + self.igvm_param_block + .firmware + .cpuid_page + .try_into() + .unwrap(), + )); + } + + let preval_count = self.igvm_param_block.firmware.prevalidated_count as usize; + for preval in self + .igvm_param_block + .firmware + .prevalidated + .iter() + .take(preval_count) + { + let base = PhysAddr::from(preval.base as usize); + fw_meta.add_valid_mem(base, preval.size as usize); + } + + Some(fw_meta) + } + + pub fn get_fw_regions(&self) -> Vec> { + assert!(self.should_launch_fw()); + + let mut regions = Vec::new(); + + if self.igvm_param_block.firmware.in_low_memory != 0 { + // Add the stage 2 region to the firmware region list so + // permissions can be granted to the guest VMPL for that range. + regions.push(MemoryRegion::new(PhysAddr::new(0), STAGE2_END_ADDR)); + } + + regions.push(MemoryRegion::new( + PhysAddr::new(self.igvm_param_block.firmware.start as usize), + self.igvm_param_block.firmware.size as usize, + )); + + regions + } + + pub fn fw_in_low_memory(&self) -> bool { + self.igvm_param_block.firmware.in_low_memory != 0 + } + + pub fn initialize_guest_vmsa(&self, vmsa: &mut VMSA) -> Result<(), SvsmError> { + let Some(guest_context) = self.igvm_guest_context else { + return Ok(()); + }; + + // Copy the specified registers into the VMSA. + vmsa.cr0 = guest_context.cr0; + vmsa.cr3 = guest_context.cr3; + vmsa.cr4 = guest_context.cr4; + vmsa.efer = guest_context.efer; + vmsa.rip = guest_context.rip; + vmsa.rax = guest_context.rax; + vmsa.rcx = guest_context.rcx; + vmsa.rdx = guest_context.rdx; + vmsa.rbx = guest_context.rbx; + vmsa.rsp = guest_context.rsp; + vmsa.rbp = guest_context.rbp; + vmsa.rsi = guest_context.rsi; + vmsa.rdi = guest_context.rdi; + vmsa.r8 = guest_context.r8; + vmsa.r9 = guest_context.r9; + vmsa.r10 = guest_context.r10; + vmsa.r11 = guest_context.r11; + vmsa.r12 = guest_context.r12; + vmsa.r13 = guest_context.r13; + vmsa.r14 = guest_context.r14; + vmsa.r15 = guest_context.r15; + vmsa.gdt.base = guest_context.gdt_base; + vmsa.gdt.limit = guest_context.gdt_limit; + + // If a non-zero code selector is specified, then set the code + // segment attributes based on EFER.LMA. + if guest_context.code_selector != 0 { + vmsa.cs.selector = guest_context.code_selector; + let efer_lma = EFERFlags::LMA; + if (vmsa.efer & efer_lma.bits()) != 0 { + vmsa.cs.flags = 0xA9B; + } else { + vmsa.cs.flags = 0xC9B; + vmsa.cs.limit = 0xFFFFFFFF; + } + } + + let efer_svme = EFERFlags::SVME; + vmsa.efer &= !efer_svme.bits(); + + // If a non-zero data selector is specified, then modify the data + // segment attributes to be compatible with protected mode. + if guest_context.data_selector != 0 { + vmsa.ds.selector = guest_context.data_selector; + vmsa.ds.flags = 0xA93; + vmsa.ds.limit = 0xFFFFFFFF; + vmsa.ss = vmsa.ds; + vmsa.es = vmsa.ds; + vmsa.fs = vmsa.ds; + vmsa.gs = vmsa.ds; + } + + // Configure vTOM if requested. + if self.igvm_param_block.vtom != 0 { + vmsa.vtom = self.igvm_param_block.vtom; + vmsa.sev_features |= 2; // VTOM feature + } + + Ok(()) + } + + pub fn get_vtom(&self) -> u64 { + self.igvm_param_block.vtom + } + + pub fn use_alternate_injection(&self) -> bool { + self.igvm_param_block.use_alternate_injection != 0 + } +} diff --git a/stage2/src/insn_decode/decode.rs b/stage2/src/insn_decode/decode.rs new file mode 100644 index 000000000..c6e3daf90 --- /dev/null +++ b/stage2/src/insn_decode/decode.rs @@ -0,0 +1,1555 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2024 Intel Corporation. +// +// Author: Chuanxiao Dong +// +// The instruction decoding is implemented by refering instr_emul.c +// from the Arcn project, with some modifications. A copy of license +// is included below: +// +// Copyright (c) 2012 Sandvine, Inc. +// Copyright (c) 2012 NetApp, Inc. +// Copyright (c) 2017-2022 Intel Corporation. +// +// Aedistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +// OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) +// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +// OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF +// SUCH DAMAGE. +// +// $FreeBSD$ +// +// The original file can be found in this repository: +// https://github.com/projectacrn/acrn-hypervisor/blob/master/hypervisor/ +// arch/x86/guest/instr_emul.c + +extern crate alloc; + +use super::insn::{DecodedInsn, Immediate, Operand, MAX_INSN_SIZE}; +use super::opcode::{OpCodeClass, OpCodeDesc, OpCodeFlags}; +use super::{InsnError, Register, SegRegister}; +use crate::cpu::control_regs::{CR0Flags, CR4Flags}; +use crate::cpu::efer::EFERFlags; +use crate::cpu::registers::{RFlags, SegDescAttrFlags}; +use crate::types::Bytes; +use alloc::boxed::Box; +use bitflags::bitflags; + +/// Represents the raw bytes of an instruction and +/// tracks the number of bytes being processed. +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub struct InsnBytes { + /// Raw instruction bytes + bytes: [u8; MAX_INSN_SIZE], + /// Number of instruction bytes being processed + nr_processed: usize, +} + +impl InsnBytes { + /// Creates a new `OpCodeBytes` instance with the provided instruction bytes. + /// + /// # Arguments + /// + /// * `bytes` - An array of raw instruction bytes + /// + /// # Returns + /// + /// A new instance of `OpCodeBytes` with the `bytes` set to the provided + /// array and the `nr_processed` field initialized to zero. + pub const fn new(bytes: [u8; MAX_INSN_SIZE]) -> Self { + Self { + bytes, + nr_processed: 0, + } + } + + /// Retrieves a single unprocessed instruction byte. + /// + /// # Returns + /// + /// An instruction byte if success or an [`InsnError`] otherwise. + pub fn peek(&self) -> Result { + self.bytes + .get(self.nr_processed) + .copied() + .ok_or(InsnError::InsnPeek) + } + + /// Increases the count by one after a peeked byte being processed. + pub fn advance(&mut self) { + self.nr_processed += 1 + } + + /// Retrieves the number of processed instruction bytes. + /// + /// # Returns + /// + /// Returns the number of processed bytes as a `usize`. + pub fn processed(&self) -> usize { + self.nr_processed + } +} + +/// The instruction bytes specifically for OpCode decoding +#[derive(Clone, Copy, Debug)] +pub struct OpCodeBytes(pub InsnBytes); + +// The instruction bytes specifically for prefix decoding +#[derive(Clone, Copy, Debug)] +struct PrefixBytes(InsnBytes); +// The instruction bytes specifically for ModR/M decoding +#[derive(Clone, Copy, Debug)] +struct ModRmBytes(InsnBytes); +// The instruction bytes specifically for SIB decoding +#[derive(Clone, Copy, Debug)] +struct SibBytes(InsnBytes); +// The instruction bytes specifically for displacement decoding +#[derive(Clone, Copy, Debug)] +struct DisBytes(InsnBytes); +// The instruction bytes specifically for immediate decoding +#[derive(Clone, Copy, Debug)] +struct ImmBytes(InsnBytes); +// The instruction bytes specifically for Mem-Offset decoding +#[derive(Clone, Copy, Debug)] +struct MoffBytes(InsnBytes); +// The instruction bytes specifically after decoding completed +#[derive(Clone, Copy, Debug)] +struct DecodedBytes(InsnBytes); + +/// This trait provides the necessary context for an instruction decoder +/// to decode instructions based on the current state of the machine +/// that executed them. It abstracts the interfaces through which an +/// instruction decoder can access specific registers and state that may +/// influence the decoding from the machine (such as a CPU or VMM). +pub trait InsnMachineCtx: core::fmt::Debug { + /// Read EFER register + fn read_efer(&self) -> u64; + /// Read a code segment register + fn read_seg(&self, seg: SegRegister) -> u64; + /// Read CR0 register + fn read_cr0(&self) -> u64; + /// Read CR4 register + fn read_cr4(&self) -> u64; + + /// Read a register + fn read_reg(&self, _reg: Register) -> usize { + unimplemented!("Reading register is not implemented"); + } + + /// Read rflags register + fn read_flags(&self) -> usize { + unimplemented!("Reading flags is not implemented"); + } + + /// Write a register + fn write_reg(&mut self, _reg: Register, _val: usize) { + unimplemented!("Writing register is not implemented"); + } + + /// Read the current privilege level + fn read_cpl(&self) -> usize { + unimplemented!("Reading CPL is not implemented"); + } + + /// Map the given linear address region to a machine memory object + /// which provides access to the memory of this linear address region. + /// + /// # Arguments + /// + /// * `la` - The linear address of the region to map. + /// * `write` - Whether write access is allowed to the mapped region. + /// * `fetch` - Whether fetch access is allowed to the mapped region. + /// + /// # Returns + /// + /// A `Result` containing a boxed trait object representing the mapped + /// memory, or an `InsnError` if mapping fails. + fn map_linear_addr( + &self, + _la: usize, + _write: bool, + _fetch: bool, + ) -> Result>, InsnError> { + Err(InsnError::MapLinearAddr) + } + + /// Check IO permission bitmap. + /// + /// # Arguments + /// + /// * `port` - The I/O port to check. + /// * `size` - The size of the I/O operation. + /// * `io_read` - Whether the I/O operation is a read operation. + /// + /// # Returns + /// + /// A `Result` containing true if the port is permitted otherwise false. + fn ioio_perm(&self, _port: u16, _size: Bytes, _io_read: bool) -> bool { + unimplemented!("Checking IO permission bitmap is not implemented"); + } + + /// Handle an I/O in operation. + /// + /// # Arguments + /// + /// * `port` - The I/O port to read from. + /// * `size` - The size of the data to read. + /// + /// # Returns + /// + /// A `Result` containing the read data if success or an `InsnError` if + /// the operation fails. + fn ioio_in(&self, _port: u16, _size: Bytes) -> Result { + Err(InsnError::IoIoIn) + } + + /// Handle an I/O out operation. + /// + /// # Arguments + /// + /// * `port` - The I/O port to write to. + /// * `size` - The size of the data to write. + /// * `data` - The data to write to the I/O port. + /// + /// # Returns + /// + /// A `Result` indicating success or an `InsnError` if the operation fails. + fn ioio_out(&mut self, _port: u16, _size: Bytes, _data: u64) -> Result<(), InsnError> { + Err(InsnError::IoIoOut) + } + + /// Translate the given linear address to a physical address. + /// + /// # Arguments + /// + /// * `la` - The linear address to translate. + /// * `write` - Whether the translation is for a write operation. + /// * `fetch` - Whether the translation is for a fetch operation. + /// + /// # Returns + /// + /// A `Result` containing the translated physical address and a boolean + /// indicating whether the physical address is shared or an `InsnError` if + /// the translation fails. + fn translate_linear_addr( + &self, + _la: usize, + _write: bool, + _fetch: bool, + ) -> Result<(usize, bool), InsnError> { + Err(InsnError::TranslateLinearAddr) + } + + /// Handle a memory-mapped I/O read operation. + /// + /// # Arguments + /// + /// * `pa` - The MMIO physical address to read from. + /// * `shared` - Whether the MMIO address is shared. + /// * `size` - The size of the data to read. + /// + /// # Returns + /// + /// A `Result` containing the read data if success or an `InsnError` if + /// the operation fails. + fn handle_mmio_read(&self, _pa: usize, _shared: bool, _size: Bytes) -> Result { + Err(InsnError::HandleMmioRead) + } + + /// Handle a memory-mapped I/O write operation. + /// + /// # Arguments + /// + /// * `pa` - The MMIO physical address to write to. + /// * `shared` - Whether the MMIO address is shared. + /// * `size` - The size of the data to write. + /// * `data` - The data to write to the MMIO. + /// + /// # Returns + /// + /// A `Result` indicating success or an `InsnError` if the operation fails. + fn handle_mmio_write( + &mut self, + _pa: usize, + _shared: bool, + _size: Bytes, + _data: u64, + ) -> Result<(), InsnError> { + Err(InsnError::HandleMmioWrite) + } +} + +/// Trait representing a machine memory for instruction decoding. +pub trait InsnMachineMem { + type Item; + + /// Read data from the memory at the specified offset. + /// + /// # Safety + /// + /// The caller must verify not to read data from arbitrary memory. The object implements this + /// trait should guarantee the memory region is readable. + /// + /// # Returns + /// + /// Returns the read data on success, or an `InsnError` if the read + /// operation fails. + unsafe fn mem_read(&self) -> Result { + Err(InsnError::MemRead) + } + + /// Write data to the memory at the specified offset. + /// + /// # Safety + /// + /// The caller must verify not to write data to corrupt arbitrary memory. The object implements + /// this trait should guarantee the memory region is writable. + /// + /// # Arguments + /// + /// * `data` - The data to write to the memory. + /// + /// # Returns + /// + /// Returns `Ok`on success, or an `InsnError` if the write operation fails. + unsafe fn mem_write(&mut self, _data: Self::Item) -> Result<(), InsnError> { + Err(InsnError::MemWrite) + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +enum PagingLevel { + Level4, + Level5, +} + +#[derive(Clone, Copy, Debug, Default, PartialEq)] +enum CpuMode { + #[default] + Real, + Protected, + Compatibility, + Bit64(PagingLevel), +} + +impl CpuMode { + fn is_bit64(&self) -> bool { + matches!(self, CpuMode::Bit64(_)) + } +} + +fn get_cpu_mode(mctx: &I) -> CpuMode { + if (mctx.read_efer() & EFERFlags::LMA.bits()) != 0 { + // EFER.LMA = 1 + if (mctx.read_seg(SegRegister::CS) & SegDescAttrFlags::L.bits()) != 0 { + // CS.L = 1 represents 64bit mode. + // While this sub-mode produces 64-bit linear addresses, the processor + // enforces canonicality, meaning that the upper bits of such an address + // are identical: bits 63:47 for 4-level paging and bits 63:56 for + // 5-level paging. 4-level paging (respectively, 5-level paging) does not + // use bits 63:48 (respectively, bits 63:57) of such addresses + let level = if (mctx.read_cr4() & CR4Flags::LA57.bits()) != 0 { + PagingLevel::Level5 + } else { + PagingLevel::Level4 + }; + CpuMode::Bit64(level) + } else { + CpuMode::Compatibility + } + } else if (mctx.read_cr0() & CR0Flags::PE.bits()) != 0 { + // CR0.PE = 1 + CpuMode::Protected + } else { + CpuMode::Real + } +} + +// Translate the decoded number from the instruction ModR/M +// or SIB to the corresponding register +struct RegCode(u8); +impl TryFrom for Register { + type Error = InsnError; + + fn try_from(val: RegCode) -> Result { + match val.0 { + 0 => Ok(Register::Rax), + 1 => Ok(Register::Rcx), + 2 => Ok(Register::Rdx), + 3 => Ok(Register::Rbx), + 4 => Ok(Register::Rsp), + 5 => Ok(Register::Rbp), + 6 => Ok(Register::Rsi), + 7 => Ok(Register::Rdi), + 8 => Ok(Register::R8), + 9 => Ok(Register::R9), + 10 => Ok(Register::R10), + 11 => Ok(Register::R11), + 12 => Ok(Register::R12), + 13 => Ok(Register::R13), + 14 => Ok(Register::R14), + 15 => Ok(Register::R15), + // Rip is not represented by ModR/M or SIB + _ => Err(InsnError::InvalidRegister), + } + } +} + +const PREFIX_SIZE: usize = 4; + +bitflags! { + #[derive(Copy, Clone, Debug, Default, PartialEq)] + struct PrefixFlags: u16 { + const REX_W = 1 << 0; + const REX_R = 1 << 1; + const REX_X = 1 << 2; + const REX_B = 1 << 3; + const REX_P = 1 << 4; + const REPZ_P = 1 << 5; + const REPNZ_P = 1 << 6; + const OPSIZE_OVERRIDE = 1 << 7; + const ADDRSIZE_OVERRIDE = 1 << 8; + } +} + +bitflags! { + #[derive(Copy, Clone, Debug, Default, PartialEq)] + struct RexPrefix: u8 { + const B = 1 << 0; + const X = 1 << 1; + const R = 1 << 2; + const W = 1 << 3; + } +} + +#[derive(Copy, Clone, Default, Debug, PartialEq)] +struct ModRM(u8); + +const MOD_INDIRECT: u8 = 0; +const MOD_INDIRECT_DISP8: u8 = 1; +const MOD_INDIRECT_DISP32: u8 = 2; +const MOD_DIRECT: u8 = 3; +const RM_SIB: u8 = 4; +const RM_DISP32: u8 = 5; + +impl From for ModRM { + fn from(val: u8) -> Self { + ModRM(val) + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum RM { + Reg(Register), + Sib, + Disp32, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +enum Mod { + Indirect, + IndirectDisp8, + IndirectDisp32, + Direct, +} + +impl ModRM { + fn get_mod(&self) -> Mod { + let v = (self.0 >> 6) & 0x3; + + match v { + MOD_INDIRECT => Mod::Indirect, + MOD_INDIRECT_DISP8 => Mod::IndirectDisp8, + MOD_INDIRECT_DISP32 => Mod::IndirectDisp32, + MOD_DIRECT => Mod::Direct, + _ => { + unreachable!("Mod has only two bits, so its value is always 0 ~ 3"); + } + } + } + + fn get_reg(&self) -> u8 { + (self.0 >> 3) & 0x7 + } + + fn get_rm(&self) -> RM { + let rm = self.0 & 0x7; + let r#mod = self.get_mod(); + + // RM depends on the Mod value + if r#mod == Mod::Indirect && rm == RM_DISP32 { + RM::Disp32 + } else if r#mod != Mod::Direct && rm == RM_SIB { + RM::Sib + } else { + RM::Reg(Register::try_from(RegCode(rm)).unwrap()) + } + } +} + +#[derive(Copy, Clone, Default, Debug, PartialEq)] +struct Sib(u8); + +impl From for Sib { + fn from(val: u8) -> Self { + Sib(val) + } +} + +impl Sib { + fn get_scale(&self) -> u8 { + (self.0 >> 6) & 0x3 + } + + fn get_index(&self) -> u8 { + (self.0 >> 3) & 0x7 + } + + fn get_base(&self) -> u8 { + self.0 & 0x7 + } +} + +#[inline] +fn read_reg(mctx: &I, reg: Register, size: Bytes) -> usize { + mctx.read_reg(reg) & size.mask() as usize +} + +#[inline] +fn write_reg(mctx: &mut I, reg: Register, data: usize, size: Bytes) { + mctx.write_reg( + reg, + match size { + Bytes::Zero => return, + // Writing 8bit or 16bit register will not affect the upper bits. + Bytes::One | Bytes::Two => { + let old = mctx.read_reg(reg); + (data & size.mask() as usize) | (old & !size.mask() as usize) + } + // Writing 32bit register will zero out the upper bits. + Bytes::Four => data & size.mask() as usize, + Bytes::Eight => data, + }, + ); +} + +#[inline] +fn segment_base(segment: u64) -> u32 { + // Segment base bits 0 ~ 23: raw value bits 16 ~ 39 + // Segment base bits 24 ~ 31: raw value bits 56 ~ 63 + (((segment >> 16) & 0xffffff) | ((segment >> 56) << 24)) as u32 +} + +#[inline] +fn segment_limit(segment: u64) -> u32 { + // Segment limit bits 0 ~ 15: raw value bits 0 ~ 15 + // Segment limit bits 16 ~ 19: raw value bits 48 ~ 51 + let limit = ((segment & 0xffff) | ((segment >> 32) & 0xf0000)) as u32; + + if SegDescAttrFlags::from_bits_truncate(segment).contains(SegDescAttrFlags::G) { + (limit << 12) | 0xfff + } else { + limit + } +} + +fn ioio_perm(mctx: &I, port: u16, size: Bytes, io_read: bool) -> bool { + if mctx.read_cr0() & CR0Flags::PE.bits() != 0 + && (mctx.read_cpl() > ((mctx.read_flags() >> 12) & 3) + || mctx.read_cr4() & CR4Flags::VME.bits() != 0) + { + // In protected mode with CPL > IOPL or virtual-8086 mode, if + // any I/O Permission Bit for I/O port being accessed = 1, the I/O + // operation is not allowed. + mctx.ioio_perm(port, size, io_read) + } else { + true + } +} + +#[inline] +fn read_bytereg(mctx: &I, reg: Register, lhbr: bool) -> u8 { + let data = mctx.read_reg(reg); + // To obtain the value of a legacy high byte register shift the + // base register right by 8 bits (%ah = %rax >> 8). + (if lhbr { data >> 8 } else { data }) as u8 +} + +#[inline] +fn write_bytereg(mctx: &mut I, reg: Register, lhbr: bool, data: u8) { + let old = mctx.read_reg(reg); + let mask = (Bytes::One).mask() as usize; + + let new = if lhbr { + (data as usize) << 8 | (old & !(mask << 8)) + } else { + (data as usize) | (old & !mask) + }; + + mctx.write_reg(reg, new); +} + +/// Represents the context of a decoded instruction, which is used to +/// interpret the instruction. It holds the decoded instruction, its +/// length and various components that are decoded from the instruction +/// bytes. +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub struct DecodedInsnCtx { + insn: Option, + insn_len: usize, + cpu_mode: CpuMode, + + // Prefix + prefix: PrefixFlags, + override_seg: Option, + + // Opcode description + opdesc: Option, + opsize: Bytes, + addrsize: Bytes, + + // ModR/M byte + modrm: ModRM, + modrm_reg: Option, + + // SIB byte + sib: Sib, + scale: u8, + index_reg: Option, + base_reg: Option, + + // Optional addr displacement + displacement: i64, + + // Optional immediate operand + immediate: i64, + + // Instruction repeat count + repeat: usize, +} + +impl DecodedInsnCtx { + /// Constructs a new `DecodedInsnCtx` by decoding the given + /// instruction bytes using the provided machine context. + /// + /// # Arguments + /// + /// * `bytes` - The raw bytes of the instruction to be decoded. + /// * `mctx` - A reference to an object implementing the + /// `InsnMachineCtx` trait to provide the necessary machine context + /// for decoding. + /// + /// # Returns + /// + /// A `DecodedInsnCtx` if decoding is successful or an `InsnError` + /// otherwise. + pub(super) fn new( + bytes: &[u8; MAX_INSN_SIZE], + mctx: &I, + ) -> Result { + let mut insn_ctx = Self { + cpu_mode: get_cpu_mode(mctx), + ..Default::default() + }; + + insn_ctx.decode(bytes, mctx).map(|_| insn_ctx) + } + + /// Retrieves the decoded instruction, if available. + /// + /// # Returns + /// + /// An `Option` containing the DecodedInsn. + pub fn insn(&self) -> Option { + self.insn + } + + /// Retrieves the length of the decoded instruction in bytes. + /// + /// # Returns + /// + /// The length of the decoded instruction as a `usize`. If the + /// repeat count is greater than 1, then return 0 to indicate not to + /// skip this instruction. If the repeat count is less than 1, then + /// return instruction len to indicate this instruction can be skipped. + pub fn size(&self) -> usize { + if self.repeat > 1 { + 0 + } else { + self.insn_len + } + } + + /// Emulates the decoded instruction using the provided machine context. + /// + /// # Arguments + /// + /// * `mctx` - A mutable reference to an object implementing the + /// `InsnMachineCtx` trait to provide the necessary machine context + /// for emulation. + /// + /// # Returns + /// + /// An `Ok(())` if emulation is successful or an `InsnError` otherwise. + pub fn emulate(&self, mctx: &mut I) -> Result<(), InsnError> { + self.insn + .ok_or(InsnError::UnSupportedInsn) + .and_then(|insn| match insn { + DecodedInsn::In(port, opsize) => self.emulate_in_out(port, opsize, mctx, true), + DecodedInsn::Out(port, opsize) => self.emulate_in_out(port, opsize, mctx, false), + DecodedInsn::Ins => self.emulate_ins_outs(mctx, true), + DecodedInsn::Outs => self.emulate_ins_outs(mctx, false), + DecodedInsn::Mov => self.emulate_mov(mctx), + _ => Err(InsnError::UnSupportedInsn), + }) + } + + fn decode( + &mut self, + bytes: &[u8; MAX_INSN_SIZE], + mctx: &I, + ) -> Result<(), InsnError> { + self.decode_prefixes(bytes, mctx) + .and_then(|insn| self.decode_opcode(insn)) + .and_then(|insn| self.decode_modrm_sib(insn)) + .and_then(|(insn, disp_bytes)| self.decode_displacement(insn, disp_bytes)) + .and_then(|insn| self.decode_immediate(insn)) + .and_then(|insn| self.decode_moffset(insn)) + .and_then(|insn| self.complete_decode(insn, mctx)) + } + + #[inline] + fn get_opdesc(&self) -> Result { + self.opdesc.ok_or(InsnError::NoOpCodeDesc) + } + + fn decode_rex_prefix(&mut self, code: u8) -> bool { + if !self.cpu_mode.is_bit64() { + return false; + } + + match code { + 0x40..=0x4F => { + let rex = RexPrefix::from_bits_truncate(code); + self.prefix.insert(PrefixFlags::REX_P); + if rex.contains(RexPrefix::W) { + self.prefix.insert(PrefixFlags::REX_W); + } + if rex.contains(RexPrefix::R) { + self.prefix.insert(PrefixFlags::REX_R); + } + if rex.contains(RexPrefix::X) { + self.prefix.insert(PrefixFlags::REX_X); + } + if rex.contains(RexPrefix::B) { + self.prefix.insert(PrefixFlags::REX_B); + } + true + } + _ => false, + } + } + + fn decode_op_addr_size(&mut self, cs: u64) { + (self.addrsize, self.opsize) = if self.cpu_mode.is_bit64() { + ( + if self.prefix.contains(PrefixFlags::ADDRSIZE_OVERRIDE) { + Bytes::Four + } else { + Bytes::Eight + }, + if self.prefix.contains(PrefixFlags::REX_W) { + Bytes::Eight + } else if self.prefix.contains(PrefixFlags::OPSIZE_OVERRIDE) { + Bytes::Two + } else { + Bytes::Four + }, + ) + } else if (cs & SegDescAttrFlags::DB.bits()) != 0 { + // Default address and operand sizes are 32-bits + ( + if self.prefix.contains(PrefixFlags::ADDRSIZE_OVERRIDE) { + Bytes::Two + } else { + Bytes::Four + }, + if self.prefix.contains(PrefixFlags::OPSIZE_OVERRIDE) { + Bytes::Two + } else { + Bytes::Four + }, + ) + } else { + // Default address and operand sizes are 16-bits + ( + if self.prefix.contains(PrefixFlags::ADDRSIZE_OVERRIDE) { + Bytes::Four + } else { + Bytes::Two + }, + if self.prefix.contains(PrefixFlags::OPSIZE_OVERRIDE) { + Bytes::Four + } else { + Bytes::Two + }, + ) + }; + } + + fn decode_prefixes( + &mut self, + bytes: &[u8; MAX_INSN_SIZE], + mctx: &I, + ) -> Result { + let mut insn = PrefixBytes(InsnBytes::new(*bytes)); + for _ in 0..PREFIX_SIZE { + match insn.0.peek()? { + 0x66 => self.prefix.insert(PrefixFlags::OPSIZE_OVERRIDE), + 0x67 => self.prefix.insert(PrefixFlags::ADDRSIZE_OVERRIDE), + 0xF3 => self.prefix.insert(PrefixFlags::REPZ_P), + 0xF2 => self.prefix.insert(PrefixFlags::REPNZ_P), + 0x2E => self.override_seg = Some(SegRegister::CS), + 0x36 => self.override_seg = Some(SegRegister::SS), + 0x3E => self.override_seg = Some(SegRegister::DS), + 0x26 => self.override_seg = Some(SegRegister::ES), + 0x64 => self.override_seg = Some(SegRegister::FS), + 0x65 => self.override_seg = Some(SegRegister::GS), + _ => break, + } + insn.0.advance(); + } + + // From section 2.2.1, "REX Prefixes", Intel SDM Vol 2: + // - Only one REX prefix is allowed per instruction. + // - The REX prefix must immediately precede the opcode byte or the + // escape opcode byte. + // - If an instruction has a mandatory prefix (0x66, 0xF2 or 0xF3) + // the mandatory prefix must come before the REX prefix. + if self.decode_rex_prefix(insn.0.peek()?) { + insn.0.advance(); + } + + self.decode_op_addr_size(mctx.read_seg(SegRegister::CS)); + + Ok(OpCodeBytes(insn.0)) + } + + fn decode_opcode(&mut self, mut insn: OpCodeBytes) -> Result { + let opdesc = OpCodeDesc::decode(&mut insn).ok_or(InsnError::DecodeOpCode)?; + + if opdesc.flags.contains(OpCodeFlags::BYTE_OP) { + self.opsize = Bytes::One; + } else if opdesc.flags.contains(OpCodeFlags::WORD_OP) { + self.opsize = Bytes::Two; + } + + self.opdesc = Some(opdesc); + + Ok(ModRmBytes(insn.0)) + } + + fn decode_modrm_sib(&mut self, mut insn: ModRmBytes) -> Result<(DisBytes, Bytes), InsnError> { + if self.get_opdesc()?.flags.contains(OpCodeFlags::NO_MODRM) { + return Ok((DisBytes(insn.0), Bytes::Zero)); + } + + if self.cpu_mode == CpuMode::Real { + return Err(InsnError::DecodeModRM); + } + + self.modrm = ModRM::from(insn.0.peek()?); + + if self.get_opdesc()?.flags.contains(OpCodeFlags::OP_NONE) { + insn.0.advance(); + return Ok((DisBytes(insn.0), Bytes::Zero)); + } + + let r#mod = self.modrm.get_mod(); + let reg = self.modrm.get_reg() | ((self.prefix.contains(PrefixFlags::REX_R) as u8) << 3); + self.modrm_reg = Some(Register::try_from(RegCode(reg))?); + + // As the modrm decoding is majorly for MMIO instructions which requires + // a memory access, a direct addressing mode makes no sense in the context. + // There has to be a memory access involved to trap the MMIO instruction. + if r#mod == Mod::Direct { + return Err(InsnError::DecodeModRM); + } + + // SDM Vol2 Table 2-5: Special Cases of REX Encodings + // For mod=0 r/m=5 and mod!=3 r/m=4, the 'b' bit in the REX + // prefix is 'don't care' in these two cases. + // + // RM::Disp32 represent mod=0 r/m=5 + // RM::Sib represent mod!=3 r/m=4 + // RM::Reg(r) represent the other cases. + let disp_bytes = match self.modrm.get_rm() { + RM::Reg(r) => { + let ext_r = Register::try_from(RegCode( + r as u8 | ((self.prefix.contains(PrefixFlags::REX_B) as u8) << 3), + ))?; + self.base_reg = Some(ext_r); + match r#mod { + Mod::IndirectDisp8 => Bytes::One, + Mod::IndirectDisp32 => Bytes::Four, + Mod::Indirect | Mod::Direct => Bytes::Zero, + } + } + RM::Disp32 => { + // SDM Vol2 Table 2-7: RIP-Relative Addressing + // In 64bit mode, mod=0 r/m=5 implies [rip] + disp32 + // whereas in compatibility mode it just implies disp32. + self.base_reg = if self.cpu_mode.is_bit64() { + Some(Register::Rip) + } else { + None + }; + Bytes::Four + } + RM::Sib => { + insn.0.advance(); + return self.decode_sib(SibBytes(insn.0)); + } + }; + + insn.0.advance(); + Ok((DisBytes(insn.0), disp_bytes)) + } + + fn decode_sib(&mut self, mut insn: SibBytes) -> Result<(DisBytes, Bytes), InsnError> { + // Process only if SIB byte is present + if self.modrm.get_rm() != RM::Sib { + return Err(InsnError::DecodeSib); + } + + self.sib = Sib::from(insn.0.peek()?); + let index = self.sib.get_index() | ((self.prefix.contains(PrefixFlags::REX_X) as u8) << 3); + let base = self.sib.get_base() | ((self.prefix.contains(PrefixFlags::REX_B) as u8) << 3); + + let r#mod = self.modrm.get_mod(); + let disp_bytes = match r#mod { + Mod::IndirectDisp8 => { + self.base_reg = Some(Register::try_from(RegCode(base))?); + Bytes::One + } + Mod::IndirectDisp32 => { + self.base_reg = Some(Register::try_from(RegCode(base))?); + Bytes::Four + } + Mod::Indirect => { + let mut disp_bytes = Bytes::Zero; + // SMD Vol 2 Table 2-5 Special Cases of REX Encoding + // Base register is unused if mod=0 base=RBP/R13. + self.base_reg = if base == Register::Rbp as u8 || base == Register::R13 as u8 { + disp_bytes = Bytes::Four; + None + } else { + Some(Register::try_from(RegCode(base))?) + }; + disp_bytes + } + Mod::Direct => Bytes::Zero, + }; + + // SMD Vol 2 Table 2-5 Special Cases of REX Encoding + // Index register not used when index=RSP + if index != Register::Rsp as u8 { + self.index_reg = Some(Register::try_from(RegCode(index))?); + // 'scale' makes sense only in the context of an index register + self.scale = 1 << self.sib.get_scale(); + } + + insn.0.advance(); + Ok((DisBytes(insn.0), disp_bytes)) + } + + fn decode_displacement( + &mut self, + mut insn: DisBytes, + disp_bytes: Bytes, + ) -> Result { + match disp_bytes { + Bytes::Zero => Ok(ImmBytes(insn.0)), + Bytes::One | Bytes::Four => { + let mut buf = [0; 4]; + + for v in buf.iter_mut().take(disp_bytes as usize) { + *v = insn.0.peek()?; + insn.0.advance(); + } + + self.displacement = if disp_bytes == Bytes::One { + buf[0] as i8 as i64 + } else { + i32::from_le_bytes(buf) as i64 + }; + + Ok(ImmBytes(insn.0)) + } + _ => Err(InsnError::DecodeDisp), + } + } + + fn decode_immediate(&mut self, mut insn: ImmBytes) -> Result { + // Figure out immediate operand size (if any) + let imm_bytes = if self.get_opdesc()?.flags.contains(OpCodeFlags::IMM) { + match self.opsize { + // SDM Vol 2 2.2.1.5 "Immediates" + // In 64-bit mode the typical size of immediate operands + // remains 32-bits. When the operand size if 64-bits, the + // processor sign-extends all immediates to 64-bits prior + // to their use. + Bytes::Four | Bytes::Eight => Bytes::Four, + _ => Bytes::Two, + } + } else if self.get_opdesc()?.flags.contains(OpCodeFlags::IMM8) { + Bytes::One + } else { + // No flags on immediate operand size + return Ok(MoffBytes(insn.0)); + }; + + let mut buf = [0; 4]; + + for v in buf.iter_mut().take(imm_bytes as usize) { + *v = insn.0.peek()?; + insn.0.advance(); + } + + self.immediate = match imm_bytes { + Bytes::One => buf[0] as i8 as i64, + Bytes::Two => i16::from_le_bytes([buf[0], buf[1]]) as i64, + Bytes::Four => i32::from_le_bytes(buf) as i64, + _ => return Err(InsnError::DecodeImm), + }; + + Ok(MoffBytes(insn.0)) + } + + fn decode_moffset(&mut self, mut insn: MoffBytes) -> Result { + if !self.get_opdesc()?.flags.contains(OpCodeFlags::MOFFSET) { + return Ok(DecodedBytes(insn.0)); + } + + match self.addrsize { + Bytes::Zero | Bytes::One => Err(InsnError::DecodeMOffset), + _ => { + // SDM Vol 2 Section 2.2.1.4, "Direct Memory-Offset MOVs" + // In 64-bit mode, direct memory-offset forms of the MOV + // instruction are extended to specify a 64-bit immediate + // absolute address. + // + // The memory offset size follows the address-size of the instruction. + let mut buf = [0; 8]; + for v in buf.iter_mut().take(self.addrsize as usize) { + *v = insn.0.peek()?; + insn.0.advance(); + } + self.displacement = i64::from_le_bytes(buf); + Ok(DecodedBytes(insn.0)) + } + } + } + + fn complete_decode( + &mut self, + insn: DecodedBytes, + mctx: &I, + ) -> Result<(), InsnError> { + self.insn_len = insn.0.processed(); + self.decoded_insn(mctx) + .map(|decoded_insn| self.insn = Some(decoded_insn)) + } + + fn decoded_insn(&mut self, mctx: &I) -> Result { + let opdesc = self.get_opdesc()?; + Ok(match opdesc.class { + OpCodeClass::Cpuid => DecodedInsn::Cpuid, + OpCodeClass::In => { + if opdesc.flags.contains(OpCodeFlags::IMM8) { + DecodedInsn::In( + Operand::Imm(Immediate::U8(self.immediate as u8)), + self.opsize, + ) + } else { + DecodedInsn::In(Operand::rdx(), self.opsize) + } + } + OpCodeClass::Out => { + if opdesc.flags.contains(OpCodeFlags::IMM8) { + DecodedInsn::Out( + Operand::Imm(Immediate::U8(self.immediate as u8)), + self.opsize, + ) + } else { + DecodedInsn::Out(Operand::rdx(), self.opsize) + } + } + OpCodeClass::Ins | OpCodeClass::Outs => { + if self.prefix.contains(PrefixFlags::REPZ_P) { + // The prefix REPZ(F3h) actually represents REP for ins/outs. + // The count register is depending on the address size of the + // instruction. + self.repeat = read_reg(mctx, Register::Rcx, self.addrsize); + }; + + if opdesc.class == OpCodeClass::Ins { + DecodedInsn::Ins + } else { + DecodedInsn::Outs + } + } + OpCodeClass::Rdmsr => DecodedInsn::Rdmsr, + OpCodeClass::Rdtsc => DecodedInsn::Rdtsc, + OpCodeClass::Rdtscp => DecodedInsn::Rdtscp, + OpCodeClass::Wrmsr => DecodedInsn::Wrmsr, + OpCodeClass::Mov => DecodedInsn::Mov, + _ => return Err(InsnError::UnSupportedInsn), + }) + } + + #[inline] + fn get_modrm_reg(&self) -> Result { + self.modrm_reg.ok_or(InsnError::InvalidDecode) + } + + fn cal_modrm_bytereg(&self) -> Result<(Register, bool), InsnError> { + let reg = self.get_modrm_reg()?; + // 64-bit mode imposes limitations on accessing legacy high byte + // registers (lhbr). + // + // The legacy high-byte registers cannot be addressed if the REX + // prefix is present. In this case the values 4, 5, 6 and 7 of the + // 'ModRM:reg' field address %spl, %bpl, %sil and %dil respectively. + // + // If the REX prefix is not present then the values 4, 5, 6 and 7 + // of the 'ModRM:reg' field address the legacy high-byte registers, + // %ah, %ch, %dh and %bh respectively. + Ok( + if !self.prefix.contains(PrefixFlags::REX_P) && (reg as u8 & 0x4) != 0 { + (Register::try_from(RegCode(reg as u8 & 0x3))?, true) + } else { + (reg, false) + }, + ) + } + + fn canonical_check(&self, la: usize) -> Option { + if match self.cpu_mode { + CpuMode::Bit64(level) => { + let virtaddr_bits = if level == PagingLevel::Level4 { 48 } else { 57 }; + let mask = !((1 << virtaddr_bits) - 1); + if la & (1 << (virtaddr_bits - 1)) != 0 { + la & mask == mask + } else { + la & mask == 0 + } + } + _ => true, + } { + Some(la) + } else { + None + } + } + + fn alignment_check(&self, la: usize, size: Bytes) -> Option { + match size { + // Zero size is not allowed + Bytes::Zero => None, + // One byte is always aligned + Bytes::One => Some(la), + // Two/Four/Eight bytes must be aligned on a boundary + _ => { + if la & (size as usize - 1) != 0 { + None + } else { + Some(la) + } + } + } + } + + fn cal_linear_addr( + &self, + mctx: &I, + seg: SegRegister, + ea: usize, + writable: bool, + ) -> Option { + let segment = mctx.read_seg(seg); + + let addrsize = if self.cpu_mode.is_bit64() { + Bytes::Eight + } else { + let attr = SegDescAttrFlags::from_bits_truncate(segment); + // Invalid if is system segment + if !attr.contains(SegDescAttrFlags::S) { + return None; + } + + if writable { + // Writing to a code segment, or writing to a read-only + // data segment is not allowed. + if attr.contains(SegDescAttrFlags::C_D) || !attr.contains(SegDescAttrFlags::R_W) { + return None; + } + } else { + // Data segment is always read-able, but code segment + // may be execute only. Invalid if read an execute only + // code segment. + if attr.contains(SegDescAttrFlags::C_D) && !attr.contains(SegDescAttrFlags::R_W) { + return None; + } + } + + let mut limit = segment_limit(segment) as usize; + + if !attr.contains(SegDescAttrFlags::C_D) && attr.contains(SegDescAttrFlags::C_E) { + // Expand-down segment, check low limit + if ea <= limit { + return None; + } + + limit = if attr.contains(SegDescAttrFlags::DB) { + u32::MAX as usize + } else { + u16::MAX as usize + } + } + + // Check high limit for each byte + for i in 0..self.opsize as usize { + if ea + i > limit { + return None; + } + } + + Bytes::Four + }; + + self.canonical_check( + if self.cpu_mode.is_bit64() && seg != SegRegister::FS && seg != SegRegister::GS { + ea & (addrsize.mask() as usize) + } else { + (segment_base(segment) as usize + ea) & addrsize.mask() as usize + }, + ) + } + + fn get_linear_addr( + &self, + mctx: &I, + seg: SegRegister, + ea: usize, + writable: bool, + ) -> Result { + self.cal_linear_addr(mctx, seg, ea, writable) + .ok_or(if seg == SegRegister::SS { + InsnError::ExceptionSS + } else { + InsnError::ExceptionGP(0) + }) + .and_then(|la| { + if (mctx.read_cpl() == 3) + && (mctx.read_cr0() & CR0Flags::AM.bits()) != 0 + && (mctx.read_flags() & RFlags::AC.bits()) != 0 + { + self.alignment_check(la, self.opsize) + .ok_or(InsnError::ExceptionAC) + } else { + Ok(la) + } + }) + } + + fn emulate_ins_outs( + &self, + mctx: &mut I, + io_read: bool, + ) -> Result<(), InsnError> { + // I/O port number is stored in DX. + let port = mctx.read_reg(Register::Rdx) as u16; + + // Check the IO permission bit map. + if !ioio_perm(mctx, port, self.opsize, io_read) { + return Err(InsnError::ExceptionGP(0)); + } + + let (seg, reg) = if io_read { + // Input byte from I/O port specified in DX into + // memory location specified with ES:(E)DI or + // RDI. + (SegRegister::ES, Register::Rdi) + } else { + // Output byte/word/doubleword from memory location specified in + // DS:(E)SI (The DS segment may be overridden with a segment + // override prefix.) or RSI to I/O port specified in DX. + ( + self.override_seg.map_or(SegRegister::DS, |s| s), + Register::Rsi, + ) + }; + + // Decoed the linear addresses and map as a memory object + // which allows accessing to the memory represented by the + // linear addresses. + let linear_addr = + self.get_linear_addr(mctx, seg, read_reg(mctx, reg, self.addrsize), io_read)?; + if io_read { + // Read data from IO port and then write to the memory location. + let data = mctx.ioio_in(port, self.opsize)?; + // Safety: The linear address is decoded from the instruction and checked. It can be + // remapped to a memory object with the write permission successfully, and the remapped + // memory size matches the operand size of the instruction. + unsafe { + match self.opsize { + Bytes::One => mctx + .map_linear_addr::(linear_addr, io_read, false)? + .mem_write(data as u8)?, + Bytes::Two => mctx + .map_linear_addr::(linear_addr, io_read, false)? + .mem_write(data as u16)?, + Bytes::Four => mctx + .map_linear_addr::(linear_addr, io_read, false)? + .mem_write(data as u32)?, + _ => return Err(InsnError::IoIoIn), + }; + } + } else { + // Read data from memory location and then write to the IO port + // + // Safety: The linear address is decoded from the instruction and checked. It can be + // remapped to a memory object with the read permission successfully, and the remapped + // memory size matches the operand size of the instruction. + let data = unsafe { + match self.opsize { + Bytes::One => mctx + .map_linear_addr::(linear_addr, io_read, false)? + .mem_read()? as u64, + Bytes::Two => mctx + .map_linear_addr::(linear_addr, io_read, false)? + .mem_read()? as u64, + Bytes::Four => mctx + .map_linear_addr::(linear_addr, io_read, false)? + .mem_read()? as u64, + _ => return Err(InsnError::IoIoOut), + } + }; + mctx.ioio_out(port, self.opsize, data)?; + } + + let rflags = RFlags::from_bits_truncate(mctx.read_flags()); + if rflags.contains(RFlags::DF) { + // The DF flag is 1, the (E)SI/DI register is decremented. + write_reg( + mctx, + reg, + read_reg(mctx, reg, self.addrsize) + .checked_sub(self.opsize as usize) + .ok_or(InsnError::IoIoOut)?, + self.addrsize, + ); + } else { + // The DF flag is 0, the (E)SI/DI register is incremented. + write_reg( + mctx, + reg, + read_reg(mctx, reg, self.addrsize) + .checked_add(self.opsize as usize) + .ok_or(InsnError::IoIoOut)?, + self.addrsize, + ); + } + + if self.repeat != 0 { + // Update the count register with the left count which are not + // emulated yet. + write_reg(mctx, Register::Rcx, self.repeat - 1, self.addrsize); + } + + Ok(()) + } + + fn emulate_in_out( + &self, + port: Operand, + opsize: Bytes, + mctx: &mut I, + io_read: bool, + ) -> Result<(), InsnError> { + let port = match port { + Operand::Reg(Register::Rdx) => mctx.read_reg(Register::Rdx) as u16, + Operand::Reg(..) => unreachable!("Port value is always in DX"), + Operand::Imm(imm) => match imm { + Immediate::U8(val) => val as u16, + _ => unreachable!("Port value in immediate is always 1 byte"), + }, + }; + + // Check the IO permission bit map + if !ioio_perm(mctx, port, opsize, io_read) { + return Err(InsnError::ExceptionGP(0)); + } + + if io_read { + // Read data from IO port and then write to AL/AX/EAX. + write_reg( + mctx, + Register::Rax, + mctx.ioio_in(port, opsize)? as usize, + opsize, + ); + } else { + // Read data from AL/AX/EAX and then write to the IO port. + mctx.ioio_out(port, opsize, read_reg(mctx, Register::Rax, opsize) as u64)?; + } + + Ok(()) + } + + fn cal_effective_addr(&self, mctx: &I) -> Result { + let base = if let Some(reg) = self.base_reg { + match reg { + Register::Rip => { + // RIP relative addressing is used in 64bit mode and + // starts from the following instruction + mctx.read_reg(reg) + self.insn_len + } + _ => mctx.read_reg(reg), + } + } else { + 0 + }; + + let index = if let Some(reg) = self.index_reg { + mctx.read_reg(reg) + } else { + 0 + }; + + Ok(base + .checked_add(index << (self.scale as usize)) + .and_then(|v| v.checked_add(self.displacement as usize)) + .ok_or(InsnError::InvalidDecode)? + & self.addrsize.mask() as usize) + } + + #[inline] + fn emulate_mmio_read( + &self, + mctx: &I, + seg: SegRegister, + ea: usize, + ) -> Result { + mctx.translate_linear_addr(self.get_linear_addr(mctx, seg, ea, false)?, false, false) + .and_then(|(addr, shared)| mctx.handle_mmio_read(addr, shared, self.opsize)) + } + + #[inline] + fn emulate_mmio_write( + &self, + mctx: &mut I, + seg: SegRegister, + ea: usize, + data: u64, + ) -> Result<(), InsnError> { + mctx.translate_linear_addr(self.get_linear_addr(mctx, seg, ea, true)?, true, false) + .and_then(|(addr, shared)| mctx.handle_mmio_write(addr, shared, self.opsize, data)) + } + + fn emulate_mov(&self, mctx: &mut I) -> Result<(), InsnError> { + if self.prefix.contains(PrefixFlags::REPZ_P) { + return Err(InsnError::UnSupportedInsn); + } + + let seg = if let Some(s) = self.override_seg { + s + } else if self.base_reg == Some(Register::Rsp) || self.base_reg == Some(Register::Rbp) { + SegRegister::SS + } else { + SegRegister::DS + }; + let ea = self.cal_effective_addr(mctx)?; + + match self.get_opdesc()?.code { + 0x88 => { + // Mov byte from reg (ModRM:reg) to mem (ModRM:r/m) + // 88/r: mov r/m8, r8 + // REX + 88/r: mov r/m8, r8 (%ah, %ch, %dh, %bh not available) + let (reg, lhbr) = self.cal_modrm_bytereg()?; + let data = read_bytereg(mctx, reg, lhbr); + self.emulate_mmio_write(mctx, seg, ea, data as u64)?; + } + 0x89 => { + // MOV from reg (ModRM:reg) to mem (ModRM:r/m) + // 89/r: mov r/m16, r16 + // 89/r: mov r/m32, r32 + // REX.W + 89/r mov r/m64, r64 + let data = read_reg(mctx, self.get_modrm_reg()?, self.opsize); + self.emulate_mmio_write(mctx, seg, ea, data as u64)?; + } + 0x8A => { + // MOV byte from mem (ModRM:r/m) to reg (ModRM:reg) + // 8A/r: mov r8, r/m8 + // REX + 8A/r: mov r8, r/m8 + let data = self.emulate_mmio_read(mctx, seg, ea)?; + let (reg, lhbr) = self.cal_modrm_bytereg()?; + write_bytereg(mctx, reg, lhbr, data as u8); + } + 0x8B => { + // MOV from mem (ModRM:r/m) to reg (ModRM:reg) + // 8B/r: mov r16, r/m16 + // 8B/r: mov r32, r/m32 + // REX.W 8B/r: mov r64, r/m64 + let data = self.emulate_mmio_read(mctx, seg, ea)?; + write_reg(mctx, self.get_modrm_reg()?, data as usize, self.opsize); + } + 0xA1 => { + // MOV from seg:moffset to AX/EAX/RAX + // A1: mov AX, moffs16 + // A1: mov EAX, moffs32 + // REX.W + A1: mov RAX, moffs64 + let data = self.emulate_mmio_read(mctx, seg, ea)?; + write_reg(mctx, Register::Rax, data as usize, self.opsize); + } + 0xA3 => { + // MOV from AX/EAX/RAX to seg:moffset + // A3: mov moffs16, AX + // A3: mov moffs32, EAX + // REX.W + A3: mov moffs64, RAX + let data = read_reg(mctx, Register::Rax, self.opsize); + self.emulate_mmio_write(mctx, seg, ea, data as u64)?; + } + 0xC6 | 0xC7 => { + // MOV from imm8 to mem (ModRM:r/m) + // C6/0 mov r/m8, imm8 + // REX + C6/0 mov r/m8, imm8 + // MOV from imm16/imm32 to mem (ModRM:r/m) + // C7/0 mov r/m16, imm16 + // C7/0 mov r/m32, imm32 + // REX.W + C7/0 mov r/m64, imm32 (sign-extended to 64-bits) + self.emulate_mmio_write(mctx, seg, ea, self.immediate as u64 & self.opsize.mask())?; + } + _ => return Err(InsnError::UnSupportedInsn), + } + + Ok(()) + } +} diff --git a/stage2/src/insn_decode/insn.rs b/stage2/src/insn_decode/insn.rs new file mode 100644 index 000000000..efc412657 --- /dev/null +++ b/stage2/src/insn_decode/insn.rs @@ -0,0 +1,1362 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Thomas Leroy + +use super::decode::DecodedInsnCtx; +use super::{InsnError, InsnMachineCtx}; +use crate::types::Bytes; + +/// An immediate value in an instruction +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Immediate { + U8(u8), + U16(u16), + U32(u32), +} + +/// A register in an instruction +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Register { + Rax, + Rcx, + Rdx, + Rbx, + Rsp, + Rbp, + Rsi, + Rdi, + R8, + R9, + R10, + R11, + R12, + R13, + R14, + R15, + Rip, +} + +/// A Segment register in instruction +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum SegRegister { + CS, + SS, + DS, + ES, + FS, + GS, +} + +/// An operand in an instruction, which might be a register or an immediate. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum Operand { + Reg(Register), + Imm(Immediate), +} + +impl Operand { + #[inline] + pub const fn rdx() -> Self { + Self::Reg(Register::Rdx) + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DecodedInsn { + Cpuid, + In(Operand, Bytes), + Ins, + Mov, + Out(Operand, Bytes), + Outs, + Wrmsr, + Rdmsr, + Rdtsc, + Rdtscp, +} + +pub const MAX_INSN_SIZE: usize = 15; + +/// A view of an x86 instruction. +#[derive(Default, Debug, Copy, Clone, PartialEq)] +pub struct Instruction([u8; MAX_INSN_SIZE]); + +impl Instruction { + pub const fn new(bytes: [u8; MAX_INSN_SIZE]) -> Self { + Self(bytes) + } + + /// Decode the instruction with the given InsnMachineCtx. + /// + /// # Returns + /// + /// A [`DecodedInsnCtx`] if the instruction is supported, or an [`InsnError`] otherwise. + pub fn decode(&self, mctx: &I) -> Result { + DecodedInsnCtx::new(&self.0, mctx) + } +} + +#[cfg(any(test, fuzzing))] +pub mod test_utils { + extern crate alloc; + + use crate::cpu::control_regs::{CR0Flags, CR4Flags}; + use crate::cpu::efer::EFERFlags; + use crate::insn_decode::*; + use crate::types::Bytes; + use alloc::boxed::Box; + + pub const TEST_PORT: u16 = 0xE0; + + /// A dummy struct to implement InsnMachineCtx for testing purposes. + #[derive(Copy, Clone, Debug)] + pub struct TestCtx { + pub efer: u64, + pub cr0: u64, + pub cr4: u64, + + pub rax: usize, + pub rdx: usize, + pub rcx: usize, + pub rbx: usize, + pub rsp: usize, + pub rbp: usize, + pub rdi: usize, + pub rsi: usize, + pub r8: usize, + pub r9: usize, + pub r10: usize, + pub r11: usize, + pub r12: usize, + pub r13: usize, + pub r14: usize, + pub r15: usize, + pub rip: usize, + pub flags: usize, + + pub ioport: u16, + pub iodata: u64, + + pub mmio_reg: u64, + } + + impl Default for TestCtx { + fn default() -> Self { + Self { + efer: EFERFlags::LMA.bits(), + cr0: CR0Flags::PE.bits(), + cr4: CR4Flags::LA57.bits(), + rax: 0, + rdx: 0, + rcx: 0, + rbx: 0, + rsp: 0, + rbp: 0, + rdi: 0, + rsi: 0, + r8: 0, + r9: 0, + r10: 0, + r11: 0, + r12: 0, + r13: 0, + r14: 0, + r15: 0, + rip: 0, + flags: 0, + ioport: TEST_PORT, + iodata: u64::MAX, + mmio_reg: 0, + } + } + } + + #[cfg_attr(not(test), expect(dead_code))] + struct TestMem { + ptr: *mut T, + } + + impl InsnMachineCtx for TestCtx { + fn read_efer(&self) -> u64 { + self.efer + } + + fn read_seg(&self, seg: SegRegister) -> u64 { + match seg { + SegRegister::CS => 0x00af9a000000ffffu64, + _ => 0x00cf92000000ffffu64, + } + } + + fn read_cr0(&self) -> u64 { + self.cr0 + } + + fn read_cr4(&self) -> u64 { + self.cr4 + } + + fn read_reg(&self, reg: Register) -> usize { + match reg { + Register::Rax => self.rax, + Register::Rdx => self.rdx, + Register::Rcx => self.rcx, + Register::Rbx => self.rdx, + Register::Rsp => self.rsp, + Register::Rbp => self.rbp, + Register::Rdi => self.rdi, + Register::Rsi => self.rsi, + Register::R8 => self.r8, + Register::R9 => self.r9, + Register::R10 => self.r10, + Register::R11 => self.r11, + Register::R12 => self.r12, + Register::R13 => self.r13, + Register::R14 => self.r14, + Register::R15 => self.r15, + Register::Rip => self.rip, + } + } + + fn write_reg(&mut self, reg: Register, val: usize) { + match reg { + Register::Rax => self.rax = val, + Register::Rdx => self.rdx = val, + Register::Rcx => self.rcx = val, + Register::Rbx => self.rdx = val, + Register::Rsp => self.rsp = val, + Register::Rbp => self.rbp = val, + Register::Rdi => self.rdi = val, + Register::Rsi => self.rsi = val, + Register::R8 => self.r8 = val, + Register::R9 => self.r9 = val, + Register::R10 => self.r10 = val, + Register::R11 => self.r11 = val, + Register::R12 => self.r12 = val, + Register::R13 => self.r13 = val, + Register::R14 => self.r14 = val, + Register::R15 => self.r15 = val, + Register::Rip => self.rip = val, + } + } + + fn read_cpl(&self) -> usize { + 0 + } + + fn read_flags(&self) -> usize { + self.flags + } + + fn map_linear_addr( + &self, + la: usize, + _write: bool, + _fetch: bool, + ) -> Result>, InsnError> { + Ok(Box::new(TestMem { ptr: la as *mut T })) + } + + fn ioio_in(&self, _port: u16, size: Bytes) -> Result { + match size { + Bytes::One => Ok(self.iodata as u8 as u64), + Bytes::Two => Ok(self.iodata as u16 as u64), + Bytes::Four => Ok(self.iodata as u32 as u64), + _ => Err(InsnError::IoIoIn), + } + } + + fn ioio_out(&mut self, _port: u16, size: Bytes, data: u64) -> Result<(), InsnError> { + match size { + Bytes::One => self.iodata = data as u8 as u64, + Bytes::Two => self.iodata = data as u16 as u64, + Bytes::Four => self.iodata = data as u32 as u64, + _ => return Err(InsnError::IoIoOut), + } + + Ok(()) + } + + fn translate_linear_addr( + &self, + la: usize, + _write: bool, + _fetch: bool, + ) -> Result<(usize, bool), InsnError> { + Ok((la, false)) + } + + fn handle_mmio_read( + &self, + pa: usize, + _shared: bool, + size: Bytes, + ) -> Result { + if pa != &raw const self.mmio_reg as usize { + return Ok(0); + } + + match size { + Bytes::One => Ok(unsafe { *(pa as *const u8) } as u64), + Bytes::Two => Ok(unsafe { *(pa as *const u16) } as u64), + Bytes::Four => Ok(unsafe { *(pa as *const u32) } as u64), + Bytes::Eight => Ok(unsafe { *(pa as *const u64) }), + _ => Err(InsnError::HandleMmioRead), + } + } + + fn handle_mmio_write( + &mut self, + pa: usize, + _shared: bool, + size: Bytes, + data: u64, + ) -> Result<(), InsnError> { + if pa != &raw const self.mmio_reg as usize { + return Ok(()); + } + + match size { + Bytes::One => unsafe { *(pa as *mut u8) = data as u8 }, + Bytes::Two => unsafe { *(pa as *mut u16) = data as u16 }, + Bytes::Four => unsafe { *(pa as *mut u32) = data as u32 }, + Bytes::Eight => unsafe { *(pa as *mut u64) = data }, + _ => return Err(InsnError::HandleMmioWrite), + } + Ok(()) + } + } + + #[cfg(test)] + impl InsnMachineMem for TestMem { + type Item = T; + + unsafe fn mem_read(&self) -> Result { + Ok(*(self.ptr)) + } + + unsafe fn mem_write(&mut self, data: Self::Item) -> Result<(), InsnError> { + *(self.ptr) = data; + Ok(()) + } + } + + #[cfg(fuzzing)] + impl InsnMachineMem for TestMem { + type Item = T; + + unsafe fn mem_read(&self) -> Result { + Err(InsnError::MemRead) + } + + unsafe fn mem_write(&mut self, _data: Self::Item) -> Result<(), InsnError> { + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + use super::test_utils::*; + use super::*; + use crate::cpu::registers::RFlags; + + #[test] + fn test_decode_inb() { + let mut testctx = TestCtx { + iodata: 0xab, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xE4, + TEST_PORT as u8, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::In(Operand::Imm(Immediate::U8(TEST_PORT as u8)), Bytes::One) + ); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.rax as u64, testctx.iodata); + + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + iodata: 0xab, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xEC, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::In(Operand::rdx(), Bytes::One) + ); + assert_eq!(decoded.size(), 1); + assert_eq!(testctx.rax as u64, testctx.iodata); + } + + #[test] + fn test_decode_inw() { + let mut testctx = TestCtx { + iodata: 0xabcd, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, + 0xE5, + TEST_PORT as u8, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::In(Operand::Imm(Immediate::U8(TEST_PORT as u8)), Bytes::Two) + ); + assert_eq!(decoded.size(), 3); + assert_eq!(testctx.rax as u64, testctx.iodata); + + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + iodata: 0xabcd, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, 0xED, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::In(Operand::rdx(), Bytes::Two) + ); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.rax as u64, testctx.iodata); + } + + #[test] + fn test_decode_inl() { + let mut testctx = TestCtx { + iodata: 0xabcdef01, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xE5, + TEST_PORT as u8, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::In(Operand::Imm(Immediate::U8(TEST_PORT as u8)), Bytes::Four) + ); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.rax as u64, testctx.iodata); + + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + iodata: 0xabcdef01, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xED, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::In(Operand::rdx(), Bytes::Four) + ); + assert_eq!(decoded.size(), 1); + assert_eq!(testctx.rax as u64, testctx.iodata); + } + + #[test] + fn test_decode_outb() { + let mut testctx = TestCtx { + rax: 0xab, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xE6, + TEST_PORT as u8, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::Out(Operand::Imm(Immediate::U8(TEST_PORT as u8)), Bytes::One) + ); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.rax as u64, testctx.iodata); + + let mut testctx = TestCtx { + rax: 0xab, + rdx: TEST_PORT as usize, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xEE, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::Out(Operand::rdx(), Bytes::One) + ); + assert_eq!(decoded.size(), 1); + assert_eq!(testctx.rax as u64, testctx.iodata); + } + + #[test] + fn test_decode_outw() { + let mut testctx = TestCtx { + rax: 0xabcd, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, + 0xE7, + TEST_PORT as u8, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::Out(Operand::Imm(Immediate::U8(TEST_PORT as u8)), Bytes::Two) + ); + assert_eq!(decoded.size(), 3); + assert_eq!(testctx.rax as u64, testctx.iodata); + + let mut testctx = TestCtx { + rax: 0xabcd, + rdx: TEST_PORT as usize, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, 0xEF, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::Out(Operand::rdx(), Bytes::Two) + ); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.rax as u64, testctx.iodata); + } + + #[test] + fn test_decode_outl() { + let mut testctx = TestCtx { + rax: 0xabcdef01, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xE7, + TEST_PORT as u8, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::Out(Operand::Imm(Immediate::U8(TEST_PORT as u8)), Bytes::Four) + ); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.rax as u64, testctx.iodata); + + let mut testctx = TestCtx { + rax: 0xabcdef01, + rdx: TEST_PORT as usize, + ..Default::default() + }; + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xEF, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!( + decoded.insn().unwrap(), + DecodedInsn::Out(Operand::rdx(), Bytes::Four) + ); + assert_eq!(decoded.size(), 1); + assert_eq!(testctx.rax as u64, testctx.iodata); + } + + #[test] + fn test_decode_cpuid() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x0F, 0xA2, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let insn = Instruction::new(raw_insn); + let decoded = insn.decode(&TestCtx::default()).unwrap(); + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Cpuid); + assert_eq!(decoded.size(), 2); + } + + #[test] + fn test_decode_wrmsr() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x0F, 0x30, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let insn = Instruction::new(raw_insn); + let decoded = insn.decode(&TestCtx::default()).unwrap(); + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Wrmsr); + assert_eq!(decoded.size(), 2); + } + + #[test] + fn test_decode_rdmsr() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x0F, 0x32, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let insn = Instruction::new(raw_insn); + let decoded = insn.decode(&TestCtx::default()).unwrap(); + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Rdmsr); + assert_eq!(decoded.size(), 2); + } + + #[test] + fn test_decode_rdtsc() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x0F, 0x31, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let insn = Instruction::new(raw_insn); + let decoded = insn.decode(&TestCtx::default()).unwrap(); + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Rdtsc); + assert_eq!(decoded.size(), 2); + } + + #[test] + fn test_decode_rdtscp() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x0F, 0x01, 0xF9, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let insn = Instruction::new(raw_insn); + let decoded = insn.decode(&TestCtx::default()).unwrap(); + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Rdtscp); + assert_eq!(decoded.size(), 3); + } + + #[test] + fn test_decode_ins_u8() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xF3, 0x6C, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + let iodata: [u8; 4] = [0x12, 0x34, 0x56, 0x78]; + + let mut i = 0usize; + let mut testdata: [u8; 4] = [0; 4]; + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rdi: testdata.as_ptr() as usize, + ..Default::default() + }; + loop { + testctx.iodata = *iodata.get(i).unwrap() as u64; + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + if decoded.size() == 0 { + i += 1; + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Ins); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize + testdata.len() * Bytes::One as usize, + testctx.rdi + ); + assert_eq!(i, testdata.len() - 1); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + + i = iodata.len() - 1; + testdata = [0; 4]; + testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rdi: &raw const testdata[testdata.len() - 1] as usize, + flags: RFlags::DF.bits(), + ..Default::default() + }; + loop { + testctx.iodata = *iodata.get(i).unwrap() as u64; + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + if decoded.size() == 0 { + i = i.checked_sub(1).unwrap(); + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Ins); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize - Bytes::One as usize, + testctx.rdi + ); + assert_eq!(i, 0); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + } + + #[test] + fn test_decode_ins_u16() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, 0xF3, 0x6D, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + let iodata: [u16; 4] = [0x1234, 0x5678, 0x9abc, 0xdef0]; + + let mut i = 0usize; + let mut testdata: [u16; 4] = [0; 4]; + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rdi: testdata.as_ptr() as usize, + ..Default::default() + }; + loop { + testctx.iodata = *iodata.get(i).unwrap() as u64; + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + if decoded.size() == 0 { + i += 1; + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Ins); + assert_eq!(decoded.size(), 3); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize + testdata.len() * Bytes::Two as usize, + testctx.rdi + ); + assert_eq!(i, testdata.len() - 1); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + + i = iodata.len() - 1; + testdata = [0; 4]; + testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rdi: &raw const testdata[testdata.len() - 1] as usize, + flags: RFlags::DF.bits(), + ..Default::default() + }; + loop { + testctx.iodata = *iodata.get(i).unwrap() as u64; + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + if decoded.size() == 0 { + i = i.checked_sub(1).unwrap(); + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Ins); + assert_eq!(decoded.size(), 3); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize - Bytes::Two as usize, + testctx.rdi + ); + assert_eq!(i, 0); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + } + + #[test] + fn test_decode_ins_u32() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xF3, 0x6D, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + let iodata: [u32; 4] = [0x12345678, 0x9abcdef0, 0x87654321, 0x0fedcba9]; + + let mut i = 0usize; + let mut testdata: [u32; 4] = [0; 4]; + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rdi: testdata.as_ptr() as usize, + ..Default::default() + }; + loop { + testctx.iodata = *iodata.get(i).unwrap() as u64; + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + if decoded.size() == 0 { + i += 1; + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Ins); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize + testdata.len() * Bytes::Four as usize, + testctx.rdi + ); + assert_eq!(i, testdata.len() - 1); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + + i = iodata.len() - 1; + testdata = [0; 4]; + testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rdi: &raw const testdata[testdata.len() - 1] as usize, + flags: RFlags::DF.bits(), + ..Default::default() + }; + loop { + testctx.iodata = *iodata.get(i).unwrap() as u64; + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + if decoded.size() == 0 { + i = i.checked_sub(1).unwrap(); + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Ins); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize - Bytes::Four as usize, + testctx.rdi + ); + assert_eq!(i, 0); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + } + + #[test] + fn test_decode_outs_u8() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xF3, 0x6E, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + let testdata: [u8; 4] = [0x12, 0x34, 0x56, 0x78]; + + let mut i = 0usize; + let mut iodata: [u8; 4] = [0; 4]; + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rsi: testdata.as_ptr() as usize, + ..Default::default() + }; + loop { + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + *iodata.get_mut(i).unwrap() = testctx.iodata as u8; + if decoded.size() == 0 { + i += 1; + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Outs); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!(testdata.as_ptr() as usize + testdata.len(), testctx.rsi); + assert_eq!(i, testdata.len() - 1); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + + i = iodata.len() - 1; + iodata = [0; 4]; + testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rsi: &raw const testdata[testdata.len() - 1] as usize, + flags: RFlags::DF.bits(), + ..Default::default() + }; + loop { + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + *iodata.get_mut(i).unwrap() = testctx.iodata as u8; + if decoded.size() == 0 { + i = i.checked_sub(1).unwrap(); + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Outs); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize - Bytes::One as usize, + testctx.rsi + ); + assert_eq!(i, 0); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + } + + #[test] + fn test_decode_outs_u16() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, 0xF3, 0x6F, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + let testdata: [u16; 4] = [0x1234, 0x5678, 0x9abc, 0xdef0]; + + let mut i = 0usize; + let mut iodata: [u16; 4] = [0; 4]; + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rsi: testdata.as_ptr() as usize, + ..Default::default() + }; + loop { + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + *iodata.get_mut(i).unwrap() = testctx.iodata as u16; + if decoded.size() == 0 { + i += 1; + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Outs); + assert_eq!(decoded.size(), 3); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize + testdata.len() * Bytes::Two as usize, + testctx.rsi + ); + assert_eq!(i, testdata.len() - 1); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + + i = iodata.len() - 1; + iodata = [0; 4]; + testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rsi: &raw const testdata[testdata.len() - 1] as usize, + flags: RFlags::DF.bits(), + ..Default::default() + }; + loop { + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + *iodata.get_mut(i).unwrap() = testctx.iodata as u16; + if decoded.size() == 0 { + i = i.checked_sub(1).unwrap(); + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Outs); + assert_eq!(decoded.size(), 3); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize - Bytes::Two as usize, + testctx.rsi + ); + assert_eq!(i, 0); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + } + + #[test] + fn test_decode_outs_u32() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xF3, 0x6F, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + let testdata: [u32; 4] = [0x12345678, 0x9abcdef0, 0xdeadbeef, 0xfeedface]; + + let mut i = 0usize; + let mut iodata: [u32; 4] = [0; 4]; + let mut testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rsi: testdata.as_ptr() as usize, + ..Default::default() + }; + loop { + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + *iodata.get_mut(i).unwrap() = testctx.iodata as u32; + if decoded.size() == 0 { + i += 1; + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Outs); + assert_eq!(decoded.size(), 2); + assert_eq!(*testdata.last().unwrap() as u64, testctx.iodata); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize + testdata.len() * Bytes::Four as usize, + testctx.rsi + ); + assert_eq!(i, testdata.len() - 1); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + + i = iodata.len() - 1; + iodata = [0; 4]; + testctx = TestCtx { + rdx: TEST_PORT as usize, + rcx: testdata.len(), + rsi: &raw const testdata[testdata.len() - 1] as usize, + flags: RFlags::DF.bits(), + ..Default::default() + }; + loop { + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + *iodata.get_mut(i).unwrap() = testctx.iodata as u32; + if decoded.size() == 0 { + i = i.checked_sub(1).unwrap(); + continue; + } + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Outs); + assert_eq!(decoded.size(), 2); + assert_eq!(0, testctx.rcx); + assert_eq!( + testdata.as_ptr() as usize - Bytes::Four as usize, + testctx.rsi + ); + assert_eq!(i, 0); + for (i, d) in testdata.iter().enumerate() { + assert_eq!(d, iodata.get(i).unwrap()); + } + break; + } + } + + #[test] + fn test_decode_mov_reg_to_rm() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x88, 0x07, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + rax: 0xab, + ..Default::default() + }; + testctx.rdi = &raw const testctx.mmio_reg as usize; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.mmio_reg, testctx.rax as u64); + + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x48, 0x89, 0x07, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + rax: 0x1234567890abcdef, + ..Default::default() + }; + testctx.rdi = &raw const testctx.mmio_reg as usize; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 3); + assert_eq!(testctx.mmio_reg, testctx.rax as u64); + } + + #[test] + fn test_decode_mov_rm_to_reg() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x8A, 0x07, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + mmio_reg: 0xab, + ..Default::default() + }; + testctx.rdi = &raw const testctx.mmio_reg as usize; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 2); + assert_eq!(testctx.mmio_reg, testctx.rax as u64); + + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x48, 0x8B, 0x07, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + mmio_reg: 0x1234567890abcdef, + ..Default::default() + }; + testctx.rdi = &raw const testctx.mmio_reg as usize; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 3); + assert_eq!(testctx.mmio_reg, testctx.rax as u64); + } + + #[test] + fn test_decode_mov_moffset_to_reg() { + let mut raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xA1, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + mmio_reg: 0x12345678, + ..Default::default() + }; + let addr = (&raw const testctx.mmio_reg as usize).to_le_bytes(); + raw_insn[1..9].copy_from_slice(&addr); + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 9); + assert_eq!(testctx.mmio_reg, testctx.rax as u64); + } + + #[test] + fn test_decode_mov_reg_to_moffset() { + let mut raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xA3, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + rax: 0x12345678, + ..Default::default() + }; + let addr = (&raw const testctx.mmio_reg as usize).to_le_bytes(); + raw_insn[1..9].copy_from_slice(&addr); + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 9); + assert_eq!(testctx.mmio_reg, testctx.rax as u64); + } + + #[test] + fn test_decode_mov_imm_to_reg() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0xC6, 0x07, 0xab, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + ..Default::default() + }; + testctx.rdi = &raw const testctx.mmio_reg as usize; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 3); + assert_eq!(testctx.mmio_reg, 0xab); + + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x48, 0xC7, 0x07, 0x78, 0x56, 0x34, 0x12, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let mut testctx = TestCtx { + ..Default::default() + }; + testctx.rdi = &raw const testctx.mmio_reg as usize; + + let decoded = Instruction::new(raw_insn).decode(&testctx).unwrap(); + decoded.emulate(&mut testctx).unwrap(); + + assert_eq!(decoded.insn().unwrap(), DecodedInsn::Mov); + assert_eq!(decoded.size(), 7); + assert_eq!(testctx.mmio_reg, 0x12345678); + } + + #[test] + fn test_decode_failed() { + let raw_insn: [u8; MAX_INSN_SIZE] = [ + 0x66, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, + 0x41, + ]; + + let insn = Instruction::new(raw_insn); + let err = insn.decode(&TestCtx::default()); + + assert!(err.is_err()); + } +} diff --git a/stage2/src/insn_decode/mod.rs b/stage2/src/insn_decode/mod.rs new file mode 100644 index 000000000..96c0c5520 --- /dev/null +++ b/stage2/src/insn_decode/mod.rs @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2024 Intel Corporation. +// +// Author: Chuanxiao Dong + +mod decode; +mod insn; +mod opcode; + +pub use decode::{DecodedInsnCtx, InsnMachineCtx, InsnMachineMem}; +#[cfg(any(test, fuzzing))] +pub use insn::test_utils::TestCtx; +pub use insn::{ + DecodedInsn, Immediate, Instruction, Operand, Register, SegRegister, MAX_INSN_SIZE, +}; + +/// An error that can occur during instruction decoding. +#[derive(Copy, Clone, Debug)] +pub enum InsnError { + /// Error while decoding the displacement bytes. + DecodeDisp, + /// Error while decoding the immediate bytes. + DecodeImm, + /// Error while decoding the Mem-Offset bytes. + DecodeMOffset, + /// Error while decoding the ModR/M byte. + DecodeModRM, + /// Error while decoding the OpCode bytes. + DecodeOpCode, + /// Error while decoding the prefix bytes. + DecodePrefix, + /// Error while decoding the SIB byte. + DecodeSib, + /// Error due to alignment check exception. + ExceptionAC, + /// Error due to general protection exception. + ExceptionGP(u8), + /// Error due to page fault exception. + ExceptionPF(usize, u32), + /// Error due to stack segment exception. + ExceptionSS, + /// Error while mapping linear addresses. + MapLinearAddr, + /// Error while reading from memory. + MemRead, + /// Error while writing to memory. + MemWrite, + /// No OpCodeDesc generated while decoding. + NoOpCodeDesc, + /// Error while peeking an instruction byte. + InsnPeek, + /// The instruction decoding is not invalid. + InvalidDecode, + /// Invalid RegCode for decoding Register. + InvalidRegister, + /// Error while handling input IO operation. + IoIoIn, + /// Error while handling output IO operation. + IoIoOut, + /// The decoded instruction is not supported. + UnSupportedInsn, + /// Error while translating linear address. + TranslateLinearAddr, + /// Error while handling MMIO read operation. + HandleMmioRead, + /// Error while handling MMIO write operation. + HandleMmioWrite, +} + +impl From for crate::error::SvsmError { + fn from(e: InsnError) -> Self { + Self::Insn(e) + } +} diff --git a/stage2/src/insn_decode/opcode.rs b/stage2/src/insn_decode/opcode.rs new file mode 100644 index 000000000..dee32c6d3 --- /dev/null +++ b/stage2/src/insn_decode/opcode.rs @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2024 Intel Corporation +// +// Author: Chuanxiao Dong + +use super::decode::OpCodeBytes; +use bitflags::bitflags; + +bitflags! { + /// Defines a set of flags for opcode attributes. These flags provide + /// information about the characteristics of an opcode, such as the + /// presence of an immediate operand, operand size, and special decoding + /// requirements. + #[derive(Clone, Copy, Debug, Default, PartialEq)] + pub struct OpCodeFlags: u64 { + // Immediate operand with decoded size + const IMM = 1 << 0; + // U8 immediate operand + const IMM8 = 1 << 1; + // No need to decode ModRm + const NO_MODRM = 1 << 2; + // Operand size is one byte + const BYTE_OP = 1 << 3; + // Operand size is two byte + const WORD_OP = 1 << 4; + // Doesn't have an operand + const OP_NONE = 1 << 5; + // Need to decode Moffset + const MOFFSET = 1 << 6; + } +} + +/// Represents the classification of opcodes into distinct categories. +/// Each variant of the enum corresponds to a specific type of opcode +/// or a group of opcodes that share common characteristics or decoding +/// behaviors. +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum OpCodeClass { + Cpuid, + Group7, + Group7Rm7, + In, + Ins, + Mov, + Out, + Outs, + Rdmsr, + Rdtsc, + Rdtscp, + TwoByte, + Wrmsr, +} + +/// Descriptor for an opcode, which contains the raw instruction opcode +/// value, its corresponding class and flags for fully decoding the +/// instruction. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct OpCodeDesc { + /// The opcode value + pub code: u8, + /// The type of the opcode + pub class: OpCodeClass, + /// The flags for fully decoding the instruction + pub flags: OpCodeFlags, +} + +macro_rules! opcode { + ($class:expr) => { + Some(OpCodeDesc { + code: 0, + class: $class, + flags: OpCodeFlags::empty(), + }) + }; + ($code:expr, $class:expr) => { + Some(OpCodeDesc { + code: $code, + class: $class, + flags: OpCodeFlags::empty(), + }) + }; + ($code:expr, $class:expr, $flags:expr) => { + Some(OpCodeDesc { + code: $code, + class: $class, + flags: OpCodeFlags::from_bits_truncate($flags), + }) + }; +} + +static ONE_BYTE_TABLE: [Option; 256] = { + let mut table: [Option; 256] = [None; 256]; + + table[0x0F] = opcode!(OpCodeClass::TwoByte); + table[0x6C] = opcode!( + 0x6C, + OpCodeClass::Ins, + OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0x6D] = opcode!(0x6D, OpCodeClass::Ins, OpCodeFlags::NO_MODRM.bits()); + table[0x6E] = opcode!( + 0x6E, + OpCodeClass::Outs, + OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0x6F] = opcode!(0x6F, OpCodeClass::Outs, OpCodeFlags::NO_MODRM.bits()); + table[0x88] = opcode!(0x88, OpCodeClass::Mov, OpCodeFlags::BYTE_OP.bits()); + table[0x8A] = opcode!(0x8A, OpCodeClass::Mov, OpCodeFlags::BYTE_OP.bits()); + table[0x89] = opcode!(0x89, OpCodeClass::Mov); + table[0x8B] = opcode!(0x8B, OpCodeClass::Mov); + table[0xA1] = opcode!( + 0xA1, + OpCodeClass::Mov, + OpCodeFlags::MOFFSET.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xA3] = opcode!( + 0xA3, + OpCodeClass::Mov, + OpCodeFlags::MOFFSET.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xC6] = opcode!( + 0xC6, + OpCodeClass::Mov, + OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::IMM8.bits() + ); + table[0xC7] = opcode!(0xC7, OpCodeClass::Mov, OpCodeFlags::IMM.bits()); + table[0xE4] = opcode!( + 0xE4, + OpCodeClass::In, + OpCodeFlags::IMM8.bits() | OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xE5] = opcode!( + 0xE5, + OpCodeClass::In, + OpCodeFlags::IMM8.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xE6] = opcode!( + 0xE6, + OpCodeClass::Out, + OpCodeFlags::IMM8.bits() | OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xE7] = opcode!( + 0xE7, + OpCodeClass::Out, + OpCodeFlags::IMM8.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xEC] = opcode!( + 0xEC, + OpCodeClass::In, + OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xED] = opcode!(0xED, OpCodeClass::In, OpCodeFlags::NO_MODRM.bits()); + table[0xEE] = opcode!( + 0xEE, + OpCodeClass::Out, + OpCodeFlags::BYTE_OP.bits() | OpCodeFlags::NO_MODRM.bits() + ); + table[0xEF] = opcode!(0xEF, OpCodeClass::Out, OpCodeFlags::NO_MODRM.bits()); + + table +}; + +static GROUP7_RM7_TABLE: [Option; 8] = { + let mut table = [None; 8]; + + table[1] = opcode!(0xF9, OpCodeClass::Rdtscp, OpCodeFlags::OP_NONE.bits()); + + table +}; + +static GROUP7_TABLE: [Option; 16] = { + let mut table = [None; 16]; + + table[15] = opcode!(OpCodeClass::Group7Rm7); + + table +}; + +static TWO_BYTE_TABLE: [Option; 256] = { + let mut table = [None; 256]; + + table[0x01] = opcode!(OpCodeClass::Group7); + table[0x30] = opcode!(0x30, OpCodeClass::Wrmsr, OpCodeFlags::NO_MODRM.bits()); + table[0x31] = opcode!(0x31, OpCodeClass::Rdtsc, OpCodeFlags::NO_MODRM.bits()); + table[0x32] = opcode!(0x32, OpCodeClass::Rdmsr, OpCodeFlags::NO_MODRM.bits()); + table[0xA2] = opcode!(0xA2, OpCodeClass::Cpuid, OpCodeFlags::NO_MODRM.bits()); + + table +}; + +impl OpCodeDesc { + fn one_byte(insn: &mut OpCodeBytes) -> Option { + if let Ok(byte) = insn.0.peek() { + // Advance the OpCodeBytes as this is a opcode byte + insn.0.advance(); + ONE_BYTE_TABLE.get(byte as usize).cloned().flatten() + } else { + None + } + } + + fn two_byte(insn: &mut OpCodeBytes) -> Option { + if let Ok(byte) = insn.0.peek() { + // Advance the OpCodeBytes as this is a opcode byte + insn.0.advance(); + TWO_BYTE_TABLE.get(byte as usize).cloned().flatten() + } else { + None + } + } + + fn group7(insn: &OpCodeBytes) -> Option { + if let Ok(modrm) = insn.0.peek() { + // Not to advance the OpCodeBytes as this is not a opcode byte + let r#mod = modrm >> 6; + let offset = (modrm >> 3) & 0x7; + let idx = if r#mod == 3 { 8 + offset } else { offset }; + GROUP7_TABLE.get(idx as usize).cloned().flatten() + } else { + None + } + } + + fn group7_rm7(insn: &OpCodeBytes) -> Option { + if let Ok(modrm) = insn.0.peek() { + // Not to advance the OpCodeBytes as this is not a opcode byte + let idx = modrm & 0x7; + GROUP7_RM7_TABLE.get(idx as usize).cloned().flatten() + } else { + None + } + } + + /// Decodes an opcode from the given `OpCodeBytes`. + /// + /// # Arguments + /// + /// * `insn` - A mutable reference to the `OpCodeBytes` representing + /// the bytes of the opcode to be decoded. + /// + /// # Returns + /// + /// A Some(OpCodeDesc) if the opcode is supported or None otherwise + pub fn decode(insn: &mut OpCodeBytes) -> Option { + let mut opdesc = Self::one_byte(insn); + + loop { + if let Some(desc) = opdesc { + opdesc = match desc.class { + OpCodeClass::TwoByte => Self::two_byte(insn), + OpCodeClass::Group7 => Self::group7(insn), + OpCodeClass::Group7Rm7 => Self::group7_rm7(insn), + _ => return opdesc, + } + } else { + return None; + } + } + } +} diff --git a/stage2/src/io.rs b/stage2/src/io.rs new file mode 100644 index 000000000..39c8a920c --- /dev/null +++ b/stage2/src/io.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use core::arch::asm; +use core::fmt::Debug; + +pub trait IOPort: Sync + Debug { + fn outb(&self, port: u16, value: u8) { + unsafe { asm!("outb %al, %dx", in("al") value, in("dx") port, options(att_syntax)) } + } + + fn inb(&self, port: u16) -> u8 { + unsafe { + let ret: u8; + asm!("inb %dx, %al", in("dx") port, out("al") ret, options(att_syntax)); + ret + } + } + + fn outw(&self, port: u16, value: u16) { + unsafe { asm!("outw %ax, %dx", in("ax") value, in("dx") port, options(att_syntax)) } + } + + fn inw(&self, port: u16) -> u16 { + unsafe { + let ret: u16; + asm!("inw %dx, %ax", in("dx") port, out("ax") ret, options(att_syntax)); + ret + } + } + + fn outl(&self, port: u16, value: u32) { + unsafe { asm!("outl %eax, %dx", in("eax") value, in("dx") port, options(att_syntax)) } + } + + fn inl(&self, port: u16) -> u32 { + unsafe { + let ret: u32; + asm!("inl %dx, %eax", in("dx") port, out("eax") ret, options(att_syntax)); + ret + } + } +} + +#[derive(Default, Debug, Clone, Copy)] +pub struct DefaultIOPort {} + +impl IOPort for DefaultIOPort {} + +pub static DEFAULT_IO_DRIVER: DefaultIOPort = DefaultIOPort {}; diff --git a/stage2/src/kernel_region.rs b/stage2/src/kernel_region.rs new file mode 100644 index 000000000..39d7c1ae5 --- /dev/null +++ b/stage2/src/kernel_region.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +use crate::address::PhysAddr; +use crate::utils::MemoryRegion; +use bootlib::kernel_launch::KernelLaunchInfo; + +pub fn new_kernel_region(launch_info: &KernelLaunchInfo) -> MemoryRegion { + let start = PhysAddr::from(launch_info.kernel_region_phys_start); + let end = PhysAddr::from(launch_info.kernel_region_phys_end); + MemoryRegion::from_addresses(start, end) +} diff --git a/stage2/src/lib.rs b/stage2/src/lib.rs new file mode 100644 index 000000000..250b94e0b --- /dev/null +++ b/stage2/src/lib.rs @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Nicolai Stange + +#![no_std] +#![cfg_attr(all(test, test_in_svsm), no_main)] +#![cfg_attr(all(test, test_in_svsm), feature(custom_test_frameworks))] +#![cfg_attr(all(test, test_in_svsm), test_runner(crate::testing::svsm_test_runner))] +#![cfg_attr(all(test, test_in_svsm), reexport_test_harness_main = "test_main")] + +pub mod acpi; +pub mod address; +pub mod config; +pub mod console; +pub mod cpu; +pub mod crypto; +pub mod debug; +pub mod error; +pub mod fs; +pub mod fw_cfg; +pub mod fw_meta; +pub mod greq; +pub mod igvm_params; +pub mod insn_decode; +pub mod io; +pub mod kernel_region; +pub mod locking; +pub mod mm; +pub mod platform; +pub mod protocols; +pub mod requests; +pub mod serial; +pub mod sev; +pub mod string; +pub mod svsm_paging; +pub mod syscall; +pub mod task; +pub mod types; +pub mod utils; +#[cfg(all(feature = "mstpm", not(test)))] +pub mod vtpm; + +#[test] +fn test_nop() {} + +// When running tests inside the SVSM: +// Build the kernel entrypoint. +#[cfg(all(test, test_in_svsm))] +#[path = "svsm.rs"] +pub mod svsm_bin; +// The kernel expects to access this crate as svsm, so reexport. +#[cfg(all(test, test_in_svsm))] +extern crate self as svsm; +// Include a module containing the test runner. +#[cfg(all(test, test_in_svsm))] +pub mod testing; diff --git a/stage2/src/locking/common.rs b/stage2/src/locking/common.rs new file mode 100644 index 000000000..229ff3e92 --- /dev/null +++ b/stage2/src/locking/common.rs @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel +use crate::cpu::IrqGuard; +use core::marker::PhantomData; + +/// Abstracts IRQ state handling when taking and releasing locks. There are two +/// implemenations: +/// +/// * [IrqUnsafeLocking] implements the methods as no-ops and does not change +/// any IRQ state. +/// * [IrqSafeLocking] actually disables and enables IRQs in the methods, +/// making a lock IRQ-safe by using this structure. +pub trait IrqLocking { + /// Associated helper function to disable IRQs and create an instance of + /// the implementing struct. This is used by lock implementations. + /// + /// # Returns + /// + /// New instance of implementing struct. + fn irqs_disable() -> Self; +} + +/// Implements the IRQ state handling methods as no-ops. For use it IRQ-unsafe +/// locks. +#[derive(Debug, Default)] +pub struct IrqUnsafeLocking; + +impl IrqLocking for IrqUnsafeLocking { + fn irqs_disable() -> Self { + Self {} + } +} + +/// Properly implements the IRQ state handling methods. For use it IRQ-safe +/// locks. +#[derive(Debug, Default)] +pub struct IrqSafeLocking { + /// IrqGuard to keep track of IRQ state. IrqGuard implements Drop, which + /// will re-enable IRQs when the struct goes out of scope. + _guard: IrqGuard, + /// Make type explicitly !Send + !Sync + phantom: PhantomData<*const ()>, +} + +impl IrqLocking for IrqSafeLocking { + fn irqs_disable() -> Self { + Self { + _guard: IrqGuard::new(), + phantom: PhantomData, + } + } +} diff --git a/stage2/src/locking/mod.rs b/stage2/src/locking/mod.rs new file mode 100644 index 000000000..3c2fbcb71 --- /dev/null +++ b/stage2/src/locking/mod.rs @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod common; +pub mod rwlock; +pub mod spinlock; + +pub use common::{IrqLocking, IrqSafeLocking, IrqUnsafeLocking}; +pub use rwlock::{ + RWLock, RWLockIrqSafe, ReadLockGuard, ReadLockGuardIrqSafe, WriteLockGuard, + WriteLockGuardIrqSafe, +}; +pub use spinlock::{LockGuard, LockGuardIrqSafe, RawLockGuard, SpinLock, SpinLockIrqSafe}; diff --git a/stage2/src/locking/rwlock.rs b/stage2/src/locking/rwlock.rs new file mode 100644 index 000000000..224a6fc0f --- /dev/null +++ b/stage2/src/locking/rwlock.rs @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::common::*; +use core::cell::UnsafeCell; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::sync::atomic::{AtomicU64, Ordering}; + +/// A guard that provides read access to the data protected by [`RWLock`] +#[derive(Debug)] +#[must_use = "if unused the RWLock will immediately unlock"] +pub struct RawReadLockGuard<'a, T, I> { + /// Reference to the associated `AtomicU64` in the [`RWLock`] + rwlock: &'a AtomicU64, + /// Reference to the protected data + data: &'a T, + /// IRQ state before and after critical section + _irq_state: I, +} + +/// Implements the behavior of the [`ReadLockGuard`] when it is dropped +impl Drop for RawReadLockGuard<'_, T, I> { + /// Release the read lock + fn drop(&mut self) { + self.rwlock.fetch_sub(1, Ordering::Release); + } +} + +/// Implements the behavior of dereferencing the [`ReadLockGuard`] to +/// access the protected data. +impl Deref for RawReadLockGuard<'_, T, I> { + type Target = T; + /// Allow reading the protected data through deref + fn deref(&self) -> &T { + self.data + } +} + +pub type ReadLockGuard<'a, T> = RawReadLockGuard<'a, T, IrqUnsafeLocking>; +pub type ReadLockGuardIrqSafe<'a, T> = RawReadLockGuard<'a, T, IrqSafeLocking>; + +/// A guard that provides exclusive write access to the data protected by [`RWLock`] +#[derive(Debug)] +#[must_use = "if unused the RWLock will immediately unlock"] +pub struct RawWriteLockGuard<'a, T, I> { + /// Reference to the associated `AtomicU64` in the [`RWLock`] + rwlock: &'a AtomicU64, + /// Reference to the protected data (mutable) + data: &'a mut T, + /// IRQ state before and after critical section + _irq_state: I, +} + +/// Implements the behavior of the [`WriteLockGuard`] when it is dropped +impl Drop for RawWriteLockGuard<'_, T, I> { + fn drop(&mut self) { + // There are no readers - safe to just set lock to 0 + self.rwlock.store(0, Ordering::Release); + } +} + +/// Implements the behavior of dereferencing the [`WriteLockGuard`] to +/// access the protected data. +impl Deref for RawWriteLockGuard<'_, T, I> { + type Target = T; + fn deref(&self) -> &T { + self.data + } +} + +/// Implements the behavior of dereferencing the [`WriteLockGuard`] to +/// access the protected data in a mutable way. +impl DerefMut for RawWriteLockGuard<'_, T, I> { + fn deref_mut(&mut self) -> &mut T { + self.data + } +} + +pub type WriteLockGuard<'a, T> = RawWriteLockGuard<'a, T, IrqUnsafeLocking>; +pub type WriteLockGuardIrqSafe<'a, T> = RawWriteLockGuard<'a, T, IrqSafeLocking>; + +/// A simple Read-Write Lock (RWLock) that allows multiple readers or +/// one exclusive writer. +#[derive(Debug)] +pub struct RawRWLock { + /// An atomic 64-bit integer used for synchronization + rwlock: AtomicU64, + /// An UnsafeCell for interior mutability + data: UnsafeCell, + /// Silence unused type warning + phantom: PhantomData, +} + +/// Implements the trait `Sync` for the [`RWLock`], allowing safe +/// concurrent access across threads. +unsafe impl Send for RawRWLock {} +unsafe impl Sync for RawRWLock {} + +/// Splits a 64-bit value into two parts: readers (low 32 bits) and +/// writers (high 32 bits). +/// +/// # Parameters +/// +/// - `val`: A 64-bit unsigned integer value to be split. +/// +/// # Returns +/// +/// A tuple containing two 32-bit unsigned integer values. The first +/// element of the tuple is the lower 32 bits of input value, and the +/// second is the upper 32 bits. +#[inline] +fn split_val(val: u64) -> (u64, u64) { + (val & 0xffff_ffffu64, val >> 32) +} + +/// Composes a 64-bit value by combining the number of readers (low 32 +/// bits) and writers (high 32 bits). This function is used to create a +/// 64-bit synchronization value that represents the current state of the +/// RWLock, including the count of readers and writers. +/// +/// # Parameters +/// +/// - `readers`: The number of readers (low 32 bits) currently holding read locks. +/// - `writers`: The number of writers (high 32 bits) currently holding write locks. +/// +/// # Returns +/// +/// A 64-bit value representing the combined state of readers and writers in the RWLock. +#[inline] +fn compose_val(readers: u64, writers: u64) -> u64 { + (readers & 0xffff_ffffu64) | (writers << 32) +} + +/// A reader-writer lock that allows multiple readers or a single writer +/// to access the protected data. [`RWLock`] provides exclusive access for +/// writers and shared access for readers, for efficient synchronization. +impl RawRWLock { + /// Creates a new [`RWLock`] instance with the provided initial data. + /// + /// # Parameters + /// + /// - `data`: The initial data to be protected by the [`RWLock`]. + /// + /// # Returns + /// + /// A new [`RWLock`] instance with the specified initial data. + /// + /// # Example + /// + /// ```rust + /// use svsm::locking::RWLock; + /// + /// #[derive(Debug)] + /// struct MyData { + /// value: i32, + /// } + /// + /// let data = MyData { value: 42 }; + /// let rwlock = RWLock::new(data); + /// ``` + pub const fn new(data: T) -> Self { + Self { + rwlock: AtomicU64::new(0), + data: UnsafeCell::new(data), + phantom: PhantomData, + } + } + + /// This function is used to wait until all writers have finished their + /// operations and retrieve the current state of the [`RWLock`]. + /// + /// # Returns + /// + /// A 64-bit value representing the current state of the [`RWLock`], + /// including the count of readers and writers. + #[inline] + fn wait_for_writers(&self) -> u64 { + loop { + let val: u64 = self.rwlock.load(Ordering::Relaxed); + let (_, writers) = split_val(val); + + if writers == 0 { + return val; + } + core::hint::spin_loop(); + } + } + + /// This function is used to wait until all readers have finished their + /// operations and retrieve the current state of the [`RWLock`]. + /// + /// # Returns + /// + /// A 64-bit value representing the current state of the [`RWLock`], + /// including the count of readers and writers. + #[inline] + fn wait_for_readers(&self) -> u64 { + loop { + let val: u64 = self.rwlock.load(Ordering::Relaxed); + let (readers, _) = split_val(val); + + if readers == 0 { + return val; + } + core::hint::spin_loop(); + } + } + + /// This function allows multiple readers to access the data concurrently. + /// + /// # Returns + /// + /// A [`ReadLockGuard`] that provides read access to the protected data. + pub fn lock_read(&self) -> RawReadLockGuard<'_, T, I> { + let irq_state = I::irqs_disable(); + loop { + let val = self.wait_for_writers(); + let (readers, _) = split_val(val); + let new_val = compose_val(readers + 1, 0); + + if self + .rwlock + .compare_exchange(val, new_val, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + break; + } + core::hint::spin_loop(); + } + + RawReadLockGuard { + rwlock: &self.rwlock, + data: unsafe { &*self.data.get() }, + _irq_state: irq_state, + } + } + + /// This function ensures exclusive access for a single writer and waits + /// for all readers to finish before granting access to the writer. + /// + /// # Returns + /// + /// A [`WriteLockGuard`] that provides write access to the protected data. + pub fn lock_write(&self) -> RawWriteLockGuard<'_, T, I> { + let irq_state = I::irqs_disable(); + + // Waiting for current writer to finish + loop { + let val = self.wait_for_writers(); + let (readers, _) = split_val(val); + let new_val = compose_val(readers, 1); + + if self + .rwlock + .compare_exchange(val, new_val, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + break; + } + core::hint::spin_loop(); + } + + // Now locked for write - wait until all readers finished + let val: u64 = self.wait_for_readers(); + assert!(val == compose_val(0, 1)); + + RawWriteLockGuard { + rwlock: &self.rwlock, + data: unsafe { &mut *self.data.get() }, + _irq_state: irq_state, + } + } +} + +pub type RWLock = RawRWLock; +pub type RWLockIrqSafe = RawRWLock; + +mod tests { + #[test] + fn test_lock_rw() { + use crate::locking::*; + + let rwlock = RWLock::new(42); + + // Acquire a read lock and check the initial value + let read_guard = rwlock.lock_read(); + assert_eq!(*read_guard, 42); + + drop(read_guard); + + let read_guard2 = rwlock.lock_read(); + assert_eq!(*read_guard2, 42); + + // Create another RWLock instance for modification + let rwlock_modify = RWLock::new(0); + + let mut write_guard = rwlock_modify.lock_write(); + *write_guard = 99; + assert_eq!(*write_guard, 99); + + drop(write_guard); + + let read_guard = rwlock.lock_read(); + assert_eq!(*read_guard, 42); + } + + #[test] + fn test_concurrent_readers() { + use crate::locking::*; + + // Let's test two concurrent readers on a new RWLock instance + let rwlock_concurrent = RWLock::new(123); + + let read_guard1 = rwlock_concurrent.lock_read(); + let read_guard2 = rwlock_concurrent.lock_read(); + + // Assert that both readers can access the same value (123) + assert_eq!(*read_guard1, 123); + assert_eq!(*read_guard2, 123); + + drop(read_guard1); + drop(read_guard2); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn rw_lock_irq_unsafe() { + use crate::cpu::irq_state::{raw_irqs_disable, raw_irqs_enable}; + use crate::cpu::irqs_enabled; + use crate::locking::*; + + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + let lock = RWLock::new(0); + + // Lock for write + let guard = lock.lock_write(); + // IRQs must still be enabled; + assert!(irqs_enabled()); + // Unlock + drop(guard); + + // Lock for read + let guard = lock.lock_read(); + // IRQs must still be enabled; + assert!(irqs_enabled()); + // Unlock + drop(guard); + + // IRQs must still be enabled + assert!(irqs_enabled()); + if !was_enabled { + raw_irqs_disable(); + } + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn rw_lock_irq_safe() { + use crate::cpu::irq_state::{raw_irqs_disable, raw_irqs_enable}; + use crate::cpu::{irqs_disabled, irqs_enabled}; + use crate::locking::*; + + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + let lock = RWLockIrqSafe::new(0); + + // Lock for write + let guard = lock.lock_write(); + // IRQs must be disabled + assert!(irqs_disabled()); + // Unlock + drop(guard); + + assert!(irqs_enabled()); + + // Lock for read + let guard = lock.lock_read(); + // IRQs must still be enabled; + assert!(irqs_disabled()); + // Unlock + drop(guard); + + // IRQs must still be enabled + assert!(irqs_enabled()); + if !was_enabled { + raw_irqs_disable(); + } + } + } +} diff --git a/stage2/src/locking/spinlock.rs b/stage2/src/locking/spinlock.rs new file mode 100644 index 000000000..1a91338ee --- /dev/null +++ b/stage2/src/locking/spinlock.rs @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::common::*; +use core::cell::UnsafeCell; +use core::marker::PhantomData; +use core::ops::{Deref, DerefMut}; +use core::sync::atomic::{AtomicU64, Ordering}; + +/// A lock guard obtained from a [`SpinLock`]. This lock guard +/// provides exclusive access to the data protected by a [`SpinLock`], +/// ensuring that the lock is released when it goes out of scope. +/// +/// # Examples +/// +/// ``` +/// use svsm::locking::SpinLock; +/// +/// let data = 42; +/// let spin_lock = SpinLock::new(data); +/// +/// { +/// let mut guard = spin_lock.lock(); +/// *guard += 1; // Modify the protected data. +/// }; // Lock is automatically released when `guard` goes out of scope. +/// ``` +#[derive(Debug)] +#[must_use = "if unused the SpinLock will immediately unlock"] +pub struct RawLockGuard<'a, T, I = IrqUnsafeLocking> { + holder: &'a AtomicU64, + data: &'a mut T, + #[expect(dead_code)] + irq_state: I, +} + +/// Implements the behavior of the [`LockGuard`] when it is dropped +impl Drop for RawLockGuard<'_, T, I> { + /// Automatically releases the lock when the guard is dropped + fn drop(&mut self) { + self.holder.fetch_add(1, Ordering::Release); + } +} + +/// Implements the behavior of dereferencing the [`LockGuard`] to +/// access the protected data. +impl Deref for RawLockGuard<'_, T, I> { + type Target = T; + /// Provides read-only access to the protected data + fn deref(&self) -> &T { + self.data + } +} + +/// Implements the behavior of dereferencing the [`LockGuard`] to +/// access the protected data in a mutable way. +impl DerefMut for RawLockGuard<'_, T, I> { + /// Provides mutable access to the protected data + fn deref_mut(&mut self) -> &mut T { + self.data + } +} + +pub type LockGuard<'a, T> = RawLockGuard<'a, T, IrqUnsafeLocking>; +pub type LockGuardIrqSafe<'a, T> = RawLockGuard<'a, T, IrqSafeLocking>; + +/// A simple ticket-spinlock implementation for protecting concurrent data +/// access. +/// +/// Two variants are derived from this implementation: +/// +/// * [`SpinLock`] for general use. This implementation is not safe for use in +/// IRQ handlers. +/// * [`SpinLockIrqSafe`] for protecting data that is accessed in IRQ context. +/// +/// # Examples +/// +/// ``` +/// use svsm::locking::SpinLock; +/// +/// let data = 42; +/// let spin_lock = SpinLock::new(data); +/// +/// // Acquire the lock and modify the protected data. +/// { +/// let mut guard = spin_lock.lock(); +/// *guard += 1; +/// }; // Lock is automatically released when `guard` goes out of scope. +/// +/// // Try to acquire the lock without blocking +/// if let Some(mut guard) = spin_lock.try_lock() { +/// *guard += 2; +/// }; +/// ``` +#[derive(Debug, Default)] +pub struct RawSpinLock { + /// This atomic counter is incremented each time a thread attempts to + /// acquire the lock. It helps to determine the order in which threads + /// acquire the lock. + current: AtomicU64, + /// This counter represents the thread that currently holds the lock + /// and has access to the protected data. + holder: AtomicU64, + /// This `UnsafeCell` is used to provide interior mutability of the + /// protected data. That is, it allows the data to be accessed/modified + /// while enforcing the locking mechanism. + data: UnsafeCell, + /// Use generic type I in the struct without consuming space. + phantom: PhantomData, +} + +unsafe impl Send for RawSpinLock {} +unsafe impl Sync for RawSpinLock {} + +impl RawSpinLock { + /// Creates a new SpinLock instance with the specified initial data. + /// + /// # Examples + /// + /// ``` + /// use svsm::locking::SpinLock; + /// + /// let data = 42; + /// let spin_lock = SpinLock::new(data); + /// ``` + pub const fn new(data: T) -> Self { + Self { + current: AtomicU64::new(0), + holder: AtomicU64::new(0), + data: UnsafeCell::new(data), + phantom: PhantomData, + } + } + + /// Acquires the lock, providing access to the protected data. + /// + /// # Examples + /// + /// ``` + /// use svsm::locking::SpinLock; + /// + /// let spin_lock = SpinLock::new(42); + /// + /// // Acquire the lock and modify the protected data. + /// { + /// let mut guard = spin_lock.lock(); + /// *guard += 1; + /// }; // Lock is automatically released when `guard` goes out of scope. + /// ``` + pub fn lock(&self) -> RawLockGuard<'_, T, I> { + let irq_state = I::irqs_disable(); + + let ticket = self.current.fetch_add(1, Ordering::Relaxed); + loop { + let h = self.holder.load(Ordering::Acquire); + if h == ticket { + break; + } + core::hint::spin_loop(); + } + RawLockGuard { + holder: &self.holder, + data: unsafe { &mut *self.data.get() }, + irq_state, + } + } + + /// This method tries to acquire the lock without blocking. If the + /// lock is not available, it returns `None`. If the lock is + /// successfully acquired, it returns a [`LockGuard`] that automatically + /// releases the lock when it goes out of scope. + pub fn try_lock(&self) -> Option> { + let irq_state = I::irqs_disable(); + + let current = self.current.load(Ordering::Relaxed); + let holder = self.holder.load(Ordering::Acquire); + + if current == holder { + let result = self.current.compare_exchange( + current, + current + 1, + Ordering::Acquire, + Ordering::Relaxed, + ); + if result.is_ok() { + return Some(RawLockGuard { + holder: &self.holder, + data: unsafe { &mut *self.data.get() }, + irq_state, + }); + } + } + + None + } +} + +pub type SpinLock = RawSpinLock; +pub type SpinLockIrqSafe = RawSpinLock; + +#[cfg(test)] +mod tests { + use super::*; + use crate::cpu::irq_state::{raw_irqs_disable, raw_irqs_enable}; + use crate::cpu::{irqs_disabled, irqs_enabled}; + + #[test] + fn test_spin_lock() { + let spin_lock = SpinLock::new(0); + + let mut guard = spin_lock.lock(); + *guard += 1; + + // Ensure the locked data is updated. + assert_eq!(*guard, 1); + + // Try to lock again; it should fail and return None. + let try_lock_result = spin_lock.try_lock(); + assert!(try_lock_result.is_none()); + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn spin_lock_irq_unsafe() { + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + + let spin_lock = SpinLock::new(0); + let guard = spin_lock.lock(); + assert!(irqs_enabled()); + drop(guard); + assert!(irqs_enabled()); + + if !was_enabled { + raw_irqs_disable(); + } + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn spin_lock_irq_safe() { + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + + let spin_lock = SpinLockIrqSafe::new(0); + let guard = spin_lock.lock(); + assert!(irqs_disabled()); + drop(guard); + assert!(irqs_enabled()); + + if !was_enabled { + raw_irqs_disable(); + } + } + } + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn spin_trylock_irq_safe() { + unsafe { + let was_enabled = irqs_enabled(); + raw_irqs_enable(); + + let spin_lock = SpinLockIrqSafe::new(0); + + // IRQs are enabled - taking the lock must succeed and disable IRQs + let g1 = spin_lock.try_lock(); + assert!(g1.is_some()); + assert!(irqs_disabled()); + + // Release lock and check if that enables IRQs + drop(g1); + assert!(irqs_enabled()); + + // Leave with IRQs configured as test was entered. + if !was_enabled { + raw_irqs_disable(); + } + } + } +} diff --git a/stage2/src/mm/address_space.rs b/stage2/src/mm/address_space.rs new file mode 100644 index 000000000..6a18b5751 --- /dev/null +++ b/stage2/src/mm/address_space.rs @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{PhysAddr, VirtAddr}; +use crate::mm::pagetable::{PageFrame, PageTable}; +use crate::utils::immut_after_init::ImmutAfterInitCell; + +#[derive(Debug, Copy, Clone)] +#[cfg_attr(not(any(test, target_os = "none")), expect(dead_code))] +pub struct FixedAddressMappingRange { + virt_start: VirtAddr, + virt_end: VirtAddr, + phys_start: PhysAddr, +} + +impl FixedAddressMappingRange { + pub fn new(virt_start: VirtAddr, virt_end: VirtAddr, phys_start: PhysAddr) -> Self { + Self { + virt_start, + virt_end, + phys_start, + } + } + + #[cfg(target_os = "none")] + fn phys_to_virt(&self, paddr: PhysAddr) -> Option { + if paddr < self.phys_start { + None + } else { + let size: usize = self.virt_end - self.virt_start; + if paddr >= self.phys_start + size { + None + } else { + let offset: usize = paddr - self.phys_start; + Some(self.virt_start + offset) + } + } + } +} + +#[derive(Debug, Copy, Clone)] +#[cfg_attr(not(target_os = "none"), expect(dead_code))] +pub struct FixedAddressMapping { + kernel_mapping: FixedAddressMappingRange, + heap_mapping: Option, +} + +static FIXED_MAPPING: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +pub fn init_kernel_mapping_info( + kernel_mapping: FixedAddressMappingRange, + heap_mapping: Option, +) { + let mapping = FixedAddressMapping { + kernel_mapping, + heap_mapping, + }; + FIXED_MAPPING + .init(&mapping) + .expect("Already initialized fixed mapping info"); +} + +#[cfg(target_os = "none")] +pub fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr { + match PageTable::virt_to_frame(vaddr) { + Some(paddr) => paddr.address(), + None => { + panic!("Invalid virtual address {:#018x}", vaddr); + } + } +} + +pub fn virt_to_frame(vaddr: VirtAddr) -> PageFrame { + match PageTable::virt_to_frame(vaddr) { + Some(paddr) => paddr, + None => { + panic!("Invalid virtual address {:#018x}", vaddr); + } + } +} + +#[cfg(target_os = "none")] +pub fn phys_to_virt(paddr: PhysAddr) -> VirtAddr { + if let Some(addr) = FIXED_MAPPING.kernel_mapping.phys_to_virt(paddr) { + return addr; + } + if let Some(ref mapping) = FIXED_MAPPING.heap_mapping { + if let Some(addr) = mapping.phys_to_virt(paddr) { + return addr; + } + } + + panic!("Invalid physical address {:#018x}", paddr); +} + +#[cfg(not(target_os = "none"))] +pub fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr { + use crate::address::Address; + PhysAddr::from(vaddr.bits()) +} + +#[cfg(not(target_os = "none"))] +pub fn phys_to_virt(paddr: PhysAddr) -> VirtAddr { + use crate::address::Address; + VirtAddr::from(paddr.bits()) +} + +// Address space definitions for SVSM virtual memory layout + +/// Size helpers +pub const SIZE_1K: usize = 1024; +pub const SIZE_1M: usize = SIZE_1K * 1024; +pub const SIZE_1G: usize = SIZE_1M * 1024; + +/// Pagesize definitions +pub const PAGE_SIZE: usize = SIZE_1K * 4; +pub const PAGE_SIZE_2M: usize = SIZE_1M * 2; + +/// More size helpers +pub const SIZE_LEVEL3: usize = 1usize << ((9 * 3) + 12); +pub const SIZE_LEVEL2: usize = 1usize << ((9 * 2) + 12); +#[expect(clippy::identity_op)] +pub const SIZE_LEVEL1: usize = 1usize << ((9 * 1) + 12); +#[expect(clippy::erasing_op, clippy::identity_op)] +pub const SIZE_LEVEL0: usize = 1usize << ((9 * 0) + 12); + +// Stack definitions +pub const STACK_PAGES: usize = 8; +pub const STACK_SIZE: usize = PAGE_SIZE * STACK_PAGES; +pub const STACK_GUARD_SIZE: usize = STACK_SIZE; +pub const STACK_TOTAL_SIZE: usize = STACK_SIZE + STACK_GUARD_SIZE; + +const fn virt_from_idx(idx: usize) -> VirtAddr { + VirtAddr::new(idx << ((3 * 9) + 12)) +} + +/// Level3 page-table index shared between all CPUs +pub const PGTABLE_LVL3_IDX_SHARED: usize = 511; + +/// Base Address of shared memory region +pub const SVSM_SHARED_BASE: VirtAddr = virt_from_idx(PGTABLE_LVL3_IDX_SHARED); + +/// Mapping range for shared stacks +pub const SVSM_SHARED_STACK_BASE: VirtAddr = SVSM_SHARED_BASE.const_add(256 * SIZE_1G); +pub const SVSM_SHARED_STACK_END: VirtAddr = SVSM_SHARED_STACK_BASE.const_add(SIZE_1G); + +/// PerCPU mappings level 3 index +pub const PGTABLE_LVL3_IDX_PERCPU: usize = 510; + +/// Base Address of shared memory region +pub const SVSM_PERCPU_BASE: VirtAddr = virt_from_idx(PGTABLE_LVL3_IDX_PERCPU); + +/// End Address of per-cpu memory region +pub const SVSM_PERCPU_END: VirtAddr = SVSM_PERCPU_BASE.const_add(SIZE_LEVEL3); + +/// PerCPU CAA mappings +pub const SVSM_PERCPU_CAA_BASE: VirtAddr = SVSM_PERCPU_BASE.const_add(2 * SIZE_LEVEL0); + +/// PerCPU VMSA mappings +pub const SVSM_PERCPU_VMSA_BASE: VirtAddr = SVSM_PERCPU_BASE.const_add(4 * SIZE_LEVEL0); + +/// Region for PerCPU Stacks +pub const SVSM_PERCPU_STACKS_BASE: VirtAddr = SVSM_PERCPU_BASE.const_add(SIZE_LEVEL1); + +/// Stack address of the per-cpu init task +pub const SVSM_STACKS_INIT_TASK: VirtAddr = SVSM_PERCPU_STACKS_BASE; + +/// IST Stacks base address +pub const SVSM_STACKS_IST_BASE: VirtAddr = SVSM_STACKS_INIT_TASK.const_add(STACK_TOTAL_SIZE); + +/// DoubleFault IST stack base address +pub const SVSM_STACK_IST_DF_BASE: VirtAddr = SVSM_STACKS_IST_BASE; + +/// PerCPU XSave Context area base address +pub const SVSM_XSAVE_AREA_BASE: VirtAddr = SVSM_STACKS_IST_BASE.const_add(STACK_TOTAL_SIZE); + +/// Base Address for temporary mappings - used by page-table guards +pub const SVSM_PERCPU_TEMP_BASE: VirtAddr = SVSM_PERCPU_BASE.const_add(SIZE_LEVEL2); + +// Below is space for 512 temporary 4k mappings and 511 temporary 2M mappings + +/// Start and End for PAGE_SIZEed temporary mappings +pub const SVSM_PERCPU_TEMP_BASE_4K: VirtAddr = SVSM_PERCPU_TEMP_BASE; +pub const SVSM_PERCPU_TEMP_END_4K: VirtAddr = SVSM_PERCPU_TEMP_BASE_4K.const_add(SIZE_LEVEL1); + +/// Start and End for PAGE_SIZEed temporary mappings +pub const SVSM_PERCPU_TEMP_BASE_2M: VirtAddr = SVSM_PERCPU_TEMP_BASE.const_add(SIZE_LEVEL1); +pub const SVSM_PERCPU_TEMP_END_2M: VirtAddr = SVSM_PERCPU_TEMP_BASE.const_add(SIZE_LEVEL2); + +/// Task mappings level 3 index +pub const PGTABLE_LVL3_IDX_PERTASK: usize = 508; + +/// Base address of task memory region +pub const SVSM_PERTASK_BASE: VirtAddr = virt_from_idx(PGTABLE_LVL3_IDX_PERTASK); + +/// End address of task memory region +pub const SVSM_PERTASK_END: VirtAddr = SVSM_PERTASK_BASE.const_add(SIZE_LEVEL3); + +/// Kernel stack for a task +pub const SVSM_PERTASK_STACK_BASE: VirtAddr = SVSM_PERTASK_BASE; + +/// SSE context save area for a task +pub const SVSM_PERTASK_XSAVE_AREA_BASE: VirtAddr = + SVSM_PERTASK_STACK_BASE.const_add(STACK_TOTAL_SIZE); + +/// Page table self-map level 3 index +pub const PGTABLE_LVL3_IDX_PTE_SELFMAP: usize = 493; + +pub const SVSM_PTE_BASE: VirtAddr = virt_from_idx(PGTABLE_LVL3_IDX_PTE_SELFMAP); + +// +// User-space mapping constants +// + +/// Start of user memory address range +pub const USER_MEM_START: VirtAddr = VirtAddr::new(0); + +/// End of user memory address range +pub const USER_MEM_END: VirtAddr = USER_MEM_START.const_add(256 * SIZE_LEVEL3); + +#[cfg(test)] +mod tests { + use super::*; + use crate::locking::SpinLock; + + static KERNEL_MAPPING_TEST: ImmutAfterInitCell = + ImmutAfterInitCell::uninit(); + static INITIALIZED: SpinLock = SpinLock::new(false); + + #[test] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn init_km_testing() { + let mut initialized = INITIALIZED.lock(); + if *initialized { + return; + } + let kernel_mapping = FixedAddressMappingRange::new( + VirtAddr::new(0x1000), + VirtAddr::new(0x2000), + PhysAddr::new(0x3000), + ); + let mapping = FixedAddressMapping { + kernel_mapping, + heap_mapping: None, + }; + KERNEL_MAPPING_TEST.init(&mapping).unwrap(); + *initialized = true; + } + + #[test] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn test_init_kernel_mapping_info() { + init_km_testing(); + + let km = &KERNEL_MAPPING_TEST; + + assert_eq!(km.kernel_mapping.virt_start, VirtAddr::new(0x1000)); + assert_eq!(km.kernel_mapping.virt_end, VirtAddr::new(0x2000)); + assert_eq!(km.kernel_mapping.phys_start, PhysAddr::new(0x3000)); + } + + #[test] + #[cfg(target_os = "none")] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn test_virt_to_phys() { + let vaddr = VirtAddr::new(0x1500); + let paddr = virt_to_phys(vaddr); + + assert_eq!(paddr, PhysAddr::new(0x4500)); + } + + #[test] + #[cfg(not(target_os = "none"))] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn test_virt_to_phys() { + let vaddr = VirtAddr::new(0x1500); + let paddr = virt_to_phys(vaddr); + + assert_eq!(paddr, PhysAddr::new(0x1500)); + } + + #[test] + #[cfg(target_os = "none")] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn test_phys_to_virt() { + let paddr = PhysAddr::new(0x4500); + let vaddr = phys_to_virt(paddr); + + assert_eq!(vaddr, VirtAddr::new(0x1500)); + } + + #[test] + #[cfg(not(target_os = "none"))] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn test_phys_to_virt() { + let paddr = PhysAddr::new(0x4500); + let vaddr = phys_to_virt(paddr); + + assert_eq!(vaddr, VirtAddr::new(0x4500)); + } +} diff --git a/stage2/src/mm/alloc.rs b/stage2/src/mm/alloc.rs new file mode 100644 index 000000000..2f4648b2e --- /dev/null +++ b/stage2/src/mm/alloc.rs @@ -0,0 +1,2099 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::cpu::mem::{copy_bytes, write_bytes}; +use crate::error::SvsmError; +use crate::locking::SpinLock; +use crate::mm::virt_to_phys; +use crate::types::{PAGE_SHIFT, PAGE_SIZE}; +use crate::utils::{align_down, align_up, zero_mem_region}; +use core::alloc::{GlobalAlloc, Layout}; +use core::mem::size_of; +use core::ptr; + +#[cfg(any(test, fuzzing))] +use crate::locking::LockGuard; + +/// Represents possible errors that can occur during memory allocation. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AllocError { + /// The provided page type is invalid. + InvalidPageType, + /// The heap address is invalid. + InvalidHeapAddress(VirtAddr), + /// Out of memory error. + OutOfMemory, + /// The specified page order is invalid. + InvalidPageOrder(usize), + /// The file page has an invalid virtual address. + InvalidFilePage(VirtAddr), + /// The page frame number (PFN) is invalid. + InvalidPfn(usize), +} + +impl From for SvsmError { + fn from(err: AllocError) -> Self { + Self::Alloc(err) + } +} + +/// Maximum order of page allocations (up to 128kb) +pub const MAX_ORDER: usize = 6; + +/// Calculates the order of a given size for page allocation. +/// +/// # Arguments +/// +/// * `size` - The size for which to calculate the order. +/// +/// # Returns +/// +/// The calculated order. +pub const fn get_order(size: usize) -> usize { + match size.checked_next_power_of_two() { + Some(v) => v.ilog2() as usize, + None => usize::BITS as usize, + } + .saturating_sub(PAGE_SHIFT) +} + +/// Enum representing the type of a memory page. +#[derive(Clone, Copy, Debug)] +#[repr(u64)] +enum PageType { + Free = 0, + Allocated = 1, + SlabPage = 2, + Compound = 3, + // File pages used for file and task data + File = 4, + Reserved = (1u64 << PageStorageType::TYPE_SHIFT) - 1, +} + +impl TryFrom for PageType { + type Error = AllocError; + fn try_from(val: u64) -> Result { + match val { + v if v == Self::Free as u64 => Ok(Self::Free), + v if v == Self::Allocated as u64 => Ok(Self::Allocated), + v if v == Self::SlabPage as u64 => Ok(Self::SlabPage), + v if v == Self::Compound as u64 => Ok(Self::Compound), + v if v == Self::File as u64 => Ok(Self::File), + v if v == Self::Reserved as u64 => Ok(Self::Reserved), + _ => Err(AllocError::InvalidPageType), + } + } +} + +/// Storage type of a memory page, including encoding and decoding methods +#[derive(Clone, Copy, Debug)] +#[repr(transparent)] +struct PageStorageType(u64); + +impl PageStorageType { + const TYPE_SHIFT: u64 = 4; + const TYPE_MASK: u64 = (1u64 << Self::TYPE_SHIFT) - 1; + const NEXT_SHIFT: u64 = 12; + const NEXT_MASK: u64 = !((1u64 << Self::NEXT_SHIFT) - 1); + const ORDER_MASK: u64 = (1u64 << (Self::NEXT_SHIFT - Self::TYPE_SHIFT)) - 1; + // Slab item sizes are encoded in a u16 + const SLAB_MASK: u64 = 0xffff; + + /// Creates a new [`PageStorageType`] with the specified page type. + /// + /// # Arguments + /// + /// * `t` - The page type. + /// + /// # Returns + /// + /// A new instance of [`PageStorageType`]. + const fn new(t: PageType) -> Self { + Self(t as u64) + } + + /// Encodes the order of the page. + /// + /// # Arguments + /// + /// * `order` - The order to encode. + /// + /// # Returns + /// + /// The updated [`PageStorageType`]. + fn encode_order(self, order: usize) -> Self { + Self(self.0 | ((order as u64) & Self::ORDER_MASK) << Self::TYPE_SHIFT) + } + + /// Encodes the index of the next page. + /// + /// # Arguments + /// + /// * `next_page` - The index of the next page. + /// + /// # Returns + /// + /// The updated [`PageStorageType`]. + fn encode_next(self, next_page: usize) -> Self { + Self(self.0 | (next_page as u64) << Self::NEXT_SHIFT) + } + + /// Encodes the virtual address of the slab + /// + /// # Arguments + /// + /// * `slab` - slab virtual address + /// + /// # Returns + /// + /// The updated [`PageStorageType`] + fn encode_slab(self, slab: u64) -> Self { + let item_size = slab & Self::SLAB_MASK; + Self(self.0 | (item_size << Self::TYPE_SHIFT)) + } + + /// Encodes the reference count. + /// + /// # Arguments + /// + /// * `refcount` - The reference count to encode. + /// + /// # Returns + /// + /// The updated [`PageStorageType`]. + fn encode_refcount(self, refcount: u64) -> Self { + Self(self.0 | refcount << Self::TYPE_SHIFT) + } + + /// Decodes the order of the page. + fn decode_order(&self) -> usize { + ((self.0 >> Self::TYPE_SHIFT) & Self::ORDER_MASK) as usize + } + + /// Decodes the index of the next page. + fn decode_next(&self) -> usize { + ((self.0 & Self::NEXT_MASK) >> Self::NEXT_SHIFT) as usize + } + + /// Decodes the slab + fn decode_slab(&self) -> u64 { + (self.0 >> Self::TYPE_SHIFT) & Self::SLAB_MASK + } + + /// Decodes the reference count. + fn decode_refcount(&self) -> u64 { + self.0 >> Self::TYPE_SHIFT + } + + /// Retrieves the page type from the [`PageStorageType`]. + fn page_type(&self) -> Result { + PageType::try_from(self.0 & Self::TYPE_MASK) + } +} + +/// Struct representing information about a free memory page. +#[derive(Clone, Copy, Debug)] +struct FreeInfo { + /// Index of the next free page. + next_page: usize, + /// Order of the free page. + order: usize, +} + +impl FreeInfo { + /// Encodes the [`FreeInfo`] into a [`PageStorageType`]. + fn encode(&self) -> PageStorageType { + PageStorageType::new(PageType::Free) + .encode_order(self.order) + .encode_next(self.next_page) + } + + /// Decodes a [`FreeInfo`] into a [`PageStorageType`]. + fn decode(mem: PageStorageType) -> Self { + let next_page = mem.decode_next(); + let order = mem.decode_order(); + Self { next_page, order } + } +} + +/// Struct representing information about an allocated memory page. +#[derive(Clone, Copy, Debug)] +struct AllocatedInfo { + order: usize, +} + +impl AllocatedInfo { + /// Encodes the [`AllocatedInfo`] into a [`PageStorageType`]. + fn encode(&self) -> PageStorageType { + PageStorageType::new(PageType::Allocated).encode_order(self.order) + } + + /// Decodes a [`PageStorageType`] into an [`AllocatedInfo`]. + fn decode(mem: PageStorageType) -> Self { + let order = mem.decode_order(); + Self { order } + } +} + +/// Struct representing information about a slab memory page. +#[derive(Clone, Copy, Debug)] +struct SlabPageInfo { + item_size: u64, +} + +impl SlabPageInfo { + /// Encodes the [`SlabPageInfo`] into a [`PageStorageType`]. + fn encode(&self) -> PageStorageType { + PageStorageType::new(PageType::SlabPage).encode_slab(self.item_size) + } + + /// Decodes a [`PageStorageType`] into a [`SlabPageInfo`]. + fn decode(mem: PageStorageType) -> Self { + let item_size = mem.decode_slab(); + Self { item_size } + } +} + +/// Struct representing information about a compound memory page. +#[derive(Clone, Copy, Debug)] +struct CompoundInfo { + order: usize, +} + +impl CompoundInfo { + /// Encodes the [`CompoundInfo`] into a [`PageStorageType`]. + fn encode(&self) -> PageStorageType { + PageStorageType::new(PageType::Compound).encode_order(self.order) + } + + /// Decodes a [`PageStorageType`] into a [`CompoundInfo`]. + fn decode(mem: PageStorageType) -> Self { + let order = mem.decode_order(); + Self { order } + } +} + +/// Struct representing information about a reserved memory page. +#[derive(Clone, Copy, Debug)] +struct ReservedInfo; + +impl ReservedInfo { + /// Encodes the [`ReservedInfo`] into a [`PageStorageType`]. + fn encode(&self) -> PageStorageType { + PageStorageType::new(PageType::Reserved) + } + + /// Decodes a [`PageStorageType`] into a [`ReservedInfo`]. + fn decode(_mem: PageStorageType) -> Self { + Self + } +} + +/// Struct representing information about a file memory page. +#[derive(Clone, Copy, Debug)] +struct FileInfo { + /// Reference count of the file page. + ref_count: u64, +} + +impl FileInfo { + /// Creates a new [`FileInfo`] with the specified reference count. + const fn new(ref_count: u64) -> Self { + Self { ref_count } + } + + /// Encodes the [`FileInfo`] into a [`PageStorageType`]. + fn encode(&self) -> PageStorageType { + PageStorageType::new(PageType::File).encode_refcount(self.ref_count) + } + + /// Decodes a [`PageStorageType`] into a [`FileInfo`]. + fn decode(mem: PageStorageType) -> Self { + let ref_count = mem.decode_refcount(); + Self { ref_count } + } +} + +/// Enum representing different types of page information. +#[derive(Clone, Copy, Debug)] +enum PageInfo { + Free(FreeInfo), + Allocated(AllocatedInfo), + Slab(SlabPageInfo), + Compound(CompoundInfo), + File(FileInfo), + Reserved(ReservedInfo), +} + +impl PageInfo { + /// Converts [`PageInfo`] into a [`PageStorageType`]. + fn to_mem(self) -> PageStorageType { + match self { + Self::Free(fi) => fi.encode(), + Self::Allocated(ai) => ai.encode(), + Self::Slab(si) => si.encode(), + Self::Compound(ci) => ci.encode(), + Self::File(fi) => fi.encode(), + Self::Reserved(ri) => ri.encode(), + } + } + + /// Converts a [`PageStorageType`] into [`PageInfo`]. + fn from_mem(mem: PageStorageType) -> Self { + let Ok(page_type) = mem.page_type() else { + panic!("Unknown page type in {:?}", mem); + }; + + match page_type { + PageType::Free => Self::Free(FreeInfo::decode(mem)), + PageType::Allocated => Self::Allocated(AllocatedInfo::decode(mem)), + PageType::SlabPage => Self::Slab(SlabPageInfo::decode(mem)), + PageType::Compound => Self::Compound(CompoundInfo::decode(mem)), + PageType::File => Self::File(FileInfo::decode(mem)), + PageType::Reserved => Self::Reserved(ReservedInfo::decode(mem)), + } + } +} + +/// Represents info about allocated and free pages in different orders. +#[derive(Debug, Default, Clone, Copy)] +pub struct MemInfo { + total_pages: [usize; MAX_ORDER], + free_pages: [usize; MAX_ORDER], +} + +/// Memory region with its physical/virtual addresses, page count, as well +/// as other details. +#[derive(Debug, Default)] +struct MemoryRegion { + start_phys: PhysAddr, + start_virt: VirtAddr, + page_count: usize, + nr_pages: [usize; MAX_ORDER], + next_page: [usize; MAX_ORDER], + free_pages: [usize; MAX_ORDER], +} + +impl MemoryRegion { + /// Creates a new [`MemoryRegion`] with default values. + const fn new() -> Self { + Self { + start_phys: PhysAddr::null(), + start_virt: VirtAddr::null(), + page_count: 0, + nr_pages: [0; MAX_ORDER], + next_page: [0; MAX_ORDER], + free_pages: [0; MAX_ORDER], + } + } + + /// Converts a physical address within this memory region to a virtual address. + #[expect(dead_code)] + fn phys_to_virt(&self, paddr: PhysAddr) -> Option { + let end_phys = self.start_phys + (self.page_count * PAGE_SIZE); + + if paddr < self.start_phys || paddr >= end_phys { + // For the initial stage2 identity mapping, the root page table + // pages are static and outside of the heap memory region. + if VirtAddr::from(self.start_phys.bits()) == self.start_virt { + return Some(VirtAddr::from(paddr.bits())); + } + return None; + } + + let offset = paddr - self.start_phys; + + Some(self.start_virt + offset) + } + + /// Converts a virtual address to a physical address within the memory region. + #[expect(dead_code)] + fn virt_to_phys(&self, vaddr: VirtAddr) -> Option { + let offset = self.get_virt_offset(vaddr)?; + Some(self.start_phys + offset) + } + + /// Gets a mutable pointer to the page information for a given page frame + /// number. + /// + /// # Safety + /// + /// The caller must provide a valid pfn, otherwise the returned pointer is + /// undefined, as the compiler is allowed to optimize assuming there will + /// be no arithmetic overflows. + unsafe fn page_info_mut_ptr(&mut self, pfn: usize) -> *mut PageStorageType { + self.start_virt.as_mut_ptr::().add(pfn) + } + + /// Gets a pointer to the page information for a given page frame number. + /// + /// # Safety + /// + /// The caller must provide a valid pfn, otherwise the returned pointer is + /// undefined, as the compiler is allowed to optimize assuming there will + /// be no arithmetic overflows. + unsafe fn page_info_ptr(&self, pfn: usize) -> *const PageStorageType { + self.start_virt.as_ptr::().add(pfn) + } + + /// Checks if a page frame number is valid. + /// + /// # Panics + /// + /// Panics if the page frame number is invalid. + fn check_pfn(&self, pfn: usize) { + if pfn >= self.page_count { + panic!("Invalid Page Number {}", pfn); + } + } + + /// Calculates the end virtual address of the memory region. + fn end_virt(&self) -> VirtAddr { + self.start_virt + (self.page_count * PAGE_SIZE) + } + + /// Writes page information for a given page frame number. + fn write_page_info(&mut self, pfn: usize, pi: PageInfo) { + self.check_pfn(pfn); + + let info: PageStorageType = pi.to_mem(); + // SAFETY: we have checked that the pfn is valid via check_pfn() above. + unsafe { self.page_info_mut_ptr(pfn).write(info) }; + } + + /// Reads page information for a given page frame number. + fn read_page_info(&self, pfn: usize) -> PageInfo { + self.check_pfn(pfn); + + // SAFETY: we have checked that the pfn is valid via check_pfn() above. + let info = unsafe { self.page_info_ptr(pfn).read() }; + PageInfo::from_mem(info) + } + + /// Gets the virtual offset of a virtual address within the memory region. + fn get_virt_offset(&self, vaddr: VirtAddr) -> Option { + (self.start_virt <= vaddr && vaddr < self.end_virt()).then(|| vaddr - self.start_virt) + } + + /// Gets the page frame number for a given virtual address. + fn get_pfn(&self, vaddr: VirtAddr) -> Result { + self.get_virt_offset(vaddr) + .map(|off| off / PAGE_SIZE) + .ok_or(AllocError::InvalidHeapAddress(vaddr)) + } + + /// Gets the next available page frame number for a given order. + fn get_next_page(&mut self, order: usize) -> Result { + let pfn = self.next_page[order]; + + if pfn == 0 { + return Err(AllocError::OutOfMemory); + } + + let pg = self.read_page_info(pfn); + let PageInfo::Free(fi) = pg else { + panic!( + "Unexpected page type in MemoryRegion::get_next_page() {:?}", + pg + ); + }; + + self.next_page[order] = fi.next_page; + + self.free_pages[order] -= 1; + + Ok(pfn) + } + + /// Marks a compound page and updates page information for neighboring pages. + fn mark_compound_page(&mut self, pfn: usize, order: usize) { + let nr_pages: usize = 1 << order; + let compound = PageInfo::Compound(CompoundInfo { order }); + for i in 1..nr_pages { + self.write_page_info(pfn + i, compound); + } + } + + /// Initializes a compound page with given page frame numbers and order. + fn init_compound_page(&mut self, pfn: usize, order: usize, next_pfn: usize) { + let head = PageInfo::Free(FreeInfo { + next_page: next_pfn, + order, + }); + self.write_page_info(pfn, head); + self.mark_compound_page(pfn, order); + } + + /// Splits a page into two pages of the next lower order. + fn split_page(&mut self, pfn: usize, order: usize) -> Result<(), AllocError> { + if !(1..MAX_ORDER).contains(&order) { + return Err(AllocError::InvalidPageOrder(order)); + } + + let new_order = order - 1; + let pfn1 = pfn; + let pfn2 = pfn + (1usize << new_order); + + let next_pfn = self.next_page[new_order]; + self.init_compound_page(pfn1, new_order, pfn2); + self.init_compound_page(pfn2, new_order, next_pfn); + self.next_page[new_order] = pfn1; + + // Do the accounting + self.nr_pages[order] -= 1; + self.nr_pages[new_order] += 2; + self.free_pages[new_order] += 2; + + Ok(()) + } + + /// Refills the free page list for a given order. + fn refill_page_list(&mut self, order: usize) -> Result<(), AllocError> { + let next_page = *self + .next_page + .get(order) + .ok_or(AllocError::InvalidPageOrder(order))?; + if next_page != 0 { + return Ok(()); + } + + self.refill_page_list(order + 1)?; + let pfn = self.get_next_page(order + 1)?; + self.split_page(pfn, order + 1) + } + + /// Allocates pages with a specific order and page information. + fn allocate_pages_info(&mut self, order: usize, pg: PageInfo) -> Result { + self.refill_page_list(order)?; + let pfn = self.get_next_page(order)?; + self.write_page_info(pfn, pg); + Ok(self.start_virt + (pfn * PAGE_SIZE)) + } + + /// Allocates pages with a specific order. + fn allocate_pages(&mut self, order: usize) -> Result { + let pg = PageInfo::Allocated(AllocatedInfo { order }); + self.allocate_pages_info(order, pg) + } + + /// Allocates a single page. + fn allocate_page(&mut self) -> Result { + self.allocate_pages(0) + } + + /// Allocates a zeroed page. + fn allocate_zeroed_page(&mut self) -> Result { + let vaddr = self.allocate_page()?; + + zero_mem_region(vaddr, vaddr + PAGE_SIZE); + + Ok(vaddr) + } + + /// Allocates a slab page. + fn allocate_slab_page(&mut self, item_size: u16) -> Result { + self.refill_page_list(0)?; + + let pfn = self.get_next_page(0)?; + let pg = PageInfo::Slab(SlabPageInfo { + item_size: u64::from(item_size), + }); + self.write_page_info(pfn, pg); + Ok(self.start_virt + (pfn * PAGE_SIZE)) + } + + /// Allocates a file page with initial reference count. + fn allocate_file_page(&mut self) -> Result { + let pg = PageInfo::File(FileInfo::new(1)); + self.allocate_pages_info(0, pg) + } + + /// Gets a file page and increments its reference count. + fn get_file_page(&mut self, vaddr: VirtAddr) -> Result<(), AllocError> { + let pfn = self.get_pfn(vaddr)?; + let page = self.read_page_info(pfn); + let PageInfo::File(mut fi) = page else { + return Err(AllocError::InvalidFilePage(vaddr)); + }; + + assert!(fi.ref_count > 0); + fi.ref_count += 1; + self.write_page_info(pfn, PageInfo::File(fi)); + + Ok(()) + } + + /// Releases a file page and decrements its reference count. + fn put_file_page(&mut self, vaddr: VirtAddr) -> Result<(), AllocError> { + let pfn = self.get_pfn(vaddr)?; + let page = self.read_page_info(pfn); + let PageInfo::File(mut fi) = page else { + return Err(AllocError::InvalidFilePage(vaddr)); + }; + + fi.ref_count = fi + .ref_count + .checked_sub(1) + .expect("page refcount underflow"); + if fi.ref_count > 0 { + self.write_page_info(pfn, PageInfo::File(fi)); + } else { + self.free_page(vaddr) + } + + Ok(()) + } + + /// Finds the neighboring page frame number for a compound page. + fn compound_neighbor(&self, pfn: usize, order: usize) -> Result { + if order >= MAX_ORDER - 1 { + return Err(AllocError::InvalidPageOrder(order)); + } + + assert_eq!(pfn & ((1usize << order) - 1), 0); + let pfn = pfn ^ (1usize << order); + if pfn >= self.page_count { + return Err(AllocError::InvalidPfn(pfn)); + } + + Ok(pfn) + } + + /// Merges two pages of the same order into a new compound page. + fn merge_pages(&mut self, pfn1: usize, pfn2: usize, order: usize) -> Result { + if order >= MAX_ORDER - 1 { + return Err(AllocError::InvalidPageOrder(order)); + } + + let nr_pages: usize = 1 << (order + 1); + let pfn = pfn1.min(pfn2); + + // Write new compound head + let pg = PageInfo::Allocated(AllocatedInfo { order: order + 1 }); + self.write_page_info(pfn, pg); + + // Write compound pages + let pg = PageInfo::Compound(CompoundInfo { order: order + 1 }); + for i in 1..nr_pages { + self.write_page_info(pfn + i, pg); + } + + // Do the accounting - none of the pages is free yet, so free_pages is + // not updated here. + self.nr_pages[order] -= 2; + self.nr_pages[order + 1] += 1; + + Ok(pfn) + } + + /// Gets the next free page frame number from the free list. + fn next_free_pfn(&self, pfn: usize, order: usize) -> usize { + let page = self.read_page_info(pfn); + let PageInfo::Free(fi) = page else { + panic!("Unexpected page type in free-list for order {}", order); + }; + + fi.next_page + } + + /// Allocates a specific page frame number (`pfn`) within a given order. + /// If the page frame number is not found or is already allocated, an error + /// is returned. If the requested page frame number is the first in the + /// list, it is marked as allocated, and the next page in the list becomes + /// the new first page. + /// + /// # Panics + /// + /// Panics if `order` is greater than [`MAX_ORDER`]. + fn allocate_pfn(&mut self, pfn: usize, order: usize) -> Result<(), AllocError> { + let first_pfn = self.next_page[order]; + + // Handle special cases first + if first_pfn == 0 { + // No pages for that order + return Err(AllocError::OutOfMemory); + } else if first_pfn == pfn { + // Requested pfn is first in list + self.get_next_page(order).unwrap(); + return Ok(()); + } + + // Now walk the list + let mut old_pfn = first_pfn; + loop { + let current_pfn = self.next_free_pfn(old_pfn, order); + if current_pfn == 0 { + return Err(AllocError::OutOfMemory); + } + + if current_pfn != pfn { + old_pfn = current_pfn; + continue; + } + + let next_pfn = self.next_free_pfn(current_pfn, order); + let pg = PageInfo::Free(FreeInfo { + next_page: next_pfn, + order, + }); + self.write_page_info(old_pfn, pg); + + let pg = PageInfo::Allocated(AllocatedInfo { order }); + self.write_page_info(current_pfn, pg); + + self.free_pages[order] -= 1; + + return Ok(()); + } + } + + /// Frees a raw page by updating the free list and marking it as a free page. + /// + /// # Panics + /// + /// Panics if `order` is greater than [`MAX_ORDER`]. + fn free_page_raw(&mut self, pfn: usize, order: usize) { + let old_next = self.next_page[order]; + let pg = PageInfo::Free(FreeInfo { + next_page: old_next, + order, + }); + + self.write_page_info(pfn, pg); + self.next_page[order] = pfn; + + self.free_pages[order] += 1; + } + + /// Attempts to merge a given page with its neighboring page. + /// If successful, returns the new page frame number after merging. + /// If unsuccessful, the page remains unmerged, and an error is returned. + fn try_to_merge_page(&mut self, pfn: usize, order: usize) -> Result { + let neighbor_pfn = self.compound_neighbor(pfn, order)?; + let neighbor_page = self.read_page_info(neighbor_pfn); + + let PageInfo::Free(fi) = neighbor_page else { + return Err(AllocError::InvalidPfn(neighbor_pfn)); + }; + + if fi.order != order { + return Err(AllocError::InvalidPageOrder(fi.order)); + } + + self.allocate_pfn(neighbor_pfn, order)?; + + let new_pfn = self.merge_pages(pfn, neighbor_pfn, order)?; + + Ok(new_pfn) + } + + /// Frees a page of a specific order. If merging is successful, it + /// continues merging until merging is no longer possible. If merging + /// fails, the page is marked as a free page. + fn free_page_order(&mut self, pfn: usize, order: usize) { + match self.try_to_merge_page(pfn, order) { + Err(_) => { + self.free_page_raw(pfn, order); + } + Ok(new_pfn) => { + self.free_page_order(new_pfn, order + 1); + } + } + } + + /// Frees a page based on its virtual address, determining the page + /// order and freeing accordingly. + fn free_page(&mut self, vaddr: VirtAddr) { + let Ok(pfn) = self.get_pfn(vaddr) else { + return; + }; + + let res = self.read_page_info(pfn); + + match res { + PageInfo::Allocated(ai) => { + self.free_page_order(pfn, ai.order); + } + PageInfo::Slab(_si) => { + self.free_page_order(pfn, 0); + } + PageInfo::Compound(ci) => { + let mask = (1usize << ci.order) - 1; + let start_pfn = pfn & !mask; + self.free_page_order(start_pfn, ci.order); + } + PageInfo::File(_) => { + self.free_page_order(pfn, 0); + } + _ => { + panic!("Unexpected page type in MemoryRegion::free_page()"); + } + } + } + + /// Retrieves information about memory, including total and free pages + /// in different orders. + fn memory_info(&self) -> MemInfo { + MemInfo { + total_pages: self.nr_pages, + free_pages: self.free_pages, + } + } + + /// Initializes memory by marking certain pages as reserved and the rest + /// as allocated. It then frees all pages and organizes them into their + /// respective order buckets. + fn init_memory(&mut self) { + let size = size_of::(); + let meta_pages = align_up(self.page_count * size, PAGE_SIZE) / PAGE_SIZE; + + /* Mark page storage as reserved */ + for i in 0..meta_pages { + let pg = PageInfo::Reserved(ReservedInfo); + self.write_page_info(i, pg); + } + + /* Mark all pages as allocated */ + for i in meta_pages..self.page_count { + let pg = PageInfo::Allocated(AllocatedInfo { order: 0 }); + self.write_page_info(i, pg); + } + + /* Now free all pages. Any runs of pages aligned to the maximum order + * will be freed directly into the maximum order bucket, and all other + * pages will be freed individually so the correct orders can be + * generated */ + let alignment = 1 << (MAX_ORDER - 1); + let first_aligned_page = align_up(meta_pages, alignment); + let last_aligned_page = align_down(self.page_count, alignment); + + if first_aligned_page < last_aligned_page { + self.nr_pages[MAX_ORDER - 1] += (last_aligned_page - first_aligned_page) / alignment; + for i in (first_aligned_page..last_aligned_page).step_by(alignment) { + self.mark_compound_page(i, MAX_ORDER - 1); + self.free_page_raw(i, MAX_ORDER - 1); + } + + if first_aligned_page < self.page_count { + self.nr_pages[0] += first_aligned_page - meta_pages; + for i in meta_pages..first_aligned_page { + self.free_page_order(i, 0); + } + } + + if last_aligned_page > meta_pages { + self.nr_pages[0] += self.page_count - last_aligned_page; + for i in last_aligned_page..self.page_count { + self.free_page_order(i, 0); + } + } + } else { + // Special case: Memory region size smaller than a MAX_ORDER allocation + self.nr_pages[0] = self.page_count - meta_pages; + for i in meta_pages..self.page_count { + self.free_page_order(i, 0); + } + } + } +} + +/// Represents a reference to a memory page, holding both virtual and +/// physical addresses. +#[derive(Debug)] +pub struct PageRef { + virt_addr: VirtAddr, + phys_addr: PhysAddr, +} + +impl PageRef { + /// Allocate a reference-counted file page. + pub fn new() -> Result { + let virt_addr = allocate_file_page()?; + let phys_addr = virt_to_phys(virt_addr); + + Ok(Self { + virt_addr, + phys_addr, + }) + } + + /// Returns the virtual address of the memory page. + pub fn virt_addr(&self) -> VirtAddr { + self.virt_addr + } + + /// Returns the physical address of the memory page. + pub fn phys_addr(&self) -> PhysAddr { + self.phys_addr + } + + pub fn try_copy_page(&self) -> Result { + let virt_addr = allocate_file_page()?; + + let src = self.virt_addr.bits(); + let dst = virt_addr.bits(); + let size = PAGE_SIZE; + unsafe { + // SAFETY: `src` and `dst` are both valid. + copy_bytes(src, dst, size); + } + + Ok(PageRef { + virt_addr, + phys_addr: virt_to_phys(virt_addr), + }) + } + + pub fn write(&self, offset: usize, buf: &[u8]) { + assert!(offset.checked_add(buf.len()).unwrap() <= PAGE_SIZE); + + let src = buf.as_ptr() as usize; + let dst = self.virt_addr.bits() + offset; + let size = buf.len(); + unsafe { + // SAFETY: `src` and `dst` are both valid. + copy_bytes(src, dst, size); + } + } + + pub fn read(&self, offset: usize, buf: &mut [u8]) { + assert!(offset.checked_add(buf.len()).unwrap() <= PAGE_SIZE); + + let src = self.virt_addr.bits() + offset; + let dst = buf.as_mut_ptr() as usize; + let size = buf.len(); + unsafe { + // SAFETY: `src` and `dst` are both valid. + copy_bytes(src, dst, size); + } + } + + pub fn fill(&self, offset: usize, value: u8) { + let dst = self.virt_addr.bits() + offset; + let size = PAGE_SIZE.checked_sub(offset).unwrap(); + + unsafe { + // SAFETY: `dst` is valid. + write_bytes(dst, size, value); + } + } +} + +impl Clone for PageRef { + /// Clones the [`PageRef`] instance, obtaining a new reference to the same memory page. + fn clone(&self) -> Self { + get_file_page(self.virt_addr).expect("Failed to get page reference"); + PageRef { + virt_addr: self.virt_addr, + phys_addr: self.phys_addr, + } + } +} + +impl Drop for PageRef { + /// Drops the [`PageRef`] instance, decreasing the reference count for + /// the associated memory page. + fn drop(&mut self) { + put_file_page(self.virt_addr).expect("Failed to drop page reference"); + } +} + +/// Prints memory information based on the provided [`MemInfo`] structure. +/// +/// # Arguments +/// +/// * `info` - Reference to [`MemInfo`] structure containing memory information. +pub fn print_memory_info(info: &MemInfo) { + let mut pages_4k = 0; + let mut free_pages_4k = 0; + + for i in 0..MAX_ORDER { + let nr_4k_pages: usize = 1 << i; + log::info!( + "Order-{:#02}: total pages: {:#5} free pages: {:#5}", + i, + info.total_pages[i], + info.free_pages[i] + ); + pages_4k += info.total_pages[i] * nr_4k_pages; + free_pages_4k += info.free_pages[i] * nr_4k_pages; + } + + log::info!( + "Total memory: {}KiB free memory: {}KiB", + (pages_4k * PAGE_SIZE) / 1024, + (free_pages_4k * PAGE_SIZE) / 1024 + ); +} + +/// Static spinlock-protected instance of [`MemoryRegion`] representing the +/// root memory region. +static ROOT_MEM: SpinLock = SpinLock::new(MemoryRegion::new()); + +/// Allocates a single memory page from the root memory region. +/// +/// # Returns +/// +/// Result containing the virtual address of the allocated page or an +/// `SvsmError` if allocation fails. +pub fn allocate_page() -> Result { + Ok(ROOT_MEM.lock().allocate_page()?) +} + +/// Allocates multiple memory pages with a specified order from the root +/// memory region. +/// +/// # Arguments +/// +/// * `order` - Order of the allocation, determining the number of pages (2^order). +/// +/// # Returns +/// +/// Result containing the virtual address of the allocated pages or an +/// `SvsmError` if allocation fails. +pub fn allocate_pages(order: usize) -> Result { + Ok(ROOT_MEM.lock().allocate_pages(order)?) +} + +/// Allocate a slab page. +/// +/// # Arguments +/// +/// `slab` - slab virtual address +/// +/// # Returns +/// +/// Result containing the virtual address of the allocated slab page or an +/// `SvsmError` if allocation fails. +pub fn allocate_slab_page(item_size: u16) -> Result { + Ok(ROOT_MEM.lock().allocate_slab_page(item_size)?) +} + +/// Allocate a zeroed page. +/// +/// # Returns +/// +/// Result containing the virtual address of the allocated zeroed page or an +/// `SvsmError` if allocation fails. +pub fn allocate_zeroed_page() -> Result { + Ok(ROOT_MEM.lock().allocate_zeroed_page()?) +} + +/// Allocate a file page. +/// +/// # Returns +/// +/// Result containing the virtual address of the allocated file page or an +/// `SvsmError` if allocation fails. +pub fn allocate_file_page() -> Result { + let vaddr = ROOT_MEM.lock().allocate_file_page()?; + zero_mem_region(vaddr, vaddr + PAGE_SIZE); + Ok(vaddr) +} + +fn get_file_page(vaddr: VirtAddr) -> Result<(), SvsmError> { + Ok(ROOT_MEM.lock().get_file_page(vaddr)?) +} + +fn put_file_page(vaddr: VirtAddr) -> Result<(), SvsmError> { + Ok(ROOT_MEM.lock().put_file_page(vaddr)?) +} + +/// Free the page at the given virtual address. +pub fn free_page(vaddr: VirtAddr) { + ROOT_MEM.lock().free_page(vaddr) +} + +/// Retrieve information about the root memory +pub fn memory_info() -> MemInfo { + ROOT_MEM.lock().memory_info() +} + +/// Represents a slab memory page, used for efficient allocation of +/// fixed-size objects. +#[derive(Debug, Default)] +struct SlabPage { + vaddr: VirtAddr, + free: u16, + used_bitmap: [u64; 2], + next_page: VirtAddr, +} + +impl SlabPage { + /// Creates a new [`SlabPage`] instance with default values. + const fn new() -> Self { + assert!(N <= (PAGE_SIZE / 2) as u16); + Self { + vaddr: VirtAddr::null(), + free: 0, + used_bitmap: [0; 2], + next_page: VirtAddr::null(), + } + } + + /// Initialize the [`SlabPage`]. + fn init(&mut self) -> Result<(), AllocError> { + if !self.vaddr.is_null() { + return Ok(()); + } + let vaddr = ROOT_MEM.lock().allocate_slab_page(N)?; + self.vaddr = vaddr; + self.free = self.get_capacity(); + + Ok(()) + } + + /// Free the memory (destroy) the [`SlabPage`] + #[expect(clippy::needless_pass_by_ref_mut)] + fn destroy(&mut self) { + if self.vaddr.is_null() { + return; + } + + free_page(self.vaddr); + } + + /// Get the capacity of the [`SlabPage`] + const fn get_capacity(&self) -> u16 { + (PAGE_SIZE as u16) / N + } + + fn get_free(&self) -> u16 { + self.free + } + + /// Get the virtual address of the next [`SlabPage`] + fn get_next_page(&self) -> VirtAddr { + self.next_page + } + + fn set_next_page(&mut self, next_page: VirtAddr) { + self.next_page = next_page; + } + + fn allocate(&mut self) -> Result { + if self.free == 0 { + return Err(AllocError::OutOfMemory); + } + + for i in 0..self.get_capacity() { + let idx = (i / 64) as usize; + let mask = 1u64 << (i % 64); + + if self.used_bitmap[idx] & mask == 0 { + self.used_bitmap[idx] |= mask; + self.free -= 1; + return Ok(self.vaddr + ((N * i) as usize)); + } + } + + Err(AllocError::OutOfMemory) + } + + fn free(&mut self, vaddr: VirtAddr) -> Result<(), AllocError> { + if vaddr < self.vaddr || vaddr >= self.vaddr + PAGE_SIZE { + return Err(AllocError::InvalidHeapAddress(vaddr)); + } + + let item_size = N as usize; + let offset = vaddr - self.vaddr; + let i = offset / item_size; + let idx = i / 64; + let mask = 1u64 << (i % 64); + + self.used_bitmap[idx] &= !mask; + self.free += 1; + + Ok(()) + } +} + +/// Represents common information shared among multiple slab pages. +#[derive(Debug, Default)] +#[repr(align(16))] +struct SlabCommon { + capacity: u32, + free: u32, + pages: u32, + full_pages: u32, + free_pages: u32, + page: SlabPage, +} + +impl SlabCommon { + const fn new() -> Self { + Self { + capacity: 0, + free: 0, + pages: 0, + full_pages: 0, + free_pages: 0, + page: SlabPage::new(), + } + } + + /// Initialize the [`SlabCommon`] with default values + fn init(&mut self) -> Result<(), AllocError> { + self.page.init()?; + self.capacity = self.page.get_capacity() as u32; + self.free = self.capacity; + self.pages = 1; + self.full_pages = 0; + self.free_pages = 1; + Ok(()) + } + + /// Add other [`SlabPage`]. + fn add_slab_page(&mut self, new_page: &mut SlabPage) { + let old_next_page = self.page.get_next_page(); + new_page.set_next_page(old_next_page); + + self.page + .set_next_page(VirtAddr::from(new_page as *mut SlabPage)); + + let capacity = new_page.get_capacity() as u32; + self.pages += 1; + self.free_pages += 1; + self.capacity += capacity; + self.free += capacity; + } + + /// Allocate other slot, caller must make sure there's at least one + /// free slot + fn allocate_slot(&mut self) -> VirtAddr { + // Caller must make sure there's at least one free slot. + assert_ne!(self.free, 0); + let mut page = &mut self.page; + loop { + let free = page.get_free(); + + if let Ok(vaddr) = page.allocate() { + let capacity = page.get_capacity(); + self.free -= 1; + + if free == capacity { + self.free_pages -= 1; + } else if free == 1 { + self.full_pages += 1; + } + + return vaddr; + } + + let next_page = page.get_next_page(); + // Cannot fail with free slots on entry. + page = unsafe { next_page.aligned_mut().expect("Invalid next page") }; + } + } + + /// Deallocate a slot given its virtual address + fn deallocate_slot(&mut self, vaddr: VirtAddr) { + let mut page = &mut self.page; + loop { + let free = page.get_free(); + + if let Ok(_o) = page.free(vaddr) { + let capacity = page.get_capacity(); + self.free += 1; + + if free == 0 { + self.full_pages -= 1; + } else if free + 1 == capacity { + self.free_pages += 1; + } + + return; + } + + let next_page = page.get_next_page(); + // Will fail if the object does not belong to this slab. + page = unsafe { next_page.aligned_mut().expect("Invalid next page") }; + } + } + + /// Finds an unused slab page and removes it from the slab. + fn free_one_page(&mut self) -> *mut SlabPage { + let mut last_page = &mut self.page; + let mut next_page_vaddr = last_page.get_next_page(); + loop { + let slab_page = unsafe { + next_page_vaddr + .aligned_mut::>() + .expect("couldn't find page to free") + }; + next_page_vaddr = slab_page.get_next_page(); + + let capacity = slab_page.get_capacity(); + let free = slab_page.get_free(); + if free != capacity { + last_page = slab_page; + continue; + } + + let capacity = slab_page.get_capacity() as u32; + self.pages -= 1; + self.free_pages -= 1; + self.capacity -= capacity; + self.free -= capacity; + + last_page.set_next_page(slab_page.get_next_page()); + + slab_page.destroy(); + + return slab_page; + } + } +} + +// 32 is chosen arbitrarily here, it does not affect struct layout +const SLAB_PAGE_SIZE: u16 = size_of::>() as u16; + +/// Represents a slab page for the [`SlabPageSlab`] allocator. +#[derive(Debug)] +struct SlabPageSlab { + common: SlabCommon, +} + +impl SlabPageSlab { + /// Creates a new [`SlabPageSlab`] with a default [`SlabCommon`]. + const fn new() -> Self { + Self { + common: SlabCommon::new(), + } + } + + /// Initializes the [`SlabPageSlab`], allocating the first slab page if necessary. + fn init(&mut self) -> Result<(), AllocError> { + self.common.init() + } + + /// Grows the slab by allocating a new slab page. + fn grow_slab(&mut self) -> Result<(), AllocError> { + if self.common.capacity == 0 { + self.init()?; + return Ok(()); + } + + // Make sure there's always at least one SlabPage slot left for extending the SlabPageSlab itself. + if self.common.free >= 2 { + return Ok(()); + } + assert_ne!(self.common.free, 0); + + let page_vaddr = self.common.allocate_slot(); + let slab_page = unsafe { &mut *page_vaddr.as_mut_ptr::>() }; + + *slab_page = SlabPage::new(); + if let Err(e) = slab_page.init() { + self.common.deallocate_slot(page_vaddr); + return Err(e); + } + + self.common.add_slab_page(slab_page); + + Ok(()) + } + + /// Shrinks the slab by freeing unused slab pages. + fn shrink_slab(&mut self) { + // The SlabPageSlab uses SlabPages on its own and freeing a SlabPage can empty another SlabPage. + while self.common.free_pages > 1 { + let slab_page = self.common.free_one_page(); + self.common.deallocate_slot(VirtAddr::from(slab_page)); + } + } + + /// Allocates a slot in the slab. + /// + /// # Returns + /// + /// Result containing a pointer to the allocated [`SlabPage`] or an `AllocError`. + fn allocate(&mut self) -> Result<*mut SlabPage, AllocError> { + self.grow_slab()?; + Ok(self.common.allocate_slot().as_mut_ptr()) + } + + /// Deallocates a slab page, freeing the associated memory. + /// + /// # Arguments + /// + /// * `slab_page` - Pointer to the [`SlabPage`] to deallocate. + fn deallocate(&mut self, slab_page: *mut SlabPage) { + self.common.deallocate_slot(VirtAddr::from(slab_page)); + self.shrink_slab(); + } +} + +/// Represents a slab allocator for fixed-size objects. +#[derive(Debug, Default)] +struct Slab { + common: SlabCommon, +} + +impl Slab { + const fn new() -> Self { + Self { + common: SlabCommon::new(), + } + } + + /// Initialize the [`Slab`] instance + fn init(&mut self) -> Result<(), AllocError> { + self.common.init() + } + + fn grow_slab(&mut self) -> Result<(), AllocError> { + if self.common.capacity == 0 { + return self.init(); + } + + if self.common.free != 0 { + return Ok(()); + } + + let slab_page_ptr = SLAB_PAGE_SLAB.lock().allocate()?; + let page_ptr = slab_page_ptr.cast::>(); + unsafe { page_ptr.write(SlabPage::::new()) }; + let page = unsafe { &mut *page_ptr }; + if let Err(e) = page.init() { + SLAB_PAGE_SLAB.lock().deallocate(slab_page_ptr); + return Err(e); + } + + self.common.add_slab_page(page); + Ok(()) + } + + fn shrink_slab(&mut self) { + if self.common.free_pages <= 1 || 2 * self.common.free < self.common.capacity { + return; + } + + let slab_page = self.common.free_one_page(); + SLAB_PAGE_SLAB.lock().deallocate(slab_page.cast()); + } + + fn allocate(&mut self) -> Result { + self.grow_slab()?; + Ok(self.common.allocate_slot()) + } + + fn deallocate(&mut self, vaddr: VirtAddr) { + self.common.deallocate_slot(vaddr); + self.shrink_slab(); + } +} + +/// Static spinlock-protected instance of [`SlabPageSlab`] representing the +/// slab page allocator. +static SLAB_PAGE_SLAB: SpinLock = SpinLock::new(SlabPageSlab::new()); + +/// Represents a simple virtual-to-physical memory allocator ([`SvsmAllocator`]) +/// implementing the [`GlobalAlloc`] trait. +/// +/// This allocator uses slab allocation for fixed-size objects and falls +/// back to page allocation for larger objects. +#[derive(Debug, Default)] +pub struct SvsmAllocator { + slab32: SpinLock>, + slab64: SpinLock>, + slab128: SpinLock>, + slab256: SpinLock>, + slab512: SpinLock>, + slab1024: SpinLock>, + slab2048: SpinLock>, +} + +impl SvsmAllocator { + /// Creates a new instance of [`SvsmAllocator`] with initialized slab + /// allocators. + pub const fn new() -> Self { + Self { + slab32: SpinLock::new(Slab::new()), + slab64: SpinLock::new(Slab::new()), + slab128: SpinLock::new(Slab::new()), + slab256: SpinLock::new(Slab::new()), + slab512: SpinLock::new(Slab::new()), + slab1024: SpinLock::new(Slab::new()), + slab2048: SpinLock::new(Slab::new()), + } + } + + fn allocate(&self, size: usize) -> Option> { + let size = size.checked_next_power_of_two()?; + match size { + ..=32 => Some(self.slab32.lock().allocate()), + 64 => Some(self.slab64.lock().allocate()), + 128 => Some(self.slab128.lock().allocate()), + 256 => Some(self.slab256.lock().allocate()), + 512 => Some(self.slab512.lock().allocate()), + 1024 => Some(self.slab1024.lock().allocate()), + 2048 => Some(self.slab2048.lock().allocate()), + _ => None, + } + } + + fn deallocate(&self, addr: VirtAddr, size: usize) -> Option<()> { + let size = size.checked_next_power_of_two()?; + match size { + ..=32 => self.slab32.lock().deallocate(addr), + 64 => self.slab64.lock().deallocate(addr), + 128 => self.slab128.lock().deallocate(addr), + 256 => self.slab256.lock().deallocate(addr), + 512 => self.slab512.lock().deallocate(addr), + 1024 => self.slab1024.lock().deallocate(addr), + 2048 => self.slab2048.lock().deallocate(addr), + _ => return None, + } + + Some(()) + } + + /// Resets the internal state. This is equivalent to reassigning `self` + /// with a newly created [`SvsmAllocator`] with `Self::new()`. + #[cfg(all(not(test_in_svsm), any(test, fuzzing)))] + fn reset(&self) { + *self.slab32.lock() = Slab::new(); + *self.slab64.lock() = Slab::new(); + *self.slab128.lock() = Slab::new(); + *self.slab256.lock() = Slab::new(); + *self.slab512.lock() = Slab::new(); + *self.slab1024.lock() = Slab::new(); + *self.slab2048.lock() = Slab::new(); + } +} + +unsafe impl GlobalAlloc for SvsmAllocator { + /// Allocates memory based on the specified layout. + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let size = layout.size(); + let ret = match self.allocate(size) { + Some(v) => v.map_err(Into::into), + None => { + let order = get_order(size); + if order >= MAX_ORDER { + return ptr::null_mut(); + } + allocate_pages(order) + } + }; + ret.map_or_else(|_| ptr::null_mut(), |addr| addr.as_mut_ptr::()) + } + + /// Deallocates memory based on the specified pointer and layout. + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + let virt_addr = VirtAddr::from(ptr); + let size = layout.size(); + + let info = { + let mem = ROOT_MEM.lock(); + let pfn = mem.get_pfn(virt_addr).expect("Freeing unknown memory"); + mem.read_page_info(pfn) + }; + + match info { + PageInfo::Allocated(_ai) => { + free_page(virt_addr); + } + PageInfo::Slab(_) => { + self.deallocate(virt_addr, size).expect("Invalid page info"); + } + _ => { + panic!("Freeing memory on unsupported page type"); + } + } + } +} + +#[cfg_attr(any(target_os = "none"), global_allocator)] +static ALLOCATOR: SvsmAllocator = SvsmAllocator::new(); + +/// Initializes the root memory region with the specified physical start +/// address, virtual start address, and page count. +pub fn root_mem_init(pstart: PhysAddr, vstart: VirtAddr, page_count: usize) { + { + let mut region = ROOT_MEM.lock(); + region.start_phys = pstart; + region.start_virt = vstart; + region.page_count = page_count; + region.init_memory(); + // drop lock here so slab initialization does not deadlock + } + + SLAB_PAGE_SLAB + .lock() + .init() + .expect("Failed to initialize SLAB_PAGE_SLAB"); +} + +#[cfg(any(test, fuzzing))] +/// A global lock on global memory. Should only be acquired via +/// [`TestRootMem::setup()`]. +static TEST_ROOT_MEM_LOCK: SpinLock<()> = SpinLock::new(()); + +pub const MIN_ALIGN: usize = 32; + +pub fn layout_from_size(size: usize) -> Layout { + let align: usize = { + if (size % PAGE_SIZE) == 0 { + PAGE_SIZE + } else { + MIN_ALIGN + } + }; + Layout::from_size_align(size, align).unwrap() +} + +pub fn layout_from_ptr(ptr: *mut u8) -> Option { + let va = VirtAddr::from(ptr); + + let root = ROOT_MEM.lock(); + let pfn = root.get_pfn(va).ok()?; + let info = root.read_page_info(pfn); + + match info { + PageInfo::Allocated(ai) => { + let base: usize = 2; + let size: usize = base.pow(ai.order as u32) * PAGE_SIZE; + Some(Layout::from_size_align(size, PAGE_SIZE).unwrap()) + } + PageInfo::Slab(si) => { + let size = si.item_size as usize; + Some(Layout::from_size_align(size, size).unwrap()) + } + _ => None, + } +} + +#[cfg(test)] +pub const DEFAULT_TEST_MEMORY_SIZE: usize = 16usize * 1024 * 1024; + +/// A dummy struct to acquire a lock over global memory for tests. +#[cfg(any(test, fuzzing))] +#[derive(Debug)] +#[expect(dead_code)] +pub struct TestRootMem<'a>(LockGuard<'a, ()>); + +#[cfg(any(test, fuzzing))] +impl TestRootMem<'_> { + #[cfg(test_in_svsm)] + /// Sets up a test environment, returning a guard to ensure memory is + /// held for the test's duration. This test function is intended to + /// called inside a running SVSM. + /// + /// # Returns + /// + /// A guard that ensures the memory lock is held during the test. + #[must_use = "memory guard must be held for the whole test"] + pub fn setup(_size: usize) -> Self { + // We do not need to set up root memory if running inside the SVSM. + Self(TEST_ROOT_MEM_LOCK.lock()) + } + + /// Sets up a test environment, returning a guard to ensure memory is + /// held for the test's duration. This function does not run inside + /// the SVSM. + /// + /// # Returns + /// + /// A guard that ensures the memory lock is held during the test. + #[cfg(not(test_in_svsm))] + #[must_use = "memory guard must be held for the whole test"] + pub fn setup(size: usize) -> Self { + extern crate alloc; + use alloc::alloc::{alloc, handle_alloc_error}; + + let layout = Layout::from_size_align(size, PAGE_SIZE) + .unwrap() + .pad_to_align(); + let ptr = unsafe { alloc(layout) }; + if ptr.is_null() { + handle_alloc_error(layout); + } else if ptr as usize & (PAGE_SIZE - 1) != 0 { + panic!("test memory region allocation not aligned to page size"); + } + + let page_count = layout.size() / PAGE_SIZE; + let guard = Self(TEST_ROOT_MEM_LOCK.lock()); + let vaddr = VirtAddr::from(ptr); + let paddr = PhysAddr::from(vaddr.bits()); // Identity mapping + root_mem_init(paddr, vaddr, page_count); + guard + } +} + +#[cfg(all(not(test_in_svsm), any(test, fuzzing)))] +impl Drop for TestRootMem<'_> { + /// If running tests in userspace, destroy root memory before + /// dropping the lock over it. + fn drop(&mut self) { + extern crate alloc; + use alloc::alloc::dealloc; + + let mut root_mem = ROOT_MEM.lock(); + let layout = Layout::from_size_align(root_mem.page_count * PAGE_SIZE, PAGE_SIZE).unwrap(); + unsafe { dealloc(root_mem.start_virt.as_mut_ptr::(), layout) }; + *root_mem = MemoryRegion::new(); + + // Reset the Slabs + *SLAB_PAGE_SLAB.lock() = SlabPageSlab::new(); + ALLOCATOR.reset(); + } +} + +#[cfg(test)] +mod test { + extern crate alloc; + use super::*; + use crate::mm::PageBox; + use alloc::boxed::Box; + use core::sync::atomic::{AtomicUsize, Ordering}; + + /// Tests the setup of the root memory + #[test] + fn test_root_mem_setup() { + let test_mem_lock = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + drop(test_mem_lock); + } + + /// Tests the allocation and deallocation of a single page, verifying the + /// memory information. + #[test] + fn test_page_alloc_one() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let mut root_mem = ROOT_MEM.lock(); + + let info_before = root_mem.memory_info(); + let page = root_mem.allocate_page().unwrap(); + assert!(!page.is_null()); + assert_ne!(info_before.free_pages, root_mem.memory_info().free_pages); + root_mem.free_page(page); + assert_eq!(info_before.free_pages, root_mem.memory_info().free_pages); + } + + #[test] + #[cfg_attr(test_in_svsm, ignore = "FIXME")] + /// Allocate and free all available compound pages, verify that memory_info() + /// reflects it. + fn test_page_alloc_all_compound() { + extern crate alloc; + use alloc::vec::Vec; + + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let mut root_mem = ROOT_MEM.lock(); + + let info_before = root_mem.memory_info(); + let mut allocs: [Vec; MAX_ORDER] = Default::default(); + for (o, alloc) in allocs.iter_mut().enumerate().take(MAX_ORDER) { + for _i in 0..info_before.free_pages[o] { + let pages = root_mem.allocate_pages(o).unwrap(); + assert!(!pages.is_null()); + alloc.push(pages); + } + } + let info_after = root_mem.memory_info(); + for o in 0..MAX_ORDER { + assert_eq!(info_after.free_pages[o], 0); + } + + for alloc in allocs.iter().take(MAX_ORDER) { + for pages in &alloc[..] { + root_mem.free_page(*pages); + } + } + assert_eq!(info_before.free_pages, root_mem.memory_info().free_pages); + } + + #[test] + #[cfg_attr(test_in_svsm, ignore = "FIXME")] + /// Allocate and free all available 4k pages, verify that memory_info() + /// reflects it. + fn test_page_alloc_all_single() { + extern crate alloc; + use alloc::vec::Vec; + + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let mut root_mem = ROOT_MEM.lock(); + + let info_before = root_mem.memory_info(); + let mut allocs: Vec = Vec::new(); + for o in 0..MAX_ORDER { + for _i in 0..info_before.free_pages[o] { + for _j in 0..(1usize << o) { + let page = root_mem.allocate_page().unwrap(); + assert!(!page.is_null()); + allocs.push(page); + } + } + } + let info_after = root_mem.memory_info(); + for o in 0..MAX_ORDER { + assert_eq!(info_after.free_pages[o], 0); + } + + for page in &allocs[..] { + root_mem.free_page(*page); + } + assert_eq!(info_before.free_pages, root_mem.memory_info().free_pages); + } + + #[test] + #[cfg_attr(test_in_svsm, ignore = "FIXME")] + /// Allocate and free all available compound pages, verify that any subsequent + /// allocation fails. + fn test_page_alloc_oom() { + extern crate alloc; + use alloc::vec::Vec; + + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let mut root_mem = ROOT_MEM.lock(); + + let info_before = root_mem.memory_info(); + let mut allocs: [Vec; MAX_ORDER] = Default::default(); + for (o, alloc) in allocs.iter_mut().enumerate().take(MAX_ORDER) { + for _i in 0..info_before.free_pages[o] { + let pages = root_mem.allocate_pages(o).unwrap(); + assert!(!pages.is_null()); + alloc.push(pages); + } + } + let info_after = root_mem.memory_info(); + for o in 0..MAX_ORDER { + assert_eq!(info_after.free_pages[o], 0); + } + + let page = root_mem.allocate_page(); + if page.is_ok() { + panic!("unexpected page allocation success after memory exhaustion"); + } + + for alloc in allocs.iter().take(MAX_ORDER) { + for pages in &alloc[..] { + root_mem.free_page(*pages); + } + } + assert_eq!(info_before.free_pages, root_mem.memory_info().free_pages); + } + + #[test] + fn test_page_file() { + let _mem_lock = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let mut root_mem = ROOT_MEM.lock(); + + // Allocate page and check ref-count + let vaddr = root_mem.allocate_file_page().unwrap(); + let pfn = root_mem.get_pfn(vaddr).unwrap(); + let info = root_mem.read_page_info(pfn); + + assert!(matches!(info, PageInfo::File(ref fi) if fi.ref_count == 1)); + + // Get another reference and check ref-count + root_mem.get_file_page(vaddr).expect("Not a file page"); + let info = root_mem.read_page_info(pfn); + + assert!(matches!(info, PageInfo::File(ref fi) if fi.ref_count == 2)); + + // Drop reference and check ref-count + root_mem.put_file_page(vaddr).expect("Not a file page"); + let info = root_mem.read_page_info(pfn); + + assert!(matches!(info, PageInfo::File(ref fi) if fi.ref_count == 1)); + + // Drop last reference and check if page is released + root_mem.put_file_page(vaddr).expect("Not a file page"); + let info = root_mem.read_page_info(pfn); + + assert!(matches!(info, PageInfo::Free { .. })); + } + + const TEST_SLAB_SIZES: [usize; 7] = [32, 64, 128, 256, 512, 1024, 2048]; + + #[test] + /// Allocate and free a couple of objects for each slab size. + fn test_slab_alloc_free_many() { + extern crate alloc; + use alloc::vec::Vec; + + let _mem_lock = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + // Run it twice to make sure some objects will get freed and allocated again. + for _i in 0..2 { + let mut allocs: [Vec<*mut u8>; TEST_SLAB_SIZES.len()] = Default::default(); + let mut j = 0; + for size in TEST_SLAB_SIZES { + let layout = Layout::from_size_align(size, size).unwrap().pad_to_align(); + assert_eq!(layout.size(), size); + + // Allocate four pages worth of objects from each Slab. + let n = (4 * PAGE_SIZE + size - 1) / size; + for _k in 0..n { + let p = unsafe { ALLOCATOR.alloc(layout) }; + assert_ne!(p, ptr::null_mut()); + allocs[j].push(p); + } + j += 1; + } + + j = 0; + for size in TEST_SLAB_SIZES { + let layout = Layout::from_size_align(size, size).unwrap().pad_to_align(); + assert_eq!(layout.size(), size); + + for p in &allocs[j][..] { + unsafe { ALLOCATOR.dealloc(*p, layout) }; + } + j += 1; + } + } + } + + #[test] + #[cfg_attr(test_in_svsm, ignore = "FIXME")] + /// Allocate enough objects so that the SlabPageSlab will need a SlabPage for + /// itself twice. + fn test_slab_page_slab_for_self() { + extern crate alloc; + use alloc::vec::Vec; + + let _mem_lock = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + const OBJECT_SIZE: usize = TEST_SLAB_SIZES[0]; + const OBJECTS_PER_PAGE: usize = PAGE_SIZE / OBJECT_SIZE; + + const SLAB_PAGES_PER_PAGE: usize = PAGE_SIZE / SLAB_PAGE_SIZE as usize; + + let layout = Layout::from_size_align(OBJECT_SIZE, OBJECT_SIZE) + .unwrap() + .pad_to_align(); + assert_eq!(layout.size(), OBJECT_SIZE); + + let mut allocs: Vec<*mut u8> = Vec::new(); + for _i in 0..(2 * SLAB_PAGES_PER_PAGE * OBJECTS_PER_PAGE) { + let p = unsafe { ALLOCATOR.alloc(layout) }; + assert_ne!(p, ptr::null_mut()); + assert_ne!(SLAB_PAGE_SLAB.lock().common.capacity, 0); + allocs.push(p); + } + + for p in allocs { + unsafe { ALLOCATOR.dealloc(p, layout) }; + } + + assert_ne!(SLAB_PAGE_SLAB.lock().common.free, 0); + assert!(SLAB_PAGE_SLAB.lock().common.free_pages < 2); + } + + #[test] + #[cfg_attr(test_in_svsm, ignore = "FIXME")] + /// Allocate enough objects to hit an OOM situation and verify null gets + /// returned at some point. + fn test_slab_oom() { + extern crate alloc; + use alloc::vec::Vec; + + const TEST_MEMORY_SIZE: usize = 256 * PAGE_SIZE; + let _mem_lock = TestRootMem::setup(TEST_MEMORY_SIZE); + + const OBJECT_SIZE: usize = TEST_SLAB_SIZES[0]; + let layout = Layout::from_size_align(OBJECT_SIZE, OBJECT_SIZE) + .unwrap() + .pad_to_align(); + assert_eq!(layout.size(), OBJECT_SIZE); + + let mut allocs: Vec<*mut u8> = Vec::new(); + let mut null_seen = false; + for _i in 0..((TEST_MEMORY_SIZE + OBJECT_SIZE - 1) / OBJECT_SIZE) { + let p = unsafe { ALLOCATOR.alloc(layout) }; + if p.is_null() { + null_seen = true; + break; + } + allocs.push(p); + } + + if !null_seen { + panic!("unexpected slab allocation success after memory exhaustion"); + } + + for p in allocs { + unsafe { ALLOCATOR.dealloc(p, layout) }; + } + } + + /// Helper to assert that a `PageBox` is properly dropped. + fn check_drop_page(page: PageBox) { + let vaddr = page.vaddr(); + { + let mem = ROOT_MEM.lock(); + let pfn = mem.get_pfn(vaddr).unwrap(); + let info = mem.read_page_info(pfn); + assert!(matches!(info, PageInfo::Allocated(..))); + } + drop(page); + { + let mem = ROOT_MEM.lock(); + let pfn = mem.get_pfn(vaddr).unwrap(); + let info = mem.read_page_info(pfn); + assert!(matches!(info, PageInfo::Free { .. })); + } + } + + #[test] + fn test_drop_pagebox() { + // Check that the inner contents of the [`PageBox`] are only dropped + // once. + static DROPPED: AtomicUsize = AtomicUsize::new(0); + + struct Thing(Box); + + impl Drop for Thing { + fn drop(&mut self) { + DROPPED.fetch_add(1, Ordering::Relaxed); + } + } + + let _mem_lock = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let page = PageBox::try_new(Thing(Box::new(44))).unwrap(); + assert_eq!(*page.0, 44); + + check_drop_page(page); + assert_eq!(DROPPED.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_drop_pagebox_slice() { + use core::num::NonZeroUsize; + + const NUM_ITEMS: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(8192) }; + static DROPPED: AtomicUsize = AtomicUsize::new(0); + + #[derive(Clone)] + struct Thing(Box); + + impl Drop for Thing { + fn drop(&mut self) { + DROPPED.fetch_add(1, Ordering::Relaxed); + } + } + + let _mem_lock = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + + let slice = PageBox::try_new_slice(Thing(Box::new(43)), NUM_ITEMS).unwrap(); + + // Check that contents match + for item in slice.iter() { + assert_eq!(*item.0, 43); + } + assert_eq!(slice.len(), NUM_ITEMS.get()); + + check_drop_page(slice); + // All the items in the slice must have dropped, plus the original + // value that items were cloned out of. + assert_eq!(DROPPED.load(Ordering::Relaxed), NUM_ITEMS.get() + 1); + } +} diff --git a/stage2/src/mm/guestmem.rs b/stage2/src/mm/guestmem.rs new file mode 100644 index 000000000..a817f2d02 --- /dev/null +++ b/stage2/src/mm/guestmem.rs @@ -0,0 +1,429 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use crate::address::{Address, VirtAddr}; +use crate::cpu::x86::smap::{clac, stac}; +use crate::error::SvsmError; +use crate::insn_decode::{InsnError, InsnMachineMem}; +use crate::mm::{USER_MEM_END, USER_MEM_START}; +use alloc::string::String; +use alloc::vec::Vec; +use core::arch::asm; +use core::ffi::c_char; +use core::mem::{size_of, MaybeUninit}; +use syscall::PATH_MAX; +use zerocopy::FromBytes; + +#[inline] +pub fn read_u8(v: VirtAddr) -> Result { + let mut rcx: u64; + let mut val: u64; + + unsafe { + asm!("1: movb ({0}), %al", + " xorq %rcx, %rcx", + "2:", + ".pushsection \"__exception_table\",\"a\"", + ".balign 16", + ".quad (1b)", + ".quad (2b)", + ".popsection", + in(reg) v.bits(), + out("rax") val, + out("rcx") rcx, + options(att_syntax, nostack)); + } + + let ret: u8 = (val & 0xff) as u8; + if rcx == 0 { + Ok(ret) + } else { + Err(SvsmError::InvalidAddress) + } +} + +/// Writes 1 byte at a virtual address. +/// +/// # Safety +/// +/// The caller must verify not to corrupt arbitrary memory, as this function +/// doesn't make any checks in that regard. +/// +/// # Returns +/// +/// Returns an error if the specified address is not mapped or is not mapped +/// with the appropriate write permissions. +#[inline] +pub unsafe fn write_u8(v: VirtAddr, val: u8) -> Result<(), SvsmError> { + let mut rcx: u64; + + unsafe { + asm!("1: movb %al, ({0})", + " xorq %rcx, %rcx", + "2:", + ".pushsection \"__exception_table\",\"a\"", + ".balign 16", + ".quad (1b)", + ".quad (2b)", + ".popsection", + in(reg) v.bits(), + in("rax") val as u64, + out("rcx") rcx, + options(att_syntax, nostack)); + } + + if rcx == 0 { + Ok(()) + } else { + Err(SvsmError::InvalidAddress) + } +} + +#[expect(dead_code)] +#[inline] +unsafe fn read_u16(v: VirtAddr) -> Result { + let mut rcx: u64; + let mut val: u64; + + asm!("1: movw ({0}), {1}", + " xorq %rcx, %rcx", + "2:", + ".pushsection \"__exception_table\",\"a\"", + ".balign 16", + ".quad (1b)", + ".quad (2b)", + ".popsection", + in(reg) v.bits(), + out(reg) val, + out("rcx") rcx, + options(att_syntax, nostack)); + + let ret: u16 = (val & 0xffff) as u16; + if rcx == 0 { + Ok(ret) + } else { + Err(SvsmError::InvalidAddress) + } +} + +#[expect(dead_code)] +#[inline] +unsafe fn read_u32(v: VirtAddr) -> Result { + let mut rcx: u64; + let mut val: u64; + + asm!("1: movl ({0}), {1}", + " xorq %rcx, %rcx", + "2:", + ".pushsection \"__exception_table\",\"a\"", + ".balign 16", + ".quad (1b)", + ".quad (2b)", + ".popsection", + in(reg) v.bits(), + out(reg) val, + out("rcx") rcx, + options(att_syntax, nostack)); + + let ret: u32 = (val & 0xffffffff) as u32; + if rcx == 0 { + Ok(ret) + } else { + Err(SvsmError::InvalidAddress) + } +} + +#[expect(dead_code)] +#[inline] +unsafe fn read_u64(v: VirtAddr) -> Result { + let mut rcx: u64; + let mut val: u64; + + asm!("1: movq ({0}), {1}", + " xorq %rcx, %rcx", + "2:", + ".pushsection \"__exception_table\",\"a\"", + ".balign 16", + ".quad (1b)", + ".quad (2b)", + ".popsection", + in(reg) v.bits(), + out(reg) val, + out("rcx") rcx, + options(att_syntax, nostack)); + + if rcx == 0 { + Ok(val) + } else { + Err(SvsmError::InvalidAddress) + } +} + +#[inline] +unsafe fn do_movsb(src: *const T, dst: *mut T) -> Result<(), SvsmError> { + let size: usize = size_of::(); + let mut rcx: u64; + + asm!("1:cld + rep movsb + 2: + .pushsection \"__exception_table\",\"a\" + .balign 16 + .quad (1b) + .quad (2b) + .popsection", + inout("rsi") src => _, + inout("rdi") dst => _, + inout("rcx") size => rcx, + options(att_syntax, nostack)); + + if rcx == 0 { + Ok(()) + } else { + Err(SvsmError::InvalidAddress) + } +} + +#[derive(Debug)] +pub struct GuestPtr { + ptr: *mut T, +} + +impl GuestPtr { + #[inline] + pub fn new(v: VirtAddr) -> Self { + Self { + ptr: v.as_mut_ptr::(), + } + } + + #[inline] + pub const fn from_ptr(p: *mut T) -> Self { + Self { ptr: p } + } + + /// # Safety + /// + /// The caller must verify not to read arbitrary memory, as this function + /// doesn't make any checks in that regard. + /// + /// # Returns + /// + /// Returns an error if the specified address is not mapped. + #[inline] + pub unsafe fn read(&self) -> Result { + let mut buf = MaybeUninit::::uninit(); + + unsafe { + do_movsb(self.ptr, buf.as_mut_ptr())?; + Ok(buf.assume_init()) + } + } + + /// # Safety + /// + /// The caller must verify not to corrupt arbitrary memory, as this function + /// doesn't make any checks in that regard. + /// + /// # Returns + /// + /// Returns an error if the specified address is not mapped or is not mapped + /// with the appropriate write permissions. + #[inline] + pub unsafe fn write(&self, buf: T) -> Result<(), SvsmError> { + unsafe { do_movsb(&buf, self.ptr) } + } + + /// # Safety + /// + /// The caller must verify not to corrupt arbitrary memory, as this function + /// doesn't make any checks in that regard. + /// + /// # Returns + /// + /// Returns an error if the specified address is not mapped or is not mapped + /// with the appropriate write permissions. + #[inline] + pub unsafe fn write_ref(&self, buf: &T) -> Result<(), SvsmError> { + unsafe { do_movsb(buf, self.ptr) } + } + + #[inline] + pub const fn cast(&self) -> GuestPtr { + GuestPtr::from_ptr(self.ptr.cast()) + } + + #[inline] + pub fn offset(&self, count: isize) -> Self { + GuestPtr::from_ptr(self.ptr.wrapping_offset(count)) + } +} + +impl InsnMachineMem for GuestPtr { + type Item = T; + + /// Safety: See the GuestPtr's read() method documentation for safety requirements. + unsafe fn mem_read(&self) -> Result { + self.read().map_err(|_| InsnError::MemRead) + } + + /// Safety: See the GuestPtr's write() method documentation for safety requirements. + unsafe fn mem_write(&mut self, data: Self::Item) -> Result<(), InsnError> { + self.write(data).map_err(|_| InsnError::MemWrite) + } +} + +struct UserAccessGuard; + +impl UserAccessGuard { + pub fn new() -> Self { + stac(); + Self + } +} + +impl Drop for UserAccessGuard { + fn drop(&mut self) { + clac(); + } +} + +#[derive(Debug)] +pub struct UserPtr { + guest_ptr: GuestPtr, +} + +impl UserPtr { + #[inline] + pub fn new(v: VirtAddr) -> Self { + Self { + guest_ptr: GuestPtr::new(v), + } + } + + fn check_bounds(&self) -> bool { + let v = VirtAddr::from(self.guest_ptr.ptr); + + (USER_MEM_START..USER_MEM_END).contains(&v) + && (USER_MEM_START..USER_MEM_END).contains(&(v + size_of::())) + } + + #[inline] + pub fn read(&self) -> Result + where + T: FromBytes, + { + if !self.check_bounds() { + return Err(SvsmError::InvalidAddress); + } + let _guard = UserAccessGuard::new(); + unsafe { self.guest_ptr.read() } + } + + #[inline] + pub fn write(&self, buf: T) -> Result<(), SvsmError> { + self.write_ref(&buf) + } + + #[inline] + pub fn write_ref(&self, buf: &T) -> Result<(), SvsmError> { + if !self.check_bounds() { + return Err(SvsmError::InvalidAddress); + } + let _guard = UserAccessGuard::new(); + unsafe { self.guest_ptr.write_ref(buf) } + } + + #[inline] + pub const fn cast(&self) -> UserPtr { + UserPtr { + guest_ptr: self.guest_ptr.cast(), + } + } + + #[inline] + pub fn offset(&self, count: isize) -> UserPtr { + UserPtr { + guest_ptr: self.guest_ptr.offset(count), + } + } +} + +impl UserPtr { + /// Reads a null-terminated C string from the user space. + /// Allocates memory for the string and returns a `String`. + pub fn read_c_string(&self) -> Result { + let mut buffer = Vec::new(); + + for offset in 0..PATH_MAX { + let current_ptr = self.offset(offset as isize); + let char_result = current_ptr.read()?; + match char_result { + 0 => return String::from_utf8(buffer).map_err(|_| SvsmError::InvalidUtf8), + c => buffer.push(c as u8), + } + } + Err(SvsmError::InvalidBytes) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(miri, ignore = "inline assembly")] + fn test_read_u8_valid_address() { + // Create a region to read from + let test_buffer: [u8; 6] = [0; 6]; + let test_address = VirtAddr::from(test_buffer.as_ptr()); + + let result = read_u8(test_address).unwrap(); + + assert_eq!(result, test_buffer[0]); + } + + #[test] + #[cfg_attr(miri, ignore = "inline assembly")] + fn test_write_u8_valid_address() { + // Create a mutable region we can write into + let mut test_buffer: [u8; 6] = [0; 6]; + let test_address = VirtAddr::from(test_buffer.as_mut_ptr()); + let data_to_write = 0x42; + + // SAFETY: test_address points to the virtual address of test_buffer. + unsafe { + write_u8(test_address, data_to_write).unwrap(); + } + + assert_eq!(test_buffer[0], data_to_write); + } + + #[test] + #[cfg_attr(miri, ignore = "inline assembly")] + fn test_read_15_bytes_valid_address() { + let test_buffer = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; + let test_addr = VirtAddr::from(test_buffer.as_ptr()); + let ptr: GuestPtr<[u8; 15]> = GuestPtr::new(test_addr); + // SAFETY: ptr points to test_buffer's virtual address + let result = unsafe { ptr.read().unwrap() }; + + assert_eq!(result, test_buffer); + } + + #[test] + #[cfg_attr(miri, ignore = "inline assembly")] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_read_invalid_address() { + let ptr: GuestPtr = GuestPtr::new(VirtAddr::new(0xDEAD_BEEF)); + // SAFETY: ptr points to an invalid virtual address (0xDEADBEEF is + // unmapped). ptr.read() will return an error but this is expected. + let err = unsafe { ptr.read() }; + assert!(err.is_err()); + } +} diff --git a/stage2/src/mm/mappings.rs b/stage2/src/mm/mappings.rs new file mode 100644 index 000000000..52075bdc7 --- /dev/null +++ b/stage2/src/mm/mappings.rs @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::VirtAddr; +use crate::error::SvsmError; +use crate::fs::FileHandle; +use crate::mm::vm::{Mapping, VMFileMapping, VMFileMappingFlags, VMalloc, VMR}; +use crate::task::current_task; + +use core::ops::Deref; + +extern crate alloc; +use alloc::sync::Arc; + +#[derive(Debug)] +pub struct VMMappingGuard<'a> { + vmr: &'a VMR, + start: VirtAddr, +} + +impl<'a> VMMappingGuard<'a> { + pub fn new(vmr: &'a VMR, start: VirtAddr) -> Self { + VMMappingGuard { vmr, start } + } +} + +impl Deref for VMMappingGuard<'_> { + type Target = VirtAddr; + + fn deref(&self) -> &VirtAddr { + &self.start + } +} + +impl Drop for VMMappingGuard<'_> { + fn drop(&mut self) { + self.vmr + .remove(self.start) + .expect("Fatal error: Failed to unmap region from MappingGuard"); + } +} + +pub fn create_file_mapping( + file: &FileHandle, + offset: usize, + size: usize, + flags: VMFileMappingFlags, +) -> Result, SvsmError> { + let file_mapping = VMFileMapping::new(file, offset, size, flags)?; + Ok(Arc::new(Mapping::new(file_mapping))) +} + +pub fn create_anon_mapping( + size: usize, + flags: VMFileMappingFlags, +) -> Result, SvsmError> { + let alloc = VMalloc::new(size, flags)?; + Ok(Arc::new(Mapping::new(alloc))) +} + +pub fn mmap_user( + addr: VirtAddr, + file: Option<&FileHandle>, + offset: usize, + size: usize, + flags: VMFileMappingFlags, +) -> Result { + current_task().mmap_user(addr, file, offset, size, flags) +} + +pub fn mmap_kernel( + addr: VirtAddr, + file: Option<&FileHandle>, + offset: usize, + size: usize, + flags: VMFileMappingFlags, +) -> Result { + current_task().mmap_kernel(addr, file, offset, size, flags) +} + +pub fn munmap_user(addr: VirtAddr) -> Result<(), SvsmError> { + current_task().munmap_user(addr) +} + +pub fn munmap_kernel(addr: VirtAddr) -> Result<(), SvsmError> { + current_task().munmap_kernel(addr) +} diff --git a/stage2/src/mm/memory.rs b/stage2/src/mm/memory.rs new file mode 100644 index 000000000..8e846bdb2 --- /dev/null +++ b/stage2/src/mm/memory.rs @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +extern crate alloc; + +use crate::address::{Address, PhysAddr}; +use crate::config::SvsmConfig; +use crate::cpu::percpu::PERCPU_VMSAS; +use crate::error::SvsmError; +use crate::locking::RWLock; +use crate::utils::MemoryRegion; +use alloc::vec::Vec; +use bootlib::kernel_launch::KernelLaunchInfo; + +use super::pagetable::LAUNCH_VMSA_ADDR; + +/// Global memory map containing various memory regions. +static MEMORY_MAP: RWLock>> = RWLock::new(Vec::new()); + +/// Initializes the global memory map based on the provided configuration +/// and kernel launch information. +/// +/// # Arguments +/// +/// * `config` - A reference to the [`SvsmConfig`] containing memory region +/// information. +/// * `launch_info` - A reference to the [`KernelLaunchInfo`] containing +/// information about the kernel region. +/// +/// # Returns +/// +/// Returns `Ok(())` if the memory map is successfully initialized, otherwise +/// returns an error of type `SvsmError`. +pub fn init_memory_map( + config: &SvsmConfig<'_>, + launch_info: &KernelLaunchInfo, +) -> Result<(), SvsmError> { + let mut regions = config.get_memory_regions()?; + let kernel_start = PhysAddr::from(launch_info.kernel_region_phys_start); + let kernel_end = PhysAddr::from(launch_info.kernel_region_phys_end); + let kernel_region = MemoryRegion::from_addresses(kernel_start, kernel_end); + + // Remove SVSM memory from guest memory map + let mut i = 0; + while i < regions.len() { + // Check if the region overlaps with SVSM memory. + let region = regions[i]; + if !region.overlap(&kernel_region) { + // Check the next region. + i += 1; + continue; + } + + // 1. Remove the region. + regions.remove(i); + + // 2. Insert a region up until the start of SVSM memory (if non-empty). + let region_before_start = region.start(); + let region_before_end = kernel_region.start(); + if region_before_start < region_before_end { + regions.insert( + i, + MemoryRegion::from_addresses(region_before_start, region_before_end), + ); + i += 1; + } + + // 3. Insert a region up after the end of SVSM memory (if non-empty). + let region_after_start = kernel_region.end(); + let region_after_end = region.end(); + if region_after_start < region_after_end { + regions.insert( + i, + MemoryRegion::from_addresses(region_after_start, region_after_end), + ); + i += 1; + } + } + + log::info!("Guest Memory Regions:"); + for r in regions.iter() { + log::info!(" {:018x}-{:018x}", r.start(), r.end()); + } + + let mut map = MEMORY_MAP.lock_write(); + *map = regions; + + Ok(()) +} + +pub fn write_guest_memory_map(config: &SvsmConfig<'_>) -> Result<(), SvsmError> { + // Supply the memory map to the guest if required by the configuration. + config.write_guest_memory_map(&MEMORY_MAP.lock_read()) +} + +/// Returns `true` if the provided physical address `paddr` is valid, i.e. +/// it is within the configured memory regions, otherwise returns `false`. +pub fn valid_phys_address(paddr: PhysAddr) -> bool { + let page_addr = paddr.page_align(); + + if PERCPU_VMSAS.exists(page_addr) { + return false; + } + if page_addr == LAUNCH_VMSA_ADDR { + return false; + } + + MEMORY_MAP + .lock_read() + .iter() + .any(|region| region.contains(paddr)) +} + +/// The starting address of the ISA range. +const ISA_RANGE_START: PhysAddr = PhysAddr::new(0xa0000); + +/// The ending address of the ISA range. +const ISA_RANGE_END: PhysAddr = PhysAddr::new(0x100000); + +/// Returns `true` if the provided physical address `paddr` is writable, +/// otherwise returns `false`. +pub fn writable_phys_addr(paddr: PhysAddr) -> bool { + // The ISA range is not writable + if paddr >= ISA_RANGE_START && paddr < ISA_RANGE_END { + return false; + } + + valid_phys_address(paddr) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg_attr(test_in_svsm, ignore = "Offline testing")] + fn test_valid_phys_address() { + let start = PhysAddr::new(0x1000); + let end = PhysAddr::new(0x2000); + let region = MemoryRegion::from_addresses(start, end); + + MEMORY_MAP.lock_write().push(region); + + // Inside the region + assert!(valid_phys_address(PhysAddr::new(0x1500))); + // Outside the region + assert!(!valid_phys_address(PhysAddr::new(0x3000))); + } +} diff --git a/stage2/src/mm/mod.rs b/stage2/src/mm/mod.rs new file mode 100644 index 000000000..0f2f8e3e0 --- /dev/null +++ b/stage2/src/mm/mod.rs @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod address_space; +pub mod alloc; +pub mod guestmem; +pub mod mappings; +pub mod memory; +pub mod page_visibility; +mod pagebox; +pub mod pagetable; +pub mod ptguards; +pub mod validate; +pub mod virtualrange; +pub mod vm; + +pub use address_space::*; +pub use guestmem::GuestPtr; +pub use memory::{valid_phys_address, writable_phys_addr}; +pub use pagebox::*; +pub use ptguards::*; + +pub use pagetable::PageTablePart; + +pub use alloc::{allocate_file_page, PageRef}; + +pub use mappings::{mmap_kernel, mmap_user, munmap_kernel, munmap_user, VMMappingGuard}; diff --git a/stage2/src/mm/page_visibility.rs b/stage2/src/mm/page_visibility.rs new file mode 100644 index 000000000..f442cb3d9 --- /dev/null +++ b/stage2/src/mm/page_visibility.rs @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +use core::mem::MaybeUninit; +use core::ptr::NonNull; + +use crate::address::VirtAddr; +use crate::cpu::flush_tlb_global_sync; +use crate::cpu::mem::{copy_bytes, write_bytes}; +use crate::cpu::percpu::this_cpu; +use crate::error::SvsmError; +use crate::mm::validate::{ + valid_bitmap_clear_valid_4k, valid_bitmap_set_valid_4k, valid_bitmap_valid_addr, +}; +use crate::mm::{virt_to_phys, PageBox}; +use crate::platform::{PageStateChangeOp, PageValidateOp, SVSM_PLATFORM}; +use crate::protocols::errors::SvsmReqError; +use crate::types::{PageSize, PAGE_SIZE}; +use crate::utils::MemoryRegion; + +use zerocopy::{FromBytes, FromZeros}; + +/// Makes a virtual page shared by revoking its validation, updating the +/// page state, and modifying the page tables accordingly. +/// +/// # Arguments +/// +/// * `vaddr` - The virtual address of the page to be made shared. +/// +/// # Safety +/// +/// Converting the memory at `vaddr` must be safe within Rust's memory model. +/// Notably any objects at `vaddr` must tolerate unsynchronized writes of any +/// bit pattern. +unsafe fn make_page_shared(vaddr: VirtAddr) -> Result<(), SvsmError> { + // Revoke page validation before changing page state. + SVSM_PLATFORM.validate_virtual_page_range( + MemoryRegion::new(vaddr, PAGE_SIZE), + PageValidateOp::Invalidate, + )?; + let paddr = virt_to_phys(vaddr); + if valid_bitmap_valid_addr(paddr) { + valid_bitmap_clear_valid_4k(paddr); + } + + // Ask the hypervisor to make the page shared. + SVSM_PLATFORM.page_state_change( + MemoryRegion::new(paddr, PAGE_SIZE), + PageSize::Regular, + PageStateChangeOp::Shared, + )?; + + // Update the page tables to map the page as shared. + this_cpu() + .get_pgtable() + .set_shared_4k(vaddr) + .expect("Failed to remap shared page in page tables"); + flush_tlb_global_sync(); + + Ok(()) +} + +/// Makes a virtual page private by updating the page tables, modifying the +/// page state, and revalidating the page. +/// +/// # Arguments +/// +/// * `vaddr` - The virtual address of the page to be made private. +unsafe fn make_page_private(vaddr: VirtAddr) -> Result<(), SvsmError> { + // Update the page tables to map the page as private. + this_cpu().get_pgtable().set_encrypted_4k(vaddr)?; + flush_tlb_global_sync(); + + // Ask the hypervisor to make the page private. + let paddr = virt_to_phys(vaddr); + SVSM_PLATFORM.page_state_change( + MemoryRegion::new(paddr, PAGE_SIZE), + PageSize::Regular, + PageStateChangeOp::Private, + )?; + + // Validate the page now that it is private again. + SVSM_PLATFORM.validate_virtual_page_range( + MemoryRegion::new(vaddr, PAGE_SIZE), + PageValidateOp::Validate, + )?; + if valid_bitmap_valid_addr(paddr) { + valid_bitmap_set_valid_4k(paddr); + } + + Ok(()) +} + +/// SharedBox is a safe wrapper around memory pages shared with the host. +pub struct SharedBox { + ptr: NonNull, +} + +impl SharedBox { + /// Allocate some memory and share it with the host. + pub fn try_new_zeroed() -> Result { + let page_box = PageBox::>::try_new_zeroed()?; + let vaddr = page_box.vaddr(); + + let ptr = NonNull::from(PageBox::leak(page_box)).cast::(); + + for offset in (0..core::mem::size_of::()).step_by(PAGE_SIZE) { + unsafe { + make_page_shared(vaddr + offset)?; + } + } + + Ok(Self { ptr }) + } + + /// Returns the virtual address of the memory. + pub fn addr(&self) -> VirtAddr { + VirtAddr::from(self.ptr.as_ptr()) + } + + /// Read the currently stored value into `out`. + pub fn read_into(&self, out: &mut T) + where + T: FromBytes + Copy, + { + unsafe { + // SAFETY: `self.ptr` is valid. Any bitpattern is valid for `T`. + copy_bytes( + self.ptr.as_ptr() as usize, + out as *const T as usize, + size_of::(), + ); + } + } + + /// Share `value` with the host. + pub fn write_from(&mut self, value: &T) + where + T: Copy, + { + unsafe { + // SAFETY: `self.ptr` is valid.. + copy_bytes( + value as *const T as usize, + self.ptr.as_ptr() as usize, + size_of::(), + ); + } + } + + /// Leak the memory. + pub fn leak(self) -> NonNull { + let ptr = self.ptr; + core::mem::forget(self); + ptr + } +} + +impl SharedBox<[T; N]> { + /// Clear the first `n` elements. + pub fn nclear(&mut self, n: usize) -> Result<(), SvsmReqError> + where + T: FromZeros, + { + if n > N { + return Err(SvsmReqError::invalid_parameter()); + } + + unsafe { + // SAFETY: `self.ptr` is valid and we did a bounds check on `n`. + write_bytes(self.ptr.as_ptr() as usize, size_of::() * n, 0); + } + + Ok(()) + } + + /// Fill up the `outbuf` slice provided with bytes from data + pub fn copy_to_slice(&self, outbuf: &mut [T]) -> Result<(), SvsmReqError> + where + T: FromBytes + Copy, + { + if outbuf.len() > N { + return Err(SvsmReqError::invalid_parameter()); + } + + let size = core::mem::size_of_val(outbuf); + unsafe { + // SAFETY: `self.ptr` is valid. + copy_bytes( + self.ptr.as_ptr() as usize, + outbuf.as_mut_ptr() as usize, + size, + ); + } + + Ok(()) + } +} + +unsafe impl Send for SharedBox where T: Send {} +unsafe impl Sync for SharedBox where T: Sync {} + +impl Drop for SharedBox { + fn drop(&mut self) { + // Re-encrypt the pages. + let res = (0..size_of::()) + .step_by(PAGE_SIZE) + .try_for_each(|offset| unsafe { make_page_private(self.addr() + offset) }); + + // If re-encrypting was successful free the memory otherwise leak it. + if res.is_ok() { + drop(unsafe { PageBox::from_raw(self.ptr.cast::>()) }); + } else { + log::error!("failed to set pages to encrypted. Memory leak!"); + } + } +} + +impl core::fmt::Debug for SharedBox { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SharedBox").finish_non_exhaustive() + } +} diff --git a/stage2/src/mm/pagebox.rs b/stage2/src/mm/pagebox.rs new file mode 100644 index 000000000..8917afa58 --- /dev/null +++ b/stage2/src/mm/pagebox.rs @@ -0,0 +1,256 @@ +use zerocopy::FromZeros; + +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2024 SUSE +// +// Author: Carlos López +use super::alloc::{allocate_pages, free_page, get_order, AllocError, MAX_ORDER}; +use super::PAGE_SIZE; +use crate::address::VirtAddr; +use crate::error::SvsmError; +use core::borrow; +use core::marker::PhantomData; +use core::mem::{self, ManuallyDrop, MaybeUninit}; +use core::num::NonZeroUsize; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; + +/// An abstraction, similar to a `Box`, for types that need to be allocated +/// using page allocator directly. +/// +/// Constructing a [`PageBox`] is very similar to constructing a regular `Box`: +/// +/// ```no_run +/// # use svsm::mm::PageBox; +/// let p = PageBox::try_new([0u8; 4096])?; +/// # Ok::<(), svsm::error::SvsmError>(()) +/// ``` +/// +/// The type guarantees that the allocated memory will have a minimum alignment +/// of the page size, and that memory will be valid until it is dropped. +/// +/// The type does not support zero sized types nor unsized types. On the other +/// hand, it is able to check at compile time that the contained `T` can fit in +/// a page allocation with the required alignment. For example, the following +/// will not build because its size exceeds the maximum page order: +/// +/// ```compile_fail +/// # use svsm::mm::PageBox; +/// let p = PageBox::try_new([0u8; 0x80000])?; +/// # Ok::<(), svsm::error::SvsmError>(()) +/// ``` +#[derive(Debug)] +#[repr(transparent)] +pub struct PageBox { + ptr: NonNull, + _phantom: PhantomData, +} + +impl PageBox { + /// Allocates enough pages to hold a `T`, initializing them with the given value. + pub fn try_new(x: T) -> Result { + let mut pages = PageBox::::try_new_uninit()?; + // SAFETY: the pointer returned by MaybeUninit::as_mut_ptr() must be + // valid as part of its invariants. We can assume memory is + // initialized after writing to it. + unsafe { + MaybeUninit::as_mut_ptr(&mut pages).write(x); + Ok(pages.assume_init()) + } + } + + /// Allocates enough pages to hold a `T`, and zeroes them out. + pub fn try_new_zeroed() -> Result, SvsmError> + where + T: FromZeros, + { + let mut pages = Self::try_new_uninit()?; + unsafe { MaybeUninit::as_mut_ptr(&mut pages).write_bytes(0, 1) }; + Ok(unsafe { + // SAFETY: We know that all zeros is a valid representation because + // `T` implements `FromZeros`. + pages.assume_init() + }) + } + + /// Gets the page order required for an allocation to hold a `T`. It also + /// checks that the size and alignment requirements of the type can be + /// serviced. + const fn get_order() -> usize { + check_size_requirements::(); + let order = get_order(mem::size_of::()); + assert!(order < MAX_ORDER); + order + } + + /// Allocates enough pages to hold a `T`, but does not initialize them. + pub fn try_new_uninit() -> Result>, SvsmError> { + let order = const { Self::get_order() }; + let addr = NonNull::new(allocate_pages(order)?.as_mut_ptr()).unwrap(); + unsafe { Ok(PageBox::from_raw(addr)) } + } +} + +impl PageBox { + /// Create a [`PageBox`] from a previous allocation of the same type. + /// + /// # Safety + /// + /// The provided pointer must come from a previous use of [`PageBox`] + /// (likely through [`leak()`](PageBox::leak), and must not be aliased. + #[inline] + pub const unsafe fn from_raw(ptr: NonNull) -> Self { + Self { + ptr, + _phantom: PhantomData, + } + } + + /// Consumes and leaks the `PageBox`, returning a mutable reference. The + /// contents will never be freed unless the mutable reference is + /// converted back to a `PageBox` via [`from_raw()`](PageBox::from_raw). + pub fn leak<'a>(b: Self) -> &'a mut T { + unsafe { ManuallyDrop::new(b).ptr.as_mut() } + } + + /// Returns the virtual address of this allocation. + #[inline] + pub fn vaddr(&self) -> VirtAddr { + VirtAddr::from(self.ptr.as_ptr().cast::()) + } +} + +impl PageBox> { + /// Transforms a [`PageBox>`] into a [`PageBox`]. + /// + /// # Safety + /// + /// See the safety requirements for [`MaybeUninit::assume_init()`]. + pub unsafe fn assume_init(self) -> PageBox { + let leaked = PageBox::leak(self).assume_init_mut(); + let raw = NonNull::from(leaked); + PageBox::from_raw(raw) + } +} + +impl PageBox<[T]> { + /// Allocates a dynamically-sized slice of `len` items of type `T`, and + /// populates it with the given value. The slice cannot be resized. + pub fn try_new_slice(val: T, len: NonZeroUsize) -> Result { + // Create and init slice + let mut slice = Self::try_new_uninit_slice(len)?; + for item in slice.iter_mut() { + item.write(val.clone()); + } + // SAFETY: we initialized the contents + unsafe { Ok(slice.assume_init_slice()) } + } + + /// Allocates a dynamically-sized slice of `len` uninitialized items of + /// type `T`. The slice cannot be resized. + pub fn try_new_uninit_slice(len: NonZeroUsize) -> Result]>, SvsmError> { + const { check_size_requirements::>() }; + let order = len + .get() + .checked_mul(mem::size_of::>()) + .map(get_order) + .filter(|order| *order < MAX_ORDER) + .ok_or(AllocError::OutOfMemory)?; + let raw = NonNull::new(allocate_pages(order)?.as_mut_ptr()).unwrap(); + let ptr = NonNull::slice_from_raw_parts(raw, len.get()); + Ok(PageBox { + ptr, + _phantom: PhantomData, + }) + } +} + +impl PageBox<[MaybeUninit]> { + /// Transforms a [`PageBox<[MaybeUninit]>`] into a [`PageBox<[T]>`]. + /// + /// # Safety + /// + /// See the safety requirements for [`MaybeUninit::assume_init()`]. + pub unsafe fn assume_init_slice(self) -> PageBox<[T]> { + // Leak the slice so we can transmute its type. Then transform + // `NonNull<[MaybeUninit]>` into `NonNull<[T]>`. + let leaked = NonNull::from(PageBox::leak(self)); + let inited = NonNull::slice_from_raw_parts(leaked.cast(), leaked.len()); + // We obtained this pointer from a previously leaked allocation, so + // this is safe. + PageBox::from_raw(inited) + } +} + +impl Deref for PageBox { + type Target = T; + + #[inline] + fn deref(&self) -> &T { + // SAFETY: this is part of the invariants of this type, as it must + // hold a pointer to valid memory for the given `T`. + unsafe { self.ptr.as_ref() } + } +} + +impl DerefMut for PageBox { + #[inline] + fn deref_mut(&mut self) -> &mut T { + // SAFETY: this is part of the invariants of this type, as it must + // hold a pointer to valid memory for the given `T`. + unsafe { self.ptr.as_mut() } + } +} + +impl borrow::Borrow for PageBox { + #[inline] + fn borrow(&self) -> &T { + self + } +} + +impl borrow::BorrowMut for PageBox { + #[inline] + fn borrow_mut(&mut self) -> &mut T { + self + } +} + +impl AsRef for PageBox { + #[inline] + fn as_ref(&self) -> &T { + self + } +} + +impl AsMut for PageBox { + #[inline] + fn as_mut(&mut self) -> &mut T { + self + } +} + +impl Drop for PageBox { + fn drop(&mut self) { + let ptr = self.ptr.as_ptr(); + unsafe { ptr.drop_in_place() }; + free_page(self.vaddr()); + } +} + +/// `TryBox` is `Send` if `T` is `Send` because the data it +/// references via its internal pointer is unaliased. +unsafe impl Send for PageBox {} + +/// `TryBox` is `Sync` if `T` is `Sync` because the data it +/// references via its internal pointer is unaliased. +unsafe impl Sync for PageBox {} + +/// Check the size requrements for a type to be allocated through `PageBox`. +const fn check_size_requirements() { + // We cannot guarantee a better alignment than a page in the general case + // and we do not handle zero-sized types. + assert!(mem::size_of::() > 0); + assert!(mem::align_of::() <= PAGE_SIZE); +} diff --git a/stage2/src/mm/pagetable.rs b/stage2/src/mm/pagetable.rs new file mode 100644 index 000000000..6f3f13f89 --- /dev/null +++ b/stage2/src/mm/pagetable.rs @@ -0,0 +1,1796 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::cpu::control_regs::{write_cr3, CR0Flags, CR4Flags}; +use crate::cpu::efer::EFERFlags; +use crate::cpu::flush_tlb_global_sync; +use crate::cpu::idt::common::PageFaultError; +use crate::cpu::registers::RFlags; +use crate::error::SvsmError; +use crate::mm::{ + phys_to_virt, virt_to_phys, PageBox, PGTABLE_LVL3_IDX_PTE_SELFMAP, PGTABLE_LVL3_IDX_SHARED, + SVSM_PTE_BASE, +}; +use crate::platform::SvsmPlatform; +use crate::types::{PageSize, PAGE_SIZE, PAGE_SIZE_1G, PAGE_SIZE_2M}; +use crate::utils::immut_after_init::{ImmutAfterInitCell, ImmutAfterInitResult}; +use crate::utils::MemoryRegion; +use crate::BIT_MASK; +use bitflags::bitflags; +use core::cmp; +use core::ops::{Index, IndexMut}; +use core::ptr::NonNull; + +extern crate alloc; +use alloc::boxed::Box; + +/// Number of entries in a page table (4KB/8B). +const ENTRY_COUNT: usize = 512; + +/// Mask for private page table entry. +static PRIVATE_PTE_MASK: ImmutAfterInitCell = ImmutAfterInitCell::new(0); + +/// Mask for shared page table entry. +static SHARED_PTE_MASK: ImmutAfterInitCell = ImmutAfterInitCell::new(0); + +/// Maximum physical address supported by the system. +static MAX_PHYS_ADDR: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +/// Maximum physical address bits supported by the system. +static PHYS_ADDR_SIZE: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +/// Physical address for the Launch VMSA (Virtual Machine Saving Area). +pub const LAUNCH_VMSA_ADDR: PhysAddr = PhysAddr::new(0xFFFFFFFFF000); + +/// Feature mask for page table entry flags. +static FEATURE_MASK: ImmutAfterInitCell = + ImmutAfterInitCell::new(PTEntryFlags::empty()); + +/// Re-initializes early paging settings. +pub fn paging_init_early(platform: &dyn SvsmPlatform) -> ImmutAfterInitResult<()> { + init_encrypt_mask(platform)?; + + let mut feature_mask = PTEntryFlags::all(); + feature_mask.remove(PTEntryFlags::GLOBAL); + FEATURE_MASK.reinit(&feature_mask) +} + +/// Initializes paging settings. +pub fn paging_init(platform: &dyn SvsmPlatform) -> ImmutAfterInitResult<()> { + init_encrypt_mask(platform)?; + + let feature_mask = PTEntryFlags::all(); + FEATURE_MASK.reinit(&feature_mask) +} + +/// Initializes the encrypt mask. +fn init_encrypt_mask(platform: &dyn SvsmPlatform) -> ImmutAfterInitResult<()> { + let masks = platform.get_page_encryption_masks(); + + PRIVATE_PTE_MASK.reinit(&masks.private_pte_mask)?; + SHARED_PTE_MASK.reinit(&masks.shared_pte_mask)?; + + let guest_phys_addr_size = (masks.phys_addr_sizes >> 16) & 0xff; + let host_phys_addr_size = masks.phys_addr_sizes & 0xff; + let phys_addr_size = if guest_phys_addr_size == 0 { + // When [GuestPhysAddrSize] is zero, refer to the PhysAddrSize field + // for the maximum guest physical address size. + // - APM3, E.4.7 Function 8000_0008h - Processor Capacity Parameters and Extended Feature Identification + host_phys_addr_size + } else { + guest_phys_addr_size + }; + + PHYS_ADDR_SIZE.reinit(&phys_addr_size)?; + + // If the C-bit is a physical address bit however, the guest physical + // address space is effectively reduced by 1 bit. + // - APM2, 15.34.6 Page Table Support + let effective_phys_addr_size = cmp::min(masks.addr_mask_width, phys_addr_size); + + let max_addr = 1 << effective_phys_addr_size; + MAX_PHYS_ADDR.reinit(&max_addr) +} + +/// Returns the private encrypt mask value. +fn private_pte_mask() -> usize { + *PRIVATE_PTE_MASK +} + +/// Returns the shared encrypt mask value. +fn shared_pte_mask() -> usize { + *SHARED_PTE_MASK +} + +/// Returns the exclusive end of the physical address space. +pub fn max_phys_addr() -> PhysAddr { + PhysAddr::from(*MAX_PHYS_ADDR) +} + +/// Returns the supported flags considering the feature mask. +fn supported_flags(flags: PTEntryFlags) -> PTEntryFlags { + flags & *FEATURE_MASK +} + +/// Set address as shared via mask. +fn make_shared_address(paddr: PhysAddr) -> PhysAddr { + PhysAddr::from(paddr.bits() & !private_pte_mask() | shared_pte_mask()) +} + +/// Set address as private via mask. +fn make_private_address(paddr: PhysAddr) -> PhysAddr { + PhysAddr::from(paddr.bits() & !shared_pte_mask() | private_pte_mask()) +} + +fn strip_confidentiality_bits(paddr: PhysAddr) -> PhysAddr { + PhysAddr::from(paddr.bits() & !(shared_pte_mask() | private_pte_mask())) +} + +bitflags! { + #[derive(Copy, Clone, Debug, Default)] + pub struct PTEntryFlags: u64 { + const PRESENT = 1 << 0; + const WRITABLE = 1 << 1; + const USER = 1 << 2; + const ACCESSED = 1 << 5; + const DIRTY = 1 << 6; + const HUGE = 1 << 7; + const GLOBAL = 1 << 8; + const NX = 1 << 63; + } +} + +impl PTEntryFlags { + pub fn exec() -> Self { + Self::PRESENT | Self::GLOBAL | Self::ACCESSED | Self::DIRTY + } + + pub fn data() -> Self { + Self::PRESENT | Self::GLOBAL | Self::WRITABLE | Self::NX | Self::ACCESSED | Self::DIRTY + } + + pub fn data_ro() -> Self { + Self::PRESENT | Self::GLOBAL | Self::NX | Self::ACCESSED | Self::DIRTY + } + + pub fn task_exec() -> Self { + Self::PRESENT | Self::ACCESSED | Self::DIRTY + } + + pub fn task_data() -> Self { + Self::PRESENT | Self::WRITABLE | Self::NX | Self::ACCESSED | Self::DIRTY + } + + pub fn task_data_ro() -> Self { + Self::PRESENT | Self::NX | Self::ACCESSED | Self::DIRTY + } +} + +/// Represents paging mode. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum PagingMode { + // Paging mode is disabled + NoPaging, + // 32bit legacy paging mode + NonPAE, + // 32bit PAE paging mode + PAE, + // 4 level paging mode + PML4, + // 5 level paging mode + PML5, +} + +impl PagingMode { + pub fn new(efer: EFERFlags, cr0: CR0Flags, cr4: CR4Flags) -> Self { + if !cr0.contains(CR0Flags::PG) { + // Paging is disabled + PagingMode::NoPaging + } else if efer.contains(EFERFlags::LMA) { + // Long mode is activated + if cr4.contains(CR4Flags::LA57) { + PagingMode::PML5 + } else { + PagingMode::PML4 + } + } else if cr4.contains(CR4Flags::PAE) { + // PAE mode + PagingMode::PAE + } else { + // Non PAE mode + PagingMode::NonPAE + } + } +} + +/// Represents a page table entry. +#[repr(C)] +#[derive(Copy, Clone, Debug, Default)] +pub struct PTEntry(PhysAddr); + +impl PTEntry { + /// Check if the page table entry is clear (null). + pub fn is_clear(&self) -> bool { + self.0.is_null() + } + + /// Clear the page table entry. + pub fn clear(&mut self) { + self.0 = PhysAddr::null(); + } + + /// Check if the page table entry is present. + pub fn present(&self) -> bool { + self.flags().contains(PTEntryFlags::PRESENT) + } + + /// Check if the page table entry is huge. + pub fn huge(&self) -> bool { + self.flags().contains(PTEntryFlags::HUGE) + } + + /// Check if the page table entry is writable. + pub fn writable(&self) -> bool { + self.flags().contains(PTEntryFlags::WRITABLE) + } + + /// Check if the page table entry is NX (no-execute). + pub fn nx(&self) -> bool { + self.flags().contains(PTEntryFlags::NX) + } + + /// Check if the page table entry is user-accessible. + pub fn user(&self) -> bool { + self.flags().contains(PTEntryFlags::USER) + } + + /// Check if the page table entry has reserved bits set. + pub fn has_reserved_bits(&self, pm: PagingMode, level: usize) -> bool { + let reserved_mask = match pm { + PagingMode::NoPaging => unreachable!("NoPaging does not have page table"), + PagingMode::NonPAE => { + match level { + // No reserved bits in 4k PTE. + 0 => 0, + 1 => { + if self.huge() { + // Bit21 is reserved in 4M PDE. + BIT_MASK!(21, 21) + } else { + // No reserved bits in PDE. + 0 + } + } + _ => unreachable!("Invalid NonPAE page table level"), + } + } + PagingMode::PAE => { + // Bit62 ~ MAXPHYSADDR are reserved for each + // level in PAE page table. + BIT_MASK!(62, *PHYS_ADDR_SIZE) + | match level { + // No additional reserved bits in 4k PTE. + 0 => 0, + 1 => { + if self.huge() { + // Bit20 ~ Bit13 are reserved in 2M PDE. + BIT_MASK!(20, 13) + } else { + // No additional reserved bits in PDE. + 0 + } + } + // Bit63 and Bit8 ~ Bit5 are reserved in PDPTE. + 2 => BIT_MASK!(63, 63) | BIT_MASK!(8, 5), + _ => unreachable!("Invalid PAE page table level"), + } + } + PagingMode::PML4 | PagingMode::PML5 => { + // Bit51 ~ MAXPHYSADDR are reserved for each level + // in PML4 and PML5 page table. + let common = if *PHYS_ADDR_SIZE > 51 { + 0 + } else { + // Remove the encryption mask bit as this bit is not reserved + BIT_MASK!(51, *PHYS_ADDR_SIZE) + & !((shared_pte_mask() | private_pte_mask()) as u64) + }; + + common + | match level { + // No additional reserved bits in 4k PTE. + 0 => 0, + 1 => { + if self.huge() { + // Bit20 ~ Bit13 are reserved in 2M PDE. + BIT_MASK!(20, 13) + } else { + // No additional reserved bits in PDE. + 0 + } + } + 2 => { + if self.huge() { + // Bit29 ~ Bit13 are reserved in 1G PDPTE. + BIT_MASK!(29, 13) + } else { + // No additional reserved bits in PDPTE. + 0 + } + } + // Bit8 ~ Bit7 are reserved in PML4E. + 3 => BIT_MASK!(8, 7), + 4 => { + if pm == PagingMode::PML4 { + unreachable!("Invalid PML4 page table level"); + } else { + // Bit8 ~ Bit7 are reserved in PML5E. + BIT_MASK!(8, 7) + } + } + _ => unreachable!("Invalid PML4/PML5 page table level"), + } + } + }; + + self.raw() & reserved_mask != 0 + } + + /// Get the raw bits (`u64`) of the page table entry. + pub fn raw(&self) -> u64 { + self.0.bits() as u64 + } + + /// Get the flags of the page table entry. + pub fn flags(&self) -> PTEntryFlags { + PTEntryFlags::from_bits_truncate(self.0.bits() as u64) + } + + /// Set the page table entry with the specified address and flags. + pub fn set(&mut self, addr: PhysAddr, flags: PTEntryFlags) { + let addr = addr.bits() as u64; + assert_eq!(addr & !0x000f_ffff_ffff_f000, 0); + self.0 = PhysAddr::from(addr | supported_flags(flags).bits()); + } + + /// Get the address from the page table entry, excluding the C bit. + pub fn address(&self) -> PhysAddr { + let addr = PhysAddr::from(self.0.bits() & 0x000f_ffff_ffff_f000); + strip_confidentiality_bits(addr) + } + + /// Read a page table entry from the specified virtual address. + /// + /// # Safety + /// + /// Reads from an arbitrary virtual address, making this essentially a + /// raw pointer read. The caller must be certain to calculate the correct + /// address. + pub unsafe fn read_pte(vaddr: VirtAddr) -> Self { + *vaddr.as_ptr::() + } +} + +/// A pagetable page with multiple entries. +#[repr(C)] +#[derive(Debug)] +pub struct PTPage { + entries: [PTEntry; ENTRY_COUNT], +} + +impl PTPage { + /// Allocates a zeroed pagetable page and returns a mutable reference to + /// it, plus its physical address. + /// + /// # Errors + /// + /// Returns [`SvsmError`] if the page cannot be allocated. + fn alloc() -> Result<(&'static mut Self, PhysAddr), SvsmError> { + let page = PageBox::try_new(PTPage::default())?; + let paddr = virt_to_phys(page.vaddr()); + Ok((PageBox::leak(page), paddr)) + } + + /// Frees a pagetable page. + /// + /// # Safety + /// + /// The given reference must correspond to a valid previously allocated + /// page table page. + unsafe fn free(page: &'static Self) { + let _ = PageBox::from_raw(NonNull::from(page)); + } + + /// Converts a pagetable entry to a mutable reference to a [`PTPage`], + /// if the entry is present and not huge. + fn from_entry(entry: PTEntry) -> Option<&'static mut Self> { + let flags = entry.flags(); + if !flags.contains(PTEntryFlags::PRESENT) || flags.contains(PTEntryFlags::HUGE) { + return None; + } + + let address = phys_to_virt(entry.address()); + Some(unsafe { &mut *address.as_mut_ptr::() }) + } +} + +impl Default for PTPage { + fn default() -> Self { + let entries = [PTEntry::default(); ENTRY_COUNT]; + PTPage { entries } + } +} + +/// Can be used to access page table entries by index. +impl Index for PTPage { + type Output = PTEntry; + + fn index(&self, index: usize) -> &PTEntry { + &self.entries[index] + } +} + +/// Can be used to modify page table entries by index. +impl IndexMut for PTPage { + fn index_mut(&mut self, index: usize) -> &mut PTEntry { + &mut self.entries[index] + } +} + +/// Mapping levels of page table entries. +#[derive(Debug)] +pub enum Mapping<'a> { + Level3(&'a mut PTEntry), + Level2(&'a mut PTEntry), + Level1(&'a mut PTEntry), + Level0(&'a mut PTEntry), +} + +/// A physical address within a page frame +#[derive(Debug)] +pub enum PageFrame { + Size4K(PhysAddr), + Size2M(PhysAddr), + Size1G(PhysAddr), +} + +impl PageFrame { + pub fn address(&self) -> PhysAddr { + match *self { + Self::Size4K(pa) => pa, + Self::Size2M(pa) => pa, + Self::Size1G(pa) => pa, + } + } + + fn size(&self) -> usize { + match self { + Self::Size4K(_) => PAGE_SIZE, + Self::Size2M(_) => PAGE_SIZE_2M, + Self::Size1G(_) => PAGE_SIZE_1G, + } + } + + pub fn start(&self) -> PhysAddr { + let end = self.address().bits() & !(self.size() - 1); + end.into() + } + + pub fn end(&self) -> PhysAddr { + self.start() + self.size() + } +} + +/// Page table structure containing a root page with multiple entries. +#[repr(C)] +#[derive(Default, Debug)] +pub struct PageTable { + root: PTPage, +} + +impl PageTable { + /// Load the current page table into the CR3 register. + pub fn load(&self) { + write_cr3(self.cr3_value()); + } + + /// Get the CR3 register value for the current page table. + pub fn cr3_value(&self) -> PhysAddr { + let pgtable = VirtAddr::from(self as *const Self); + virt_to_phys(pgtable) + } + + /// Allocate a new page table root. + /// + /// # Errors + /// Returns [`SvsmError`] if the page cannot be allocated. + pub fn allocate_new() -> Result, SvsmError> { + let mut pgtable = PageBox::try_new(PageTable::default())?; + let paddr = virt_to_phys(pgtable.vaddr()); + + // Set the self-map entry. + let entry = &mut pgtable.root[PGTABLE_LVL3_IDX_PTE_SELFMAP]; + let flags = PTEntryFlags::PRESENT + | PTEntryFlags::WRITABLE + | PTEntryFlags::ACCESSED + | PTEntryFlags::DIRTY + | PTEntryFlags::NX; + entry.set(make_private_address(paddr), flags); + + Ok(pgtable) + } + + /// Clone the shared part of the page table; excluding the private + /// parts. + /// + /// # Errors + /// Returns [`SvsmError`] if the page cannot be allocated. + pub fn clone_shared(&self) -> Result, SvsmError> { + let mut pgtable = Self::allocate_new()?; + pgtable.root.entries[PGTABLE_LVL3_IDX_SHARED] = self.root.entries[PGTABLE_LVL3_IDX_SHARED]; + Ok(pgtable) + } + + /// Copy an entry `entry` from another [`PageTable`]. + pub fn copy_entry(&mut self, other: &Self, entry: usize) { + self.root.entries[entry] = other.root.entries[entry]; + } + + /// Computes the index within a page table at the given level for a + /// virtual address `vaddr`. + /// + /// # Parameters + /// - `vaddr`: The virtual address to compute the index for. + /// + /// # Returns + /// The index within the page table. + pub fn index(vaddr: VirtAddr) -> usize { + vaddr.to_pgtbl_idx::() + //vaddr.bits() >> (12 + L * 9) & 0x1ff + } + + /// Walks a page table at level 0 to find a mapping. + /// + /// # Parameters + /// - `page`: A mutable reference to the root page table. + /// - `vaddr`: The virtual address to find a mapping for. + /// + /// # Returns + /// A `Mapping` representing the found mapping. + fn walk_addr_lvl0(page: &mut PTPage, vaddr: VirtAddr) -> Mapping<'_> { + let idx = Self::index::<0>(vaddr); + Mapping::Level0(&mut page[idx]) + } + + /// Walks a page table at level 1 to find a mapping. + /// + /// # Parameters + /// - `page`: A mutable reference to the root page table. + /// - `vaddr`: The virtual address to find a mapping for. + /// + /// # Returns + /// A `Mapping` representing the found mapping. + fn walk_addr_lvl1(page: &mut PTPage, vaddr: VirtAddr) -> Mapping<'_> { + let idx = Self::index::<1>(vaddr); + let entry = page[idx]; + match PTPage::from_entry(entry) { + Some(page) => Self::walk_addr_lvl0(page, vaddr), + None => Mapping::Level1(&mut page[idx]), + } + } + + /// Walks a page table at level 2 to find a mapping. + /// + /// # Parameters + /// - `page`: A mutable reference to the root page table. + /// - `vaddr`: The virtual address to find a mapping for. + /// + /// # Returns + /// A `Mapping` representing the found mapping. + fn walk_addr_lvl2(page: &mut PTPage, vaddr: VirtAddr) -> Mapping<'_> { + let idx = Self::index::<2>(vaddr); + let entry = page[idx]; + match PTPage::from_entry(entry) { + Some(page) => Self::walk_addr_lvl1(page, vaddr), + None => Mapping::Level2(&mut page[idx]), + } + } + + /// Walks the page table to find a mapping for a given virtual address. + /// + /// # Parameters + /// - `page`: A mutable reference to the root page table. + /// - `vaddr`: The virtual address to find a mapping for. + /// + /// # Returns + /// A `Mapping` representing the found mapping. + fn walk_addr_lvl3(page: &mut PTPage, vaddr: VirtAddr) -> Mapping<'_> { + let idx = Self::index::<3>(vaddr); + let entry = page[idx]; + match PTPage::from_entry(entry) { + Some(page) => Self::walk_addr_lvl2(page, vaddr), + None => Mapping::Level3(&mut page[idx]), + } + } + + /// Walk the virtual address and return the corresponding mapping. + /// + /// # Parameters + /// - `vaddr`: The virtual address to find a mapping for. + /// + /// # Returns + /// A `Mapping` representing the found mapping. + fn walk_addr(&mut self, vaddr: VirtAddr) -> Mapping<'_> { + Self::walk_addr_lvl3(&mut self.root, vaddr) + } + + /// Calculate the virtual address of a PTE in the self-map, which maps a + /// specified virtual address. + /// + /// # Parameters + /// - `vaddr': The virtual address whose PTE should be located. + /// + /// # Returns + /// The virtual address of the PTE. + fn get_pte_address(vaddr: VirtAddr) -> VirtAddr { + SVSM_PTE_BASE + ((usize::from(vaddr) & 0x0000_FFFF_FFFF_F000) >> 9) + } + + /// Perform a virtual to physical translation using the self-map. + /// + /// # Parameters + /// - `vaddr': The virtual address to translate. + /// + /// # Returns + /// Some(PageFrame) if the virtual address is valid. + /// None if the virtual address is not valid. + pub fn virt_to_frame(vaddr: VirtAddr) -> Option { + // Calculate the virtual addresses of each level of the paging + // hierarchy in the self-map. + let pte_addr = Self::get_pte_address(vaddr); + let pde_addr = Self::get_pte_address(pte_addr); + let pdpe_addr = Self::get_pte_address(pde_addr); + let pml4e_addr = Self::get_pte_address(pdpe_addr); + + // Check each entry in the paging hierarchy to determine whether this + // address is mapped. Because the hierarchy is read from the top + // down using self-map addresses that were calculated correctly, + // the reads are safe to perform. + let pml4e = unsafe { PTEntry::read_pte(pml4e_addr) }; + if !pml4e.present() { + return None; + } + + // There is no need to check for a large page in the PML4E because + // the architecture does not support the large bit at the top-level + // entry. If a large page is detected at a lower level of the + // hierarchy, the low bits from the virtual address must be combined + // with the physical address from the PDE/PDPE. + let pdpe = unsafe { PTEntry::read_pte(pdpe_addr) }; + if !pdpe.present() { + return None; + } + if pdpe.huge() { + let pa = pdpe.address() + (usize::from(vaddr) & 0x3FFF_FFFF); + return Some(PageFrame::Size1G(pa)); + } + + let pde = unsafe { PTEntry::read_pte(pde_addr) }; + if !pde.present() { + return None; + } + if pde.huge() { + let pa = pde.address() + (usize::from(vaddr) & 0x001F_FFFF); + return Some(PageFrame::Size2M(pa)); + } + + let pte = unsafe { PTEntry::read_pte(pte_addr) }; + if pte.present() { + let pa = pte.address() + (usize::from(vaddr) & 0xFFF); + Some(PageFrame::Size4K(pa)) + } else { + None + } + } + + fn alloc_pte_lvl3(entry: &mut PTEntry, vaddr: VirtAddr, size: PageSize) -> Mapping<'_> { + let flags = entry.flags(); + + if flags.contains(PTEntryFlags::PRESENT) { + return Mapping::Level3(entry); + } + + let Ok((page, paddr)) = PTPage::alloc() else { + return Mapping::Level3(entry); + }; + + let flags = PTEntryFlags::PRESENT + | PTEntryFlags::WRITABLE + | PTEntryFlags::USER + | PTEntryFlags::ACCESSED; + entry.set(make_private_address(paddr), flags); + + let idx = Self::index::<2>(vaddr); + Self::alloc_pte_lvl2(&mut page[idx], vaddr, size) + } + + fn alloc_pte_lvl2(entry: &mut PTEntry, vaddr: VirtAddr, size: PageSize) -> Mapping<'_> { + let flags = entry.flags(); + + if flags.contains(PTEntryFlags::PRESENT) { + return Mapping::Level2(entry); + } + + let Ok((page, paddr)) = PTPage::alloc() else { + return Mapping::Level2(entry); + }; + + let flags = PTEntryFlags::PRESENT + | PTEntryFlags::WRITABLE + | PTEntryFlags::USER + | PTEntryFlags::ACCESSED; + entry.set(make_private_address(paddr), flags); + + let idx = Self::index::<1>(vaddr); + Self::alloc_pte_lvl1(&mut page[idx], vaddr, size) + } + + fn alloc_pte_lvl1(entry: &mut PTEntry, vaddr: VirtAddr, size: PageSize) -> Mapping<'_> { + let flags = entry.flags(); + + if size == PageSize::Huge || flags.contains(PTEntryFlags::PRESENT) { + return Mapping::Level1(entry); + } + + let Ok((page, paddr)) = PTPage::alloc() else { + return Mapping::Level1(entry); + }; + + let flags = PTEntryFlags::PRESENT + | PTEntryFlags::WRITABLE + | PTEntryFlags::USER + | PTEntryFlags::ACCESSED; + entry.set(make_private_address(paddr), flags); + + let idx = Self::index::<0>(vaddr); + Mapping::Level0(&mut page[idx]) + } + + /// Allocates a 4KB page table entry for a given virtual address. + /// + /// # Parameters + /// - `vaddr`: The virtual address for which to allocate the PTE. + /// + /// # Returns + /// A `Mapping` representing the allocated or existing PTE for the address. + fn alloc_pte_4k(&mut self, vaddr: VirtAddr) -> Mapping<'_> { + let m = self.walk_addr(vaddr); + + match m { + Mapping::Level0(entry) => Mapping::Level0(entry), + Mapping::Level1(entry) => Self::alloc_pte_lvl1(entry, vaddr, PageSize::Regular), + Mapping::Level2(entry) => Self::alloc_pte_lvl2(entry, vaddr, PageSize::Regular), + Mapping::Level3(entry) => Self::alloc_pte_lvl3(entry, vaddr, PageSize::Regular), + } + } + + /// Allocates a 2MB page table entry for a given virtual address. + /// + /// # Parameters + /// - `vaddr`: The virtual address for which to allocate the PTE. + /// + /// # Returns + /// A `Mapping` representing the allocated or existing PTE for the address. + fn alloc_pte_2m(&mut self, vaddr: VirtAddr) -> Mapping<'_> { + let m = self.walk_addr(vaddr); + + match m { + Mapping::Level0(entry) => Mapping::Level0(entry), + Mapping::Level1(entry) => Mapping::Level1(entry), + Mapping::Level2(entry) => Self::alloc_pte_lvl2(entry, vaddr, PageSize::Huge), + Mapping::Level3(entry) => Self::alloc_pte_lvl3(entry, vaddr, PageSize::Huge), + } + } + + /// Splits a 2MB page into 4KB pages. + /// + /// # Parameters + /// - `entry`: The 2M page table entry to split. + /// + /// # Returns + /// A result indicating success or an error [`SvsmError`] in failure. + fn do_split_4k(entry: &mut PTEntry) -> Result<(), SvsmError> { + let (page, paddr) = PTPage::alloc()?; + let mut flags = entry.flags(); + + assert!(flags.contains(PTEntryFlags::HUGE)); + + let addr_2m = PhysAddr::from(entry.address().bits() & 0x000f_ffff_fff0_0000); + + flags.remove(PTEntryFlags::HUGE); + + // Prepare PTE leaf page + for (i, e) in page.entries.iter_mut().enumerate() { + let addr_4k = addr_2m + (i * PAGE_SIZE); + e.clear(); + e.set(make_private_address(addr_4k), flags); + } + + entry.set(make_private_address(paddr), flags); + + flush_tlb_global_sync(); + + Ok(()) + } + + /// Splits a page into 4KB pages if it is part of a larger mapping. + /// + /// # Parameters + /// - `mapping`: The mapping to split. + /// + /// # Returns + /// A result indicating success or an error [`SvsmError`]. + fn split_4k(mapping: Mapping<'_>) -> Result<(), SvsmError> { + match mapping { + Mapping::Level0(_entry) => Ok(()), + Mapping::Level1(entry) => Self::do_split_4k(entry), + Mapping::Level2(_entry) => Err(SvsmError::Mem), + Mapping::Level3(_entry) => Err(SvsmError::Mem), + } + } + + fn make_pte_shared(entry: &mut PTEntry) { + let flags = entry.flags(); + let addr = entry.address(); + + // entry.address() returned with c-bit clear already + entry.set(make_shared_address(addr), flags); + } + + fn make_pte_private(entry: &mut PTEntry) { + let flags = entry.flags(); + let addr = entry.address(); + + // entry.address() returned with c-bit clear already + entry.set(make_private_address(addr), flags); + } + + /// Sets the shared state for a 4KB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address of the page. + /// + /// # Returns + /// A result indicating success or an error [`SvsmError`] if the + /// operation fails. + pub fn set_shared_4k(&mut self, vaddr: VirtAddr) -> Result<(), SvsmError> { + let mapping = self.walk_addr(vaddr); + Self::split_4k(mapping)?; + + if let Mapping::Level0(entry) = self.walk_addr(vaddr) { + Self::make_pte_shared(entry); + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + /// Sets the encryption state for a 4KB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address of the page. + /// + /// # Returns + /// A result indicating success or an error [`SvsmError`]. + pub fn set_encrypted_4k(&mut self, vaddr: VirtAddr) -> Result<(), SvsmError> { + let mapping = self.walk_addr(vaddr); + Self::split_4k(mapping)?; + + if let Mapping::Level0(entry) = self.walk_addr(vaddr) { + Self::make_pte_private(entry); + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + /// Gets the physical address for a mapped `vaddr` or `None` if + /// no such mapping exists. + pub fn check_mapping(&mut self, vaddr: VirtAddr) -> Option { + match self.walk_addr(vaddr) { + Mapping::Level0(entry) => Some(entry.address()), + Mapping::Level1(entry) => Some(entry.address()), + _ => None, + } + } + + /// Maps a 2MB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address to map. + /// - `paddr`: The physical address to map to. + /// - `flags`: The flags to apply to the mapping. + /// + /// # Returns + /// A result indicating success or failure ([`SvsmError`]). + /// + /// # Panics + /// Panics if either `vaddr` or `paddr` is not aligned to a 2MB boundary. + pub fn map_2m( + &mut self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: PTEntryFlags, + ) -> Result<(), SvsmError> { + assert!(vaddr.is_aligned(PAGE_SIZE_2M)); + assert!(paddr.is_aligned(PAGE_SIZE_2M)); + + let mapping = self.alloc_pte_2m(vaddr); + + if let Mapping::Level1(entry) = mapping { + entry.set(make_private_address(paddr), flags | PTEntryFlags::HUGE); + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + /// Unmaps a 2MB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address of the mapping to unmap. + /// + /// # Panics + /// Panics if `vaddr` is not aligned to a 2MB boundary. + pub fn unmap_2m(&mut self, vaddr: VirtAddr) { + assert!(vaddr.is_aligned(PAGE_SIZE_2M)); + + let mapping = self.walk_addr(vaddr); + + match mapping { + Mapping::Level0(_) => unreachable!(), + Mapping::Level1(entry) => entry.clear(), + Mapping::Level2(entry) => assert!(!entry.present()), + Mapping::Level3(entry) => assert!(!entry.present()), + } + } + + /// Maps a 4KB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address to map. + /// - `paddr`: The physical address to map to. + /// - `flags`: The flags to apply to the mapping. + /// + /// # Returns + /// A result indicating success or failure ([`SvsmError`]). + pub fn map_4k( + &mut self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: PTEntryFlags, + ) -> Result<(), SvsmError> { + let mapping = self.alloc_pte_4k(vaddr); + + if let Mapping::Level0(entry) = mapping { + entry.set(make_private_address(paddr), flags); + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + /// Unmaps a 4KB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address of the mapping to unmap. + pub fn unmap_4k(&mut self, vaddr: VirtAddr) { + let mapping = self.walk_addr(vaddr); + + match mapping { + Mapping::Level0(entry) => entry.clear(), + Mapping::Level1(entry) => assert!(!entry.present()), + Mapping::Level2(entry) => assert!(!entry.present()), + Mapping::Level3(entry) => assert!(!entry.present()), + } + } + + /// Retrieves the physical address of a mapping. + /// + /// # Parameters + /// - `vaddr`: The virtual address to query. + /// + /// # Returns + /// The physical address of the mapping if present; otherwise, an error + /// ([`SvsmError`]). + pub fn phys_addr(&mut self, vaddr: VirtAddr) -> Result { + let mapping = self.walk_addr(vaddr); + + match mapping { + Mapping::Level0(entry) => { + let offset = vaddr.page_offset(); + if !entry.flags().contains(PTEntryFlags::PRESENT) { + return Err(SvsmError::Mem); + } + Ok(entry.address() + offset) + } + Mapping::Level1(entry) => { + let offset = vaddr.bits() & (PAGE_SIZE_2M - 1); + if !entry.flags().contains(PTEntryFlags::PRESENT) + || !entry.flags().contains(PTEntryFlags::HUGE) + { + return Err(SvsmError::Mem); + } + + Ok(entry.address() + offset) + } + Mapping::Level2(_entry) => Err(SvsmError::Mem), + Mapping::Level3(_entry) => Err(SvsmError::Mem), + } + } + + /// Maps a region of memory using 4KB pages. + /// + /// # Parameters + /// - `vregion`: The virtual memory region to map. + /// - `phys`: The starting physical address to map to. + /// - `flags`: The flags to apply to the mapping. + /// + /// # Returns + /// A result indicating success or failure ([`SvsmError`]). + pub fn map_region_4k( + &mut self, + vregion: MemoryRegion, + phys: PhysAddr, + flags: PTEntryFlags, + ) -> Result<(), SvsmError> { + for addr in vregion.iter_pages(PageSize::Regular) { + let offset = addr - vregion.start(); + self.map_4k(addr, phys + offset, flags)?; + } + Ok(()) + } + + /// Unmaps a region of memory using 4KB pages. + /// + /// # Parameters + /// - `vregion`: The virtual memory region to unmap. + pub fn unmap_region_4k(&mut self, vregion: MemoryRegion) { + for addr in vregion.iter_pages(PageSize::Regular) { + self.unmap_4k(addr); + } + } + + /// Maps a region of memory using 2MB pages. + /// + /// # Parameters + /// - `vregion`: The virtual memory region to map. + /// - `phys`: The starting physical address to map to. + /// - `flags`: The flags to apply to the mapping. + /// + /// # Returns + /// A result indicating success or failure ([`SvsmError`]). + pub fn map_region_2m( + &mut self, + vregion: MemoryRegion, + phys: PhysAddr, + flags: PTEntryFlags, + ) -> Result<(), SvsmError> { + for addr in vregion.iter_pages(PageSize::Huge) { + let offset = addr - vregion.start(); + self.map_2m(addr, phys + offset, flags)?; + } + Ok(()) + } + + /// Unmaps a region `vregion` of 2MB pages. The region must be + /// 2MB-aligned and correspond to a set of huge mappings. + pub fn unmap_region_2m(&mut self, vregion: MemoryRegion) { + for addr in vregion.iter_pages(PageSize::Huge) { + self.unmap_2m(addr); + } + } + + /// Maps a memory region to physical memory with specified flags. + /// + /// # Parameters + /// - `region`: The virtual memory region to map. + /// - `phys`: The starting physical address to map to. + /// - `flags`: The flags to apply to the page table entries. + /// + /// # Returns + /// A result indicating success (`Ok`) or failure (`Err`). + pub fn map_region( + &mut self, + region: MemoryRegion, + phys: PhysAddr, + flags: PTEntryFlags, + ) -> Result<(), SvsmError> { + let mut vaddr = region.start(); + let end = region.end(); + let mut paddr = phys; + + while vaddr < end { + if vaddr.is_aligned(PAGE_SIZE_2M) + && paddr.is_aligned(PAGE_SIZE_2M) + && vaddr + PAGE_SIZE_2M <= end + && self.map_2m(vaddr, paddr, flags).is_ok() + { + vaddr = vaddr + PAGE_SIZE_2M; + paddr = paddr + PAGE_SIZE_2M; + continue; + } + + self.map_4k(vaddr, paddr, flags)?; + vaddr = vaddr + PAGE_SIZE; + paddr = paddr + PAGE_SIZE; + } + + Ok(()) + } + + /// Unmaps the virtual memory region `vregion`. + pub fn unmap_region(&mut self, vregion: MemoryRegion) { + let mut vaddr = vregion.start(); + let end = vregion.end(); + + while vaddr < end { + let mapping = self.walk_addr(vaddr); + + match mapping { + Mapping::Level0(entry) => { + entry.clear(); + vaddr = vaddr + PAGE_SIZE; + } + Mapping::Level1(entry) => { + entry.clear(); + vaddr = vaddr + PAGE_SIZE_2M; + } + _ => { + log::error!("Can't unmap - address not mapped {:#x}", vaddr); + } + } + } + } + + /// Populates this paghe table with the contents of the given subtree + /// in `part`. + pub fn populate_pgtbl_part(&mut self, part: &PageTablePart) { + if let Some(paddr) = part.address() { + let idx = part.index(); + let flags = PTEntryFlags::PRESENT + | PTEntryFlags::WRITABLE + | PTEntryFlags::USER + | PTEntryFlags::ACCESSED; + let entry = &mut self.root[idx]; + entry.set(make_private_address(paddr), flags); + } + } +} + +/// Represents a sub-tree of a page-table which can be mapped at a top-level index +#[derive(Default, Debug)] +struct RawPageTablePart { + page: PTPage, +} + +impl RawPageTablePart { + /// Frees a level 1 page table. + fn free_lvl1(page: &PTPage) { + for entry in page.entries.iter() { + if let Some(page) = PTPage::from_entry(*entry) { + // SAFETY: the page comes from an entry in the page table, + // which we allocated using `PTPage::alloc()`, so this is + // safe. + unsafe { PTPage::free(page) }; + } + } + } + + /// Frees a level 2 page table, including all level 1 tables beneath it. + fn free_lvl2(page: &PTPage) { + for entry in page.entries.iter() { + if let Some(l1_page) = PTPage::from_entry(*entry) { + Self::free_lvl1(l1_page); + // SAFETY: the page comes from an entry in the page table, + // which we allocated using `PTPage::alloc()`, so this is + // safe. + unsafe { PTPage::free(l1_page) }; + } + } + } + + /// Frees the resources associated with this page table part. + fn free(&self) { + RawPageTablePart::free_lvl2(&self.page); + } + + /// Returns the physical address of this page table part. + fn address(&self) -> PhysAddr { + virt_to_phys(VirtAddr::from(self as *const RawPageTablePart)) + } + + /// Walks the page table at level 3 to find the mapping for a given + /// virtual address. + /// + /// # Parameters + /// - `vaddr`: The virtual address to find the mapping for. + /// + /// # Returns + /// The [`Mapping`] for the given virtual address. + fn walk_addr(&mut self, vaddr: VirtAddr) -> Mapping<'_> { + PageTable::walk_addr_lvl2(&mut self.page, vaddr) + } + + /// Allocates a 4KB page table entry for a given virtual address. + /// + /// # Parameters + /// - `vaddr`: The virtual address for which to allocate the PTE. + /// + /// # Returns + /// The [`Mapping`] representing the allocated or existing PTE for the address. + /// + /// # Panics + /// Panics if a level 3 mapping is attempted in a [`RawPageTablePart`]. + fn alloc_pte_4k(&mut self, vaddr: VirtAddr) -> Mapping<'_> { + let m = self.walk_addr(vaddr); + + match m { + Mapping::Level0(entry) => Mapping::Level0(entry), + Mapping::Level1(entry) => PageTable::alloc_pte_lvl1(entry, vaddr, PageSize::Regular), + Mapping::Level2(entry) => PageTable::alloc_pte_lvl2(entry, vaddr, PageSize::Regular), + Mapping::Level3(_) => panic!("PT level 3 not possible in PageTablePart"), + } + } + + /// Allocates a 2MB page table entry for a given virtual address. + /// + /// # Parameters + /// - `vaddr`: The virtual address for which to allocate the PTE. + /// + /// # Returns + /// The [`Mapping`] representing the allocated or existing PTE for the + /// address. + fn alloc_pte_2m(&mut self, vaddr: VirtAddr) -> Mapping<'_> { + let m = self.walk_addr(vaddr); + + match m { + Mapping::Level0(entry) => Mapping::Level0(entry), + Mapping::Level1(entry) => Mapping::Level1(entry), + Mapping::Level2(entry) => PageTable::alloc_pte_lvl2(entry, vaddr, PageSize::Huge), + Mapping::Level3(entry) => PageTable::alloc_pte_lvl3(entry, vaddr, PageSize::Huge), + } + } + + /// Maps a 4KB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address to map. + /// - `paddr`: The physical address to map to. + /// - `flags`: The flags to apply to the mapping. + /// - `shared`: Indicates whether the mapping is shared. + /// + /// # Returns + /// A result indicating success (`Ok`) or failure (`Err`). + fn map_4k( + &mut self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: PTEntryFlags, + shared: bool, + ) -> Result<(), SvsmError> { + let mapping = self.alloc_pte_4k(vaddr); + + let addr = if !shared { + make_private_address(paddr) + } else { + make_shared_address(paddr) + }; + + if let Mapping::Level0(entry) = mapping { + entry.set(addr, flags); + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + /// Unmaps a 4KB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address of the mapping to unmap. + /// + /// # Returns + /// An optional [`PTEntry`] representing the unmapped page table entry. + fn unmap_4k(&mut self, vaddr: VirtAddr) -> Option { + let mapping = self.walk_addr(vaddr); + + match mapping { + Mapping::Level0(entry) => { + let e = *entry; + entry.clear(); + Some(e) + } + Mapping::Level1(entry) => { + assert!(!entry.present()); + None + } + Mapping::Level2(entry) => { + assert!(!entry.present()); + None + } + Mapping::Level3(entry) => { + assert!(!entry.present()); + None + } + } + } + + /// Maps a 2MB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address to map. + /// - `paddr`: The physical address to map to. + /// - `flags`: The flags to apply to the mapping. + /// - `shared`: Indicates whether the mapping is shared + /// + /// # Returns + /// A result indicating success (`Ok`) or failure (`Err`). + /// + /// # Panics + /// + /// Panics if `vaddr` or `paddr` are not 2MB-aligned + fn map_2m( + &mut self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: PTEntryFlags, + shared: bool, + ) -> Result<(), SvsmError> { + assert!(vaddr.is_aligned(PAGE_SIZE_2M)); + assert!(paddr.is_aligned(PAGE_SIZE_2M)); + + let mapping = self.alloc_pte_2m(vaddr); + let addr = if !shared { + make_private_address(paddr) + } else { + make_shared_address(paddr) + }; + + if let Mapping::Level1(entry) = mapping { + entry.set(addr, flags | PTEntryFlags::HUGE); + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + /// Unmaps a 2MB page. + /// + /// # Parameters + /// - `vaddr`: The virtual address of the mapping to unmap. + /// + /// # Returns + /// An optional [`PTEntry`] representing the unmapped page table entry. + /// + /// # Panics + /// + /// Panics if `vaddr` is not memory aligned. + fn unmap_2m(&mut self, vaddr: VirtAddr) -> Option { + assert!(vaddr.is_aligned(PAGE_SIZE_2M)); + + let mapping = self.walk_addr(vaddr); + + match mapping { + Mapping::Level0(_) => None, + Mapping::Level1(entry) => { + entry.clear(); + Some(*entry) + } + Mapping::Level2(entry) => { + assert!(!entry.present()); + None + } + Mapping::Level3(entry) => { + assert!(!entry.present()); + None + } + } + } +} + +impl Drop for RawPageTablePart { + fn drop(&mut self) { + self.free(); + } +} + +/// Sub-tree of a page table that can be populated at the top-level +/// used for virtual memory management +#[derive(Debug)] +pub struct PageTablePart { + /// The root of the page-table sub-tree + raw: Option>, + /// The top-level index this PageTablePart is populated at + idx: usize, +} + +impl PageTablePart { + /// Create a new PageTablePart and allocate a root page for the page-table sub-tree. + /// + /// # Arguments + /// + /// - `start`: Virtual start address this PageTablePart maps + /// + /// # Returns + /// + /// A new instance of PageTablePart + pub fn new(start: VirtAddr) -> Self { + PageTablePart { + raw: None, + idx: PageTable::index::<3>(start), + } + } + + pub fn alloc(&mut self) { + self.get_or_init_mut(); + } + + fn get_or_init_mut(&mut self) -> &mut RawPageTablePart { + self.raw.get_or_insert_with(Box::default) + } + + fn get_mut(&mut self) -> Option<&mut RawPageTablePart> { + self.raw.as_deref_mut() + } + + fn get(&self) -> Option<&RawPageTablePart> { + self.raw.as_deref() + } + + /// Request PageTable index to populate this instance to + /// + /// # Returns + /// + /// Index of the top-level PageTable this sub-tree is populated to + pub fn index(&self) -> usize { + self.idx + } + + /// Request physical base address of the page-table sub-tree. This is + /// needed to populate the PageTablePart. + /// + /// # Returns + /// + /// Physical base address of the page-table sub-tree + pub fn address(&self) -> Option { + self.get().map(|p| p.address()) + } + + /// Map a 4KiB page in the page table sub-tree + /// + /// # Arguments + /// + /// * `vaddr` - Virtual address to create the mapping. Must be aligned to 4KiB. + /// * `paddr` - Physical address to map. Must be aligned to 4KiB. + /// * `flags` - PTEntryFlags used for the mapping + /// * `shared` - Defines whether the page is mapped shared or private + /// + /// # Returns + /// + /// OK(()) on Success, Err(SvsmError::Mem) on error. + /// + /// This function can fail when there not enough memory to allocate pages for the mapping. + /// + /// # Panics + /// + /// This method panics when either `vaddr` or `paddr` are not aligned to 4KiB. + pub fn map_4k( + &mut self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: PTEntryFlags, + shared: bool, + ) -> Result<(), SvsmError> { + assert!(PageTable::index::<3>(vaddr) == self.idx); + + self.get_or_init_mut().map_4k(vaddr, paddr, flags, shared) + } + + /// Unmaps a 4KiB page from the page table sub-tree + /// + /// # Arguments + /// + /// * `vaddr` - The virtual address to unmap. Must be aligned to 4KiB. + /// + /// # Returns + /// + /// Returns a copy of the PTEntry that mapped the virtual address, if any. + /// + /// # Panics + /// + /// This method panics when `vaddr` is not aligned to 4KiB. + pub fn unmap_4k(&mut self, vaddr: VirtAddr) -> Option { + assert!(PageTable::index::<3>(vaddr) == self.idx); + + self.get_mut().and_then(|r| r.unmap_4k(vaddr)) + } + + /// Map a 2MiB page in the page table sub-tree + /// + /// # Arguments + /// + /// * `vaddr` - Virtual address to create the mapping. Must be aligned to 2MiB. + /// * `paddr` - Physical address to map. Must be aligned to 2MiB. + /// * `flags` - PTEntryFlags used for the mapping + /// * `shared` - Defines whether the page is mapped shared or private + /// + /// # Returns + /// + /// OK(()) on Success, Err(SvsmError::Mem) on error. + /// + /// This function can fail when there not enough memory to allocate pages for the mapping. + /// + /// # Panics + /// + /// This method panics when either `vaddr` or `paddr` are not aligned to 2MiB. + pub fn map_2m( + &mut self, + vaddr: VirtAddr, + paddr: PhysAddr, + flags: PTEntryFlags, + shared: bool, + ) -> Result<(), SvsmError> { + assert!(PageTable::index::<3>(vaddr) == self.idx); + + self.get_or_init_mut().map_2m(vaddr, paddr, flags, shared) + } + + /// Unmaps a 2MiB page from the page table sub-tree + /// + /// # Arguments + /// + /// * `vaddr` - The virtual address to unmap. Must be aligned to 2MiB. + /// + /// # Returns + /// + /// Returns a copy of the PTEntry that mapped the virtual address, if any. + /// + /// # Panics + /// + /// This method panics when `vaddr` is not aligned to 2MiB. + pub fn unmap_2m(&mut self, vaddr: VirtAddr) -> Option { + assert!(PageTable::index::<3>(vaddr) == self.idx); + + self.get_mut().and_then(|r| r.unmap_2m(vaddr)) + } +} + +bitflags! { + /// Flags to represent how memory is accessed, e.g. write data to the + /// memory or fetch code from the memory. + #[derive(Clone, Copy, Debug)] + pub struct MemAccessMode: u32 { + const WRITE = 1 << 0; + const FETCH = 1 << 1; + } +} + +/// Attributes to determin Whether a memory access (write/fetch) is permitted +/// by a translation which includes the paging-mode modifiers in CR0, CR4 and +/// EFER; EFLAGS.AC; and the supervisor/user mode access. +#[derive(Clone, Copy, Debug)] +pub struct PTWalkAttr { + cr0: CR0Flags, + cr4: CR4Flags, + efer: EFERFlags, + flags: RFlags, + user_mode_access: bool, + pm: PagingMode, +} + +impl PTWalkAttr { + /// Creates a new `PTWalkAttr` instance with the specified attributes. + /// + /// # Arguments + /// + /// * `cr0`, `cr4`, and `efer` - Represent the control register + /// flags for CR0, CR4, and EFER respectively. + /// * `flags` - Represents the CPU Flags. + /// * `user_mode_access` - Indicates whether the access is in user mode. + /// + /// Returns a new `PTWalkAttr` instance. + pub fn new( + cr0: CR0Flags, + cr4: CR4Flags, + efer: EFERFlags, + flags: RFlags, + user_mode_access: bool, + ) -> Self { + Self { + cr0, + cr4, + efer, + flags, + user_mode_access, + pm: PagingMode::new(efer, cr0, cr4), + } + } + + /// Checks the access rights for a page table entry. + /// + /// # Arguments + /// + /// * `entry` - The page table entry to check. + /// * `mem_am` - Indicates how to access the memory. + /// * `last_level` - Indicates whether the entry is at the last level + /// of the page table. + /// * `pteflags` - The PTE flags to indicate if the corresponding page + /// table entry allows the access rights. + /// + /// # Returns + /// + /// Returns `Ok((entry, leaf))` if the access rights are valid, where + /// `entry` is the modified page table entry and `leaf` is a boolean + /// indicating whether the entry is a leaf node, or `Err(PageFaultError)` + /// to indicate the page fault error code if the access rights are invalid. + pub fn check_access_rights( + &self, + entry: PTEntry, + mem_am: MemAccessMode, + level: usize, + pteflags: &mut PTEntryFlags, + ) -> Result<(PTEntry, bool), PageFaultError> { + let pf_err = self.default_pf_err(mem_am) | PageFaultError::P; + + if !entry.present() { + // Entry is not present. + return Err(pf_err & !PageFaultError::P); + } + + if entry.has_reserved_bits(self.pm, level) { + // Reserved bits have been set. + return Err(pf_err | PageFaultError::R); + } + + // SDM 4.6.1 Determination of Access Rights: + // If the U/S flag (bit 2) is 0 in at least one of the + // paging-structure entries, the address is a supervisor-mode + // address. Otherwise, the address is a user-mode address. + // So by-default assume the address is user mode address. + if !entry.user() { + *pteflags &= !PTEntryFlags::USER; + } + + // SDM 4.6.1 Determination of Access Rights: + // R/W flag (bit 1) is 1 in every paging-structure entry controlling + // the translation and with a protection key for which write access is + // permitted; data may not be written to any supervisor-mode + // address with a translation for which the R/W flag is 0 in any + // paging-structure entry controlling the translation. + // The same for user mode address + if !entry.writable() { + *pteflags &= !PTEntryFlags::WRITABLE; + } + + // SDM 4.6.1 Determination of Access Rights: + // For non 32-bit paging modes with IA32_EFER.NXE = 1, instructions + // may be fetched from any supervisormode address with a translation + // for which the XD flag (bit 63) is 0 in every paging-structure entry + // controlling the translation; instructions may not be fetched from + // any supervisor-mode address with a translation for which the XD flag + // is 1 in any paging-structure entry controlling the translation + if self.efer.contains(EFERFlags::NXE) && entry.nx() { + *pteflags |= PTEntryFlags::NX; + } else if !self.efer.contains(EFERFlags::NXE) && entry.nx() { + // XD bit must be 0 if efer.NXE = 0 + return Err(pf_err | PageFaultError::R); + } + + let leaf = if level == 0 || entry.huge() { + // User mode cannot access any supervisor mode addresses + if self.user_mode_access && !pteflags.contains(PTEntryFlags::USER) { + return Err(pf_err); + } + + // Always check for reading. For the case of supervisor mode read user + // mode addresses, do special checking. For other cases, read is allowed. + if !self.user_mode_access && pteflags.contains(PTEntryFlags::USER) { + // Read not allowed with SMAP = 1 && flags.ac = 0 + if self.cr4.contains(CR4Flags::SMAP) && !self.flags.contains(RFlags::AC) { + return Err(pf_err); + } + } + + if mem_am.contains(MemAccessMode::WRITE) { + if !self.user_mode_access && pteflags.contains(PTEntryFlags::USER) { + // Check supervisor mode write user mode addresses + if !self.cr0.contains(CR0Flags::WP) { + // Check write with CR0.WP = 0 + if self.cr4.contains(CR4Flags::SMAP) && !self.flags.contains(RFlags::AC) { + // Write not allowed with SMAP = 1 && flags.ac = 0 + return Err(pf_err); + } + } else { + // Check write with CR0.WP = 1 + if !self.cr4.contains(CR4Flags::SMAP) { + // SMAP = 0 + if !pteflags.contains(PTEntryFlags::WRITABLE) { + // Write not allowed R/W = 0 + return Err(pf_err); + } + } else { + // SMAP = 1 + if !self.flags.contains(RFlags::AC) + || !pteflags.contains(PTEntryFlags::WRITABLE) + { + // Write not allowed with flags.AC = 0 || R/W = 0 + return Err(pf_err); + } + } + } + } else if !self.user_mode_access && !pteflags.contains(PTEntryFlags::USER) { + // Check supervisor mode write supervisor mode addresses + if self.cr0.contains(CR0Flags::WP) && !pteflags.contains(PTEntryFlags::WRITABLE) + { + // Write not allowed with CR0.WP = 1 && R/W = 0 + return Err(pf_err); + } + } else if self.user_mode_access && pteflags.contains(PTEntryFlags::USER) { + // Check user mode write user mode addresses + if !pteflags.contains(PTEntryFlags::WRITABLE) { + // Write not allowed R/W = 0 + return Err(pf_err); + } + } + // User mode write supervisor mode addresses is checked already + } + + if mem_am.contains(MemAccessMode::FETCH) { + // For instruction fetch, the rule is the same except for the case of + // supervisor mode fetch user mode addresses + if !self.user_mode_access && pteflags.contains(PTEntryFlags::USER) { + // Fetch not allowed with SMEP = 1 + if self.cr4.contains(CR4Flags::SMEP) { + return Err(pf_err); + } + } + + // For non-32bit paging mode, fetch not allowed with efer.NXE = 1 && XD = 1 + if self.cr4.contains(CR4Flags::PAE) + && self.efer.contains(EFERFlags::NXE) + && pteflags.contains(PTEntryFlags::NX) + { + return Err(pf_err); + } + } + true + } else { + false + }; + + Ok((entry, leaf)) + } + + fn default_pf_err(&self, mem_am: MemAccessMode) -> PageFaultError { + let mut err = PageFaultError::empty(); + + if mem_am.contains(MemAccessMode::WRITE) { + err |= PageFaultError::W; + } + + if mem_am.contains(MemAccessMode::FETCH) { + err |= PageFaultError::I; + } + + if self.user_mode_access { + err |= PageFaultError::U; + } + + err + } +} diff --git a/stage2/src/mm/ptguards.rs b/stage2/src/mm/ptguards.rs new file mode 100644 index 000000000..abaa3cc4e --- /dev/null +++ b/stage2/src/mm/ptguards.rs @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::pagetable::PTEntryFlags; +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::cpu::percpu::this_cpu; +use crate::cpu::tlb::flush_address_percpu; +use crate::error::SvsmError; +use crate::insn_decode::{InsnError, InsnMachineMem}; +use crate::mm::virtualrange::VRangeAlloc; +use crate::types::{PageSize, PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::MemoryRegion; +use core::marker::PhantomData; + +/// Guard for a per-CPU page mapping to ensure adequate cleanup if drop. +#[derive(Debug)] +#[must_use = "if unused the mapping will immediately be unmapped"] +pub struct PerCPUPageMappingGuard { + mapping: VRangeAlloc, +} + +impl PerCPUPageMappingGuard { + /// Creates a new [`PerCPUPageMappingGuard`] for the specified physical + /// address range and alignment. + /// + /// # Arguments + /// + /// * `paddr_start` - The starting physical address of the range. + /// * `paddr_end` - The ending physical address of the range. + /// * `alignment` - The desired alignment for the mapping. + /// + /// # Returns + /// + /// A `Result` containing the [`PerCPUPageMappingGuard`] if successful, + /// or an `SvsmError` if an error occurs. + /// + /// # Panics + /// + /// Panics if either `paddr_start`, the size, or `paddr_end`, are not + /// aligned. + pub fn create( + paddr_start: PhysAddr, + paddr_end: PhysAddr, + alignment: usize, + ) -> Result { + let align_mask = (PAGE_SIZE << alignment) - 1; + let size = paddr_end - paddr_start; + assert!((size & align_mask) == 0); + assert!((paddr_start.bits() & align_mask) == 0); + assert!((paddr_end.bits() & align_mask) == 0); + + let flags = PTEntryFlags::data(); + let huge = ((paddr_start.bits() & (PAGE_SIZE_2M - 1)) == 0) + && ((paddr_end.bits() & (PAGE_SIZE_2M - 1)) == 0); + + let mapping = if huge { + let range = VRangeAlloc::new_2m(size, 0)?; + this_cpu() + .get_pgtable() + .map_region_2m(range.region(), paddr_start, flags)?; + range + } else { + let range = VRangeAlloc::new_4k(size, 0)?; + this_cpu() + .get_pgtable() + .map_region_4k(range.region(), paddr_start, flags)?; + range + }; + + Ok(Self { mapping }) + } + + /// Creates a new [`PerCPUPageMappingGuard`] for a 4KB page at the + /// specified physical address, or an `SvsmError` if an error occurs. + pub fn create_4k(paddr: PhysAddr) -> Result { + Self::create(paddr, paddr + PAGE_SIZE, 0) + } + + /// Returns the virtual address associated with the guard. + pub fn virt_addr(&self) -> VirtAddr { + self.mapping.region().start() + } + + /// Creates a virtual contigous mapping for the given 4k physical pages which + /// may not be contiguous in physical memory. + /// + /// # Arguments + /// + /// * `pages`: A slice of tuple containing `PhysAddr` objects representing the + /// 4k page to map and its shareability. + /// + /// # Returns + /// + /// This function returns a `Result` that contains a `PerCPUPageMappingGuard` + /// object on success. The `PerCPUPageMappingGuard` object represents the page + /// mapping that was created. If an error occurs while creating the page + /// mapping, it returns a `SvsmError`. + pub fn create_4k_pages(pages: &[(PhysAddr, bool)]) -> Result { + let mapping = VRangeAlloc::new_4k(pages.len() * PAGE_SIZE, 0)?; + let flags = PTEntryFlags::data(); + + let mut pgtable = this_cpu().get_pgtable(); + for (vaddr, (paddr, shared)) in mapping + .region() + .iter_pages(PageSize::Regular) + .zip(pages.iter().copied()) + { + assert!(paddr.is_page_aligned()); + pgtable.map_4k(vaddr, paddr, flags)?; + if shared { + pgtable.set_shared_4k(vaddr)?; + } + } + + Ok(Self { mapping }) + } +} + +impl Drop for PerCPUPageMappingGuard { + fn drop(&mut self) { + let region = self.mapping.region(); + let size = if self.mapping.huge() { + this_cpu().get_pgtable().unmap_region_2m(region); + PageSize::Huge + } else { + this_cpu().get_pgtable().unmap_region_4k(region); + PageSize::Regular + }; + // This iterative flush is acceptable for same-CPU mappings because no + // broadcast is involved for each iteration. + for page in region.iter_pages(size) { + flush_address_percpu(page); + } + } +} + +/// Represents a guard for a specific memory range mapping, which will +/// unmap the specific memory range after being dropped. +#[derive(Debug)] +pub struct MemMappingGuard { + // The guard of holding the temperary mapping for a specific memory range. + guard: PerCPUPageMappingGuard, + // The starting offset of the memory range. + start_off: usize, + + phantom: PhantomData, +} + +impl MemMappingGuard { + /// Creates a new `MemMappingGuard` with the given `PerCPUPageMappingGuard` + /// and starting offset. + /// + /// # Arguments + /// + /// * `guard` - The `PerCPUPageMappingGuard` to associate with the `MemMappingGuard`. + /// * `start_off` - The starting offset for the memory mapping. + /// + /// # Returns + /// + /// Self is returned. + pub fn new(guard: PerCPUPageMappingGuard, start_off: usize) -> Result { + if start_off >= guard.mapping.region().len() { + Err(SvsmError::Mem) + } else { + Ok(Self { + guard, + start_off, + phantom: PhantomData, + }) + } + } + + /// Reads data from a virtual address region specified by an offset + /// + /// # Safety + /// + /// The caller must verify not to read from arbitrary memory regions. The region to read must + /// be checked to guarantee the memory is mapped by the guard and is valid for reading. + /// + /// # Arguments + /// + /// * `offset`: The offset (in unit of `size_of::()`) from the start of the virtual address + /// region to read from. + /// + /// # Returns + /// + /// This function returns a `Result` that indicates the success or failure of the operation. + /// If the read operation is successful, it returns `Ok(T)` which contains the read back data. + /// If the virtual address region cannot be retrieved, it returns `Err(SvsmError::Mem)`. + pub unsafe fn read(&self, offset: usize) -> Result { + let size = core::mem::size_of::(); + self.virt_addr_region(offset * size, size) + .map_or(Err(SvsmError::Mem), |region| { + Ok(*(region.start().as_ptr::())) + }) + } + + /// Writes data from a provided data into a virtual address region specified by an offset. + /// + /// # Safety + /// + /// The caller must verify not to write to arbitrary memory regions. The memory region to write + /// must be checked to guarantee the memory is mapped by the guard and is valid for writing. + /// + /// # Arguments + /// + /// * `offset`: The offset (in unit of `size_of::()`) from the start of the virtual address + /// region to write to. + /// * `data`: Data to write. + /// + /// # Returns + /// + /// This function returns a `Result` that indicates the success or failure of the operation. + /// If the write operation is successful, it returns `Ok(())`. If the virtual address region + /// cannot be retrieved or if the buffer size is larger than the region size, it returns + /// `Err(SvsmError::Mem)`. + pub unsafe fn write(&self, offset: usize, data: T) -> Result<(), SvsmError> { + let size = core::mem::size_of::(); + self.virt_addr_region(offset * size, size) + .map_or(Err(SvsmError::Mem), |region| { + *(region.start().as_mut_ptr::()) = data; + Ok(()) + }) + } + + fn virt_addr_region(&self, offset: usize, len: usize) -> Option> { + if len != 0 { + MemoryRegion::checked_new( + self.guard + .virt_addr() + .checked_add(self.start_off + offset)?, + len, + ) + .filter(|v| self.guard.mapping.region().contains_region(v)) + } else { + None + } + } +} + +impl InsnMachineMem for MemMappingGuard { + type Item = T; + + /// Safety: See the MemMappingGuard's read() method documentation for safety requirements. + unsafe fn mem_read(&self) -> Result { + self.read(0).map_err(|_| InsnError::MemRead) + } + + /// Safety: See the MemMappingGuard's write() method documentation for safety requirements. + unsafe fn mem_write(&mut self, data: Self::Item) -> Result<(), InsnError> { + self.write(0, data).map_err(|_| InsnError::MemWrite) + } +} diff --git a/stage2/src/mm/validate.rs b/stage2/src/mm/validate.rs new file mode 100644 index 000000000..8458e30e4 --- /dev/null +++ b/stage2/src/mm/validate.rs @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr}; +use crate::error::SvsmError; +use crate::locking::SpinLock; +use crate::mm::{virt_to_phys, PageBox}; +use crate::types::{PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::MemoryRegion; +use core::mem::MaybeUninit; +use core::num::NonZeroUsize; +use core::ptr::NonNull; + +static VALID_BITMAP: SpinLock> = SpinLock::new(None); + +fn bitmap_elems(region: MemoryRegion) -> NonZeroUsize { + NonZeroUsize::new( + region + .len() + .div_ceil(PAGE_SIZE) + .div_ceil(u64::BITS as usize), + ) + .unwrap() +} + +/// # Safety +/// +/// The caller must ensure that the given bitmap pointer is valid. +pub unsafe fn init_valid_bitmap_ptr(region: MemoryRegion, raw: NonNull) { + let len = bitmap_elems(region); + let ptr = NonNull::slice_from_raw_parts(raw, len.get()); + let bitmap = unsafe { PageBox::from_raw(ptr) }; + *VALID_BITMAP.lock() = Some(ValidBitmap::new(region, bitmap)); +} + +pub fn init_valid_bitmap_alloc(region: MemoryRegion) -> Result<(), SvsmError> { + let len = bitmap_elems(region); + let bitmap = PageBox::try_new_slice(0u64, len)?; + *VALID_BITMAP.lock() = Some(ValidBitmap::new(region, bitmap)); + + Ok(()) +} + +pub fn migrate_valid_bitmap() -> Result<(), SvsmError> { + let region = VALID_BITMAP.lock().as_ref().unwrap().region; + let len = bitmap_elems(region); + let bitmap = PageBox::try_new_uninit_slice(len)?; + + // lock again here because allocator path also takes VALID_BITMAP.lock() + VALID_BITMAP.lock().as_mut().unwrap().migrate(bitmap); + Ok(()) +} + +pub fn validated_phys_addr(paddr: PhysAddr) -> bool { + VALID_BITMAP + .lock() + .as_ref() + .map(|vb| vb.is_valid_4k(paddr)) + .unwrap_or(false) +} + +pub fn valid_bitmap_set_valid_4k(paddr: PhysAddr) { + if let Some(vb) = VALID_BITMAP.lock().as_mut() { + vb.set_valid_4k(paddr); + } +} + +pub fn valid_bitmap_clear_valid_4k(paddr: PhysAddr) { + if let Some(vb) = VALID_BITMAP.lock().as_mut() { + vb.clear_valid_4k(paddr); + } +} + +pub fn valid_bitmap_set_valid_2m(paddr: PhysAddr) { + if let Some(vb) = VALID_BITMAP.lock().as_mut() { + vb.set_valid_2m(paddr); + } +} + +pub fn valid_bitmap_clear_valid_2m(paddr: PhysAddr) { + if let Some(vb) = VALID_BITMAP.lock().as_mut() { + vb.clear_valid_2m(paddr); + } +} + +pub fn valid_bitmap_set_valid_range(paddr_begin: PhysAddr, paddr_end: PhysAddr) { + if let Some(vb) = VALID_BITMAP.lock().as_mut() { + vb.set_valid_range(paddr_begin, paddr_end); + } +} + +pub fn valid_bitmap_clear_valid_range(paddr_begin: PhysAddr, paddr_end: PhysAddr) { + if let Some(vb) = VALID_BITMAP.lock().as_mut() { + vb.clear_valid_range(paddr_begin, paddr_end); + } +} + +pub fn valid_bitmap_addr() -> PhysAddr { + VALID_BITMAP.lock().as_ref().unwrap().bitmap_addr() +} + +pub fn valid_bitmap_valid_addr(paddr: PhysAddr) -> bool { + VALID_BITMAP + .lock() + .as_ref() + .map(|vb| vb.check_addr(paddr)) + .unwrap_or(false) +} + +#[derive(Debug)] +struct ValidBitmap { + region: MemoryRegion, + bitmap: PageBox<[u64]>, +} + +impl ValidBitmap { + const fn new(region: MemoryRegion, bitmap: PageBox<[u64]>) -> Self { + Self { region, bitmap } + } + + fn check_addr(&self, paddr: PhysAddr) -> bool { + self.region.contains(paddr) + } + + fn bitmap_addr(&self) -> PhysAddr { + virt_to_phys(self.bitmap.vaddr()) + } + + #[inline(always)] + fn index(&self, paddr: PhysAddr) -> (usize, usize) { + let page_offset = (paddr - self.region.start()) / PAGE_SIZE; + let index = page_offset / 64; + let bit = page_offset % 64; + + (index, bit) + } + + fn migrate(&mut self, mut new: PageBox<[MaybeUninit]>) { + for (dst, src) in new + .iter_mut() + .zip(self.bitmap.iter().copied().chain(core::iter::repeat(0))) + { + dst.write(src); + } + // SAFETY: we initialized the contents of the whole slice + self.bitmap = unsafe { new.assume_init_slice() }; + } + + fn set_valid_4k(&mut self, paddr: PhysAddr) { + let (index, bit) = self.index(paddr); + + assert!(paddr.is_page_aligned()); + assert!(self.check_addr(paddr)); + + self.bitmap[index] |= 1u64 << bit; + } + + fn clear_valid_4k(&mut self, paddr: PhysAddr) { + let (index, bit) = self.index(paddr); + + assert!(paddr.is_page_aligned()); + assert!(self.check_addr(paddr)); + + self.bitmap[index] &= !(1u64 << bit); + } + + fn set_2m(&mut self, paddr: PhysAddr, val: u64) { + const NR_INDEX: usize = PAGE_SIZE_2M / (PAGE_SIZE * 64); + let (index, _) = self.index(paddr); + + assert!(paddr.is_aligned(PAGE_SIZE_2M)); + assert!(self.check_addr(paddr)); + + self.bitmap[index..index + NR_INDEX].fill(val); + } + + fn set_valid_2m(&mut self, paddr: PhysAddr) { + self.set_2m(paddr, !0u64); + } + + fn clear_valid_2m(&mut self, paddr: PhysAddr) { + self.set_2m(paddr, 0u64); + } + + fn modify_bitmap_word(&mut self, index: usize, mask: u64, new_val: u64) { + let val = &mut self.bitmap[index]; + *val = (*val & !mask) | (new_val & mask); + } + + fn set_range(&mut self, paddr_begin: PhysAddr, paddr_end: PhysAddr, new_val: bool) { + // All ones. + let mask = !0u64; + // All ones if val == true, zero otherwise. + let new_val = 0u64.wrapping_sub(new_val as u64); + + let (index_head, bit_head_begin) = self.index(paddr_begin); + let (index_tail, bit_tail_end) = self.index(paddr_end); + if index_head != index_tail { + let mask_head = mask >> bit_head_begin << bit_head_begin; + self.modify_bitmap_word(index_head, mask_head, new_val); + + self.bitmap[index_head + 1..index_tail].fill(new_val); + + if bit_tail_end != 0 { + let mask_tail = mask << (64 - bit_tail_end) >> (64 - bit_tail_end); + self.modify_bitmap_word(index_tail, mask_tail, new_val); + } + } else { + let mask = mask >> bit_head_begin << bit_head_begin; + let mask = mask << (64 - bit_tail_end) >> (64 - bit_tail_end); + self.modify_bitmap_word(index_head, mask, new_val); + } + } + + fn set_valid_range(&mut self, paddr_begin: PhysAddr, paddr_end: PhysAddr) { + self.set_range(paddr_begin, paddr_end, true); + } + + fn clear_valid_range(&mut self, paddr_begin: PhysAddr, paddr_end: PhysAddr) { + self.set_range(paddr_begin, paddr_end, false); + } + + fn is_valid_4k(&self, paddr: PhysAddr) -> bool { + let (index, bit) = self.index(paddr); + + assert!(self.check_addr(paddr)); + + let mask: u64 = 1u64 << bit; + self.bitmap[index] & mask == mask + } +} diff --git a/stage2/src/mm/virtualrange.rs b/stage2/src/mm/virtualrange.rs new file mode 100644 index 000000000..3fb6ee297 --- /dev/null +++ b/stage2/src/mm/virtualrange.rs @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +use crate::address::VirtAddr; +use crate::cpu::percpu::this_cpu; +use crate::error::SvsmError; +use crate::types::{PAGE_SHIFT, PAGE_SHIFT_2M, PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::bitmap_allocator::{BitmapAllocator, BitmapAllocator1024}; +use crate::utils::MemoryRegion; +use core::fmt::Debug; + +use super::{ + SVSM_PERCPU_TEMP_BASE_2M, SVSM_PERCPU_TEMP_BASE_4K, SVSM_PERCPU_TEMP_END_2M, + SVSM_PERCPU_TEMP_END_4K, +}; + +pub const VIRT_ALIGN_4K: usize = PAGE_SHIFT - 12; +pub const VIRT_ALIGN_2M: usize = PAGE_SHIFT_2M - 12; + +#[derive(Debug, Default)] +pub struct VirtualRange { + start_virt: VirtAddr, + page_count: usize, + page_shift: usize, + bits: BitmapAllocator1024, +} + +impl VirtualRange { + pub const CAPACITY: usize = BitmapAllocator1024::CAPACITY; + + pub const fn new() -> VirtualRange { + VirtualRange { + start_virt: VirtAddr::null(), + page_count: 0, + page_shift: PAGE_SHIFT, + bits: BitmapAllocator1024::new(), + } + } + + pub fn init(&mut self, start_virt: VirtAddr, page_count: usize, page_shift: usize) { + self.start_virt = start_virt; + self.page_count = page_count; + self.page_shift = page_shift; + self.bits.set(0, page_count, false); + } + + pub fn alloc(&mut self, page_count: usize, alignment: usize) -> Result { + // Always reserve an extra page to leave a guard between virtual memory allocations + match self.bits.alloc(page_count + 1, alignment) { + Some(offset) => Ok(self.start_virt + (offset << self.page_shift)), + None => Err(SvsmError::Mem), + } + } + + pub fn free(&mut self, vaddr: VirtAddr, page_count: usize) { + let offset = (vaddr - self.start_virt) >> self.page_shift; + // Add 1 to the page count for the VM guard + self.bits.free(offset, page_count + 1); + } + + pub fn used_pages(&self) -> usize { + self.bits.used() + } +} + +pub fn virt_log_usage() { + let page_count4k = (SVSM_PERCPU_TEMP_END_4K - SVSM_PERCPU_TEMP_BASE_4K) / PAGE_SIZE; + let page_count2m = (SVSM_PERCPU_TEMP_END_2M - SVSM_PERCPU_TEMP_BASE_2M) / PAGE_SIZE_2M; + let unused_cap_4k = BitmapAllocator1024::CAPACITY - page_count4k; + let unused_cap_2m = BitmapAllocator1024::CAPACITY - page_count2m; + + log::info!( + "[CPU {}] Virtual memory pages used: {} * 4K, {} * 2M", + this_cpu().get_apic_id(), + this_cpu().vrange_4k.borrow().used_pages() - unused_cap_4k, + this_cpu().vrange_2m.borrow().used_pages() - unused_cap_2m + ); +} + +#[derive(Debug)] +pub struct VRangeAlloc { + region: MemoryRegion, + huge: bool, +} + +impl VRangeAlloc { + /// Returns a virtual memory region in the 4K virtual range. + pub fn new_4k(size: usize, align: usize) -> Result { + // Each bit in our bitmap represents a 4K page + if (size & (PAGE_SIZE - 1)) != 0 { + return Err(SvsmError::Mem); + } + let page_count = size >> PAGE_SHIFT; + let addr = this_cpu().vrange_4k.borrow_mut().alloc(page_count, align)?; + let region = MemoryRegion::new(addr, size); + Ok(Self { + region, + huge: false, + }) + } + + /// Returns a virtual memory region in the 2M virtual range. + pub fn new_2m(size: usize, align: usize) -> Result { + // Each bit in our bitmap represents a 2M page + if (size & (PAGE_SIZE_2M - 1)) != 0 { + return Err(SvsmError::Mem); + } + let page_count = size >> PAGE_SHIFT_2M; + let addr = this_cpu().vrange_2m.borrow_mut().alloc(page_count, align)?; + let region = MemoryRegion::new(addr, size); + Ok(Self { region, huge: true }) + } + + /// Returns the virtual memory region that this allocation spans. + pub const fn region(&self) -> MemoryRegion { + self.region + } + + /// Returns true if the allocation was made from the huge (2M) virtual range. + pub const fn huge(&self) -> bool { + self.huge + } +} + +impl Drop for VRangeAlloc { + fn drop(&mut self) { + let region = self.region(); + if self.huge { + this_cpu() + .vrange_2m + .borrow_mut() + .free(region.start(), region.len() >> PAGE_SHIFT_2M); + } else { + this_cpu() + .vrange_4k + .borrow_mut() + .free(region.start(), region.len() >> PAGE_SHIFT); + } + } +} + +#[cfg(test)] +mod tests { + use super::VirtualRange; + use crate::address::VirtAddr; + use crate::types::{PAGE_SHIFT, PAGE_SHIFT_2M, PAGE_SIZE, PAGE_SIZE_2M}; + + #[test] + fn test_alloc_no_overlap_4k() { + let mut range = VirtualRange::new(); + range.init(VirtAddr::new(0x1000000), 1024, PAGE_SHIFT); + + // Test that we get two virtual addresses that do + // not overlap when using 4k pages. + let v1 = range.alloc(12, 0); + let v2 = range.alloc(12, 0); + let v1 = u64::from(v1.unwrap()); + let v2 = u64::from(v2.unwrap()); + + assert!(v1 < v2); + assert!((v1 + (12 * PAGE_SIZE as u64)) < v2); + } + + #[test] + fn test_alloc_no_overlap_2m() { + let mut range = VirtualRange::new(); + range.init(VirtAddr::new(0x1000000), 1024, PAGE_SHIFT_2M); + + // Test that we get two virtual addresses that do + // not overlap when using 2M pages. + let v1 = range.alloc(12, 0); + let v2 = range.alloc(12, 0); + let v1 = u64::from(v1.unwrap()); + let v2 = u64::from(v2.unwrap()); + + assert!(v1 < v2); + assert!((v1 + (12 * PAGE_SIZE_2M as u64)) < v2); + } + + #[test] + fn test_free_4k() { + let mut range = VirtualRange::new(); + range.init(VirtAddr::new(0x1000000), 1024, PAGE_SHIFT); + + // This checks that freeing an allocated range giving the size + // of the virtual region in bytes does indeed free the correct amount + // of pages for 4K ranges. + let v1 = range.alloc(26, 0).unwrap(); + // Page count will be 1 higher due to guard page. + assert_eq!(range.used_pages(), 27); + + // If the page size calculation is wrong then there will be a mismatch between + // the requested and freed page count. + range.free(v1, 12); + assert_eq!(range.used_pages(), 14); + range.free(VirtAddr::new(u64::from(v1) as usize + (13 * PAGE_SIZE)), 13); + assert_eq!(range.used_pages(), 0); + } + + #[test] + fn test_free_2m() { + let mut range = VirtualRange::new(); + range.init(VirtAddr::new(0x1000000), 1024, PAGE_SHIFT_2M); + + // This checks that freeing an allocated range giving the size + // of the virtual region in bytes does indeed free the correct amount + // of pages for 4K ranges. + let v1 = range.alloc(26, 0).unwrap(); + // Page count will be 1 higher due to guard page. + assert_eq!(range.used_pages(), 27); + + // If the page size calculation is wrong then there will be a mismatch between + // the requested and freed page count. + range.free(v1, 12); + assert_eq!(range.used_pages(), 14); + range.free( + VirtAddr::new(u64::from(v1) as usize + (13 * PAGE_SIZE_2M)), + 13, + ); + assert_eq!(range.used_pages(), 0); + } +} diff --git a/stage2/src/mm/vm/mapping/api.rs b/stage2/src/mm/vm/mapping/api.rs new file mode 100644 index 000000000..c902b0d14 --- /dev/null +++ b/stage2/src/mm/vm/mapping/api.rs @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{PhysAddr, VirtAddr}; +use crate::error::SvsmError; +use crate::locking::{RWLock, ReadLockGuard, WriteLockGuard}; +use crate::mm::pagetable::PTEntryFlags; +use crate::mm::vm::VMR; +use crate::types::{PageSize, PAGE_SHIFT}; + +use intrusive_collections::rbtree::Link; +use intrusive_collections::{intrusive_adapter, KeyAdapter}; + +use core::ops::Range; + +extern crate alloc; +use alloc::boxed::Box; +use alloc::sync::Arc; + +/// Information required to resolve a page fault within a virtual mapping +#[derive(Debug, Copy, Clone)] +pub struct VMPageFaultResolution { + /// The physical address of a page that must be mapped to the page fault + /// virtual address to resolve the page fault. + pub paddr: PhysAddr, + + /// The flags to use to map the virtual memory page. + pub flags: PTEntryFlags, +} + +pub trait VirtualMapping: core::fmt::Debug { + /// Request the size of the virtual memory mapping + /// + /// # Returns + /// + /// Mapping size. Will always be a multiple of `VirtualMapping::page_size()` + fn mapping_size(&self) -> usize; + + /// Indicates whether the mapping has any associated data. + /// + /// # Returns + /// + /// `true' if there is associated physical data, or `false' if there is + /// none. + fn has_data(&self) -> bool { + // Defaults to true + true + } + + /// Request physical address to map for a given offset + /// + /// # Arguments + /// + /// * `offset` - Offset into the virtual memory mapping + /// + /// # Returns + /// + /// Physical address to map for the given offset, if any. None is also a + /// valid return value and does not indicate an error. + fn map(&self, offset: usize) -> Option; + + /// Inform the virtual memory mapping about an offset being unmapped. + /// Implementing `unmap()` is optional. + /// + /// # Arguments + /// + /// * `_offset` + fn unmap(&self, _offset: usize) { + // Provide default in case there is nothing to do + } + + /// Request the PTEntryFlags used for this virtual memory mapping. + /// + /// # Arguments + /// + /// * 'offset' -> The offset in bytes into the `VirtualMapping`. The flags + /// returned from this function relate to the page at the + /// given offset + /// + /// # Returns + /// + /// A combination of: + + /// * PTEntryFlags::WRITABLE + /// * PTEntryFlags::NX, + /// * PTEntryFlags::ACCESSED + /// * PTEntryFlags::DIRTY + fn pt_flags(&self, offset: usize) -> PTEntryFlags; + + /// Request the page size used for mappings + /// + /// # Returns + /// + /// Either PAGE_SIZE or PAGE_SIZE_2M + fn page_size(&self) -> PageSize { + // Default to system page-size + PageSize::Regular + } + + /// Request whether the mapping is shared or private. Defaults to private + /// unless overwritten by the specific type. + /// + /// # Returns + /// + /// * `True` - When mapping is shared + /// * `False` - When mapping is private + fn shared(&self) -> bool { + // Shared with the HV - defaults not No + false + } + + /// Handle a page fault that occurred on a virtual memory address within + /// this mapping. + /// + /// # Arguments + /// + /// * 'vmr' - Virtual memory range that contains the mapping. This + /// [`VirtualMapping`] can use this to insert/remove regions + /// as necessary to handle the page fault. + /// + /// * `offset` - Offset into the virtual mapping that was the subject of + /// the page fault. + /// + /// * 'write' - `true` if the fault was due to a write to the memory + /// location, or 'false' if the fault was due to a read. + fn handle_page_fault( + &mut self, + _vmr: &VMR, + _offset: usize, + _write: bool, + ) -> Result { + Err(SvsmError::Mem) + } +} + +#[derive(Debug)] +pub struct Mapping { + mapping: RWLock>, +} + +unsafe impl Send for Mapping {} +unsafe impl Sync for Mapping {} + +impl Mapping { + pub fn new(mapping: T) -> Self + where + T: VirtualMapping + 'static, + { + Mapping { + mapping: RWLock::new(Box::new(mapping)), + } + } + + pub fn get(&self) -> ReadLockGuard<'_, Box> { + self.mapping.lock_read() + } + + pub fn get_mut(&self) -> WriteLockGuard<'_, Box> { + self.mapping.lock_write() + } +} + +/// A single mapping of virtual memory in a virtual memory range +#[derive(Debug)] +pub struct VMM { + /// Link for storing this instance in an RBTree + link: Link, + + /// The virtual memory range covered by this mapping + /// It is stored in a RefCell to check borrowing rules at runtime. + /// This is safe as any modification to `range` is protected by a lock in + /// the parent data structure. This is required because changes here also + /// need changes in the parent data structure. + range: Range, + + /// Pointer to the actual mapping + /// It is protected by an RWLock to serialize concurent accesses. + mapping: Arc, +} + +intrusive_adapter!(pub VMMAdapter = Box: VMM { link: Link }); + +impl<'a> KeyAdapter<'a> for VMMAdapter { + type Key = usize; + fn get_key(&self, node: &'a VMM) -> Self::Key { + node.range.start + } +} + +impl VMM { + /// Create a new VMM instance with at a given address and backing struct + /// + /// # Arguments + /// + /// * `start_pfn` - Virtual start pfn to store in the mapping + /// * `mapping` - `Arc` pointer to the backing struct + /// + /// # Returns + /// + /// New instance of VMM + pub fn new(start_pfn: usize, mapping: Arc) -> Self { + let size = mapping.get().mapping_size() >> PAGE_SHIFT; + VMM { + link: Link::new(), + range: Range { + start: start_pfn, + end: start_pfn + size, + }, + mapping, + } + } + + /// Request the mapped range as page frame numbers + /// + /// # Returns + /// + /// The start and end (non-inclusive) virtual address for this virtual + /// mapping, right-shifted by `PAGE_SHIFT`. + pub fn range_pfn(&self) -> (usize, usize) { + (self.range.start, self.range.end) + } + + /// Request the mapped range + /// + /// # Returns + /// + /// The start and end virtual address for this virtual mapping. + pub fn range(&self) -> (VirtAddr, VirtAddr) { + ( + VirtAddr::from(self.range.start << PAGE_SHIFT), + VirtAddr::from(self.range.end << PAGE_SHIFT), + ) + } + + pub fn get_mapping(&self) -> ReadLockGuard<'_, Box> { + self.mapping.get() + } + + pub fn get_mapping_mut(&self) -> WriteLockGuard<'_, Box> { + self.mapping.get_mut() + } + + pub fn get_mapping_clone(&self) -> Arc { + self.mapping.clone() + } +} diff --git a/stage2/src/mm/vm/mapping/file_mapping.rs b/stage2/src/mm/vm/mapping/file_mapping.rs new file mode 100644 index 000000000..f6eeb760d --- /dev/null +++ b/stage2/src/mm/vm/mapping/file_mapping.rs @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 SUSE LLC +// +// Author: Roy Hopkins + +extern crate alloc; + +use alloc::vec::Vec; + +use bitflags::bitflags; + +use super::{VMPageFaultResolution, VirtualMapping}; +use crate::address::PhysAddr; +use crate::error::SvsmError; +use crate::fs::FileHandle; +use crate::mm::vm::VMR; +use crate::mm::PageRef; +use crate::mm::{pagetable::PTEntryFlags, PAGE_SIZE}; +use crate::types::PAGE_SHIFT; +use crate::utils::align_up; + +bitflags! { + #[derive(Debug, PartialEq, Copy, Clone)] + pub struct VMFileMappingFlags : u32 { + /// Read-only access to the file + const Read = 1 << 0; + // Read/Write access to a copy of the files pages + const Write = 1 << 1; + // Read-only access that allows execution + const Execute = 1 << 2; + // Map private copies of file pages + const Private = 1 << 3; + // Map at a fixed address + const Fixed = 1 << 4; + } +} + +/// Map view of a ramfs file into virtual memory +#[derive(Debug)] +pub struct VMFileMapping { + /// The size of the mapping in bytes + size: usize, + + /// The flags to apply to the virtual mapping + flags: VMFileMappingFlags, + + /// A vec containing references to mapped pages within the file + pages: Vec, +} + +impl VMFileMapping { + /// Create a new ['VMFileMapping'] for a file. The file provides the backing + /// pages for the file contents. + /// + /// # Arguments + /// + /// * 'file' - The file to create the mapping for. This instance keeps a + /// reference to the file until it is dropped. + /// + /// * 'offset' - The offset from the start of the file to map. This must be + /// align to PAGE_SIZE. + /// + /// * 'size' - The number of bytes to map starting from the offset. This + /// must be a multiple of PAGE_SIZE. + /// + /// # Returns + /// + /// Initialized mapping on success, Err(SvsmError::Mem) on error + pub fn new( + file: &FileHandle, + offset: usize, + size: usize, + flags: VMFileMappingFlags, + ) -> Result { + let page_size = align_up(size, PAGE_SIZE); + let file_size = align_up(file.size(), PAGE_SIZE); + if (offset & (PAGE_SIZE - 1)) != 0 { + return Err(SvsmError::Mem); + } + if (page_size + offset) > file_size { + return Err(SvsmError::Mem); + } + + // Take references to the file pages + let count = page_size >> PAGE_SHIFT; + let mut pages = Vec::::new(); + for page_index in 0..count { + let page_ref = file + .mapping(offset + page_index * PAGE_SIZE) + .ok_or(SvsmError::Mem)?; + if flags.contains(VMFileMappingFlags::Private) { + pages.push(page_ref.try_copy_page()?); + } else { + pages.push(page_ref); + } + } + Ok(Self { + size: page_size, + flags, + pages, + }) + } +} + +#[cfg(not(test))] +#[cfg(test)] +fn copy_page( + _vmr: &VMR, + file: &FileHandle, + offset: usize, + paddr_dst: PhysAddr, + page_size: PageSize, +) -> Result<(), SvsmError> { + let page_size = usize::from(page_size); + // In the test environment the physical address is actually the virtual + // address. We can take advantage of this to copy the file contents into the + // mock physical address without worrying about VMRs and page tables. + let slice = unsafe { from_raw_parts_mut(paddr_dst.bits() as *mut u8, page_size) }; + file.seek(offset); + file.read(slice)?; + Ok(()) +} + +impl VirtualMapping for VMFileMapping { + fn mapping_size(&self) -> usize { + self.size + } + + fn map(&self, offset: usize) -> Option { + let page_index = offset / PAGE_SIZE; + if page_index >= self.pages.len() { + return None; + } + Some(self.pages[page_index].phys_addr()) + } + + fn pt_flags(&self, _offset: usize) -> PTEntryFlags { + let mut flags = PTEntryFlags::empty(); + + if self.flags.contains(VMFileMappingFlags::Write) { + flags |= PTEntryFlags::WRITABLE; + } + + if !self.flags.contains(VMFileMappingFlags::Execute) { + flags |= PTEntryFlags::NX; + } + + flags + } + + fn handle_page_fault( + &mut self, + _vmr: &VMR, + _offset: usize, + _write: bool, + ) -> Result { + Err(SvsmError::Mem) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + fs::{create, open, unlink, TestFileSystemGuard}, + mm::alloc::{TestRootMem, DEFAULT_TEST_MEMORY_SIZE}, + types::PAGE_SIZE, + }; + + fn create_512b_test_file() -> (FileHandle, &'static str) { + let fh = create("test1").unwrap(); + let buf = [0xffu8; 512]; + fh.write(&buf).expect("File write failed"); + (fh, "test1") + } + + fn create_16k_test_file() -> (FileHandle, &'static str) { + let fh = create("test1").unwrap(); + let mut buf = [0xffu8; PAGE_SIZE * 4]; + buf[PAGE_SIZE] = 1; + buf[PAGE_SIZE * 2] = 2; + buf[PAGE_SIZE * 3] = 3; + fh.write(&buf).expect("File write failed"); + (fh, "test1") + } + + fn create_5000b_test_file() -> (FileHandle, &'static str) { + let fh = create("test1").unwrap(); + let buf = [0xffu8; 5000]; + fh.write(&buf).expect("File write failed"); + (fh, "test1") + } + + #[test] + fn test_create_mapping() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_512b_test_file(); + let vm = VMFileMapping::new(&fh, 0, 512, VMFileMappingFlags::Read) + .expect("Failed to create new VMFileMapping"); + assert_eq!(vm.mapping_size(), PAGE_SIZE); + assert!(vm.flags.contains(VMFileMappingFlags::Read)); + assert_eq!(vm.pages.len(), 1); + unlink(name).unwrap(); + } + + #[test] + fn test_create_unaligned_offset() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + // Not page aligned + let offset = PAGE_SIZE + 0x60; + + let (fh, name) = create_16k_test_file(); + let fh2 = open(name).unwrap(); + let vm = VMFileMapping::new(&fh, offset, fh2.size() - offset, VMFileMappingFlags::Read); + assert!(vm.is_err()); + unlink(name).unwrap(); + } + + #[test] + fn test_create_size_too_large() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_16k_test_file(); + let fh2 = open(name).unwrap(); + let vm = VMFileMapping::new(&fh, 0, fh2.size() + 1, VMFileMappingFlags::Read); + assert!(vm.is_err()); + unlink(name).unwrap(); + } + + #[test] + fn test_create_offset_overflow() { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_16k_test_file(); + let fh2 = open(name).unwrap(); + let vm = VMFileMapping::new(&fh, PAGE_SIZE, fh2.size(), VMFileMappingFlags::Read); + assert!(vm.is_err()); + unlink(name).unwrap(); + } + + fn test_map_first_page(flags: VMFileMappingFlags) { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_512b_test_file(); + let vm = + VMFileMapping::new(&fh, 0, 512, flags).expect("Failed to create new VMFileMapping"); + + let res = vm + .map(0) + .expect("Mapping of first VMFileMapping page failed"); + + let fh2 = open(name).unwrap(); + assert_eq!( + fh2.mapping(0) + .expect("Failed to get file page mapping") + .phys_addr(), + res + ); + unlink(name).unwrap(); + } + + fn test_map_multiple_pages(flags: VMFileMappingFlags) { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_16k_test_file(); + let fh2 = open(name).unwrap(); + let vm = VMFileMapping::new(&fh, 0, fh2.size(), flags) + .expect("Failed to create new VMFileMapping"); + + for i in 0..4 { + let res = vm + .map(i * PAGE_SIZE) + .expect("Mapping of VMFileMapping page failed"); + + assert_eq!( + fh2.mapping(i * PAGE_SIZE) + .expect("Failed to get file page mapping") + .phys_addr(), + res + ); + } + unlink(name).unwrap(); + } + + fn test_map_unaligned_file_size(flags: VMFileMappingFlags) { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_5000b_test_file(); + let fh2 = open(name).unwrap(); + let vm = VMFileMapping::new(&fh, 0, fh2.size(), flags) + .expect("Failed to create new VMFileMapping"); + + assert_eq!(vm.mapping_size(), PAGE_SIZE * 2); + assert_eq!(vm.pages.len(), 2); + + for i in 0..2 { + let res = vm + .map(i * PAGE_SIZE) + .expect("Mapping of first VMFileMapping page failed"); + + assert_eq!( + fh2.mapping(i * PAGE_SIZE) + .expect("Failed to get file page mapping") + .phys_addr(), + res + ); + } + unlink(name).unwrap(); + } + + fn test_map_non_zero_offset(flags: VMFileMappingFlags) { + let _test_mem = TestRootMem::setup(DEFAULT_TEST_MEMORY_SIZE); + let _test_fs = TestFileSystemGuard::setup(); + + let (fh, name) = create_16k_test_file(); + let fh2 = open(name).unwrap(); + let vm = VMFileMapping::new(&fh, 2 * PAGE_SIZE, PAGE_SIZE, flags) + .expect("Failed to create new VMFileMapping"); + + assert_eq!(vm.mapping_size(), PAGE_SIZE); + assert_eq!(vm.pages.len(), 1); + + let res = vm + .map(0) + .expect("Mapping of first VMFileMapping page failed"); + + assert_eq!( + fh2.mapping(2 * PAGE_SIZE) + .expect("Failed to get file page mapping") + .phys_addr(), + res + ); + unlink(name).unwrap(); + } + + #[test] + fn test_map_first_page_readonly() { + test_map_first_page(VMFileMappingFlags::Read) + } + + #[test] + fn test_map_multiple_pages_readonly() { + test_map_multiple_pages(VMFileMappingFlags::Read) + } + + #[test] + fn test_map_unaligned_file_size_readonly() { + test_map_unaligned_file_size(VMFileMappingFlags::Read) + } + + #[test] + fn test_map_non_zero_offset_readonly() { + test_map_non_zero_offset(VMFileMappingFlags::Read) + } + + #[test] + fn test_map_first_page_readwrite() { + test_map_first_page(VMFileMappingFlags::Write) + } + + #[test] + fn test_map_multiple_pages_readwrite() { + test_map_multiple_pages(VMFileMappingFlags::Write) + } + + #[test] + fn test_map_unaligned_file_size_readwrite() { + test_map_unaligned_file_size(VMFileMappingFlags::Write) + } + + #[test] + fn test_map_non_zero_offset_readwrite() { + test_map_non_zero_offset(VMFileMappingFlags::Write) + } +} diff --git a/stage2/src/mm/vm/mapping/kernel_stack.rs b/stage2/src/mm/vm/mapping/kernel_stack.rs new file mode 100644 index 000000000..bf7da012c --- /dev/null +++ b/stage2/src/mm/vm/mapping/kernel_stack.rs @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::VirtualMapping; +use crate::address::{PhysAddr, VirtAddr}; +use crate::error::SvsmError; +use crate::mm::address_space::STACK_SIZE; +use crate::mm::pagetable::PTEntryFlags; +use crate::types::{PAGE_SHIFT, PAGE_SIZE}; +use crate::utils::{page_align_up, MemoryRegion}; + +use super::rawalloc::RawAllocMapping; +use super::Mapping; + +/// Mapping to be used as a kernel stack. This maps a stack including guard +/// pages at the top and bottom. +#[derive(Default, Debug)] +pub struct VMKernelStack { + /// Allocation for stack pages + alloc: RawAllocMapping, + /// Number of guard pages to reserve address space for + guard_pages: usize, +} + +impl VMKernelStack { + /// Returns the virtual address for the top of this kernel stack + /// + /// # Arguments + /// + /// * `base` - Virtual base address this stack is mapped at (including + /// guard pages). + /// + /// # Returns + /// + /// Virtual address to program into the hardware stack register + pub fn top_of_stack(&self, base: VirtAddr) -> VirtAddr { + let guard_size = self.guard_pages * PAGE_SIZE; + base + guard_size + self.alloc.mapping_size() + } + + /// Returns the stack bounds of this kernel stack + /// + /// # Arguments + /// + /// * `base` - Virtual base address this stack is mapped at (including + /// guard pages). + /// + /// # Returns + /// + /// A [`MemoryRegion`] object containing the bottom and top addresses for + /// the stack + pub fn bounds(&self, base: VirtAddr) -> MemoryRegion { + let mapping_size = self.alloc.mapping_size(); + let guard_size = self.guard_pages * PAGE_SIZE; + MemoryRegion::new(base + guard_size, mapping_size) + } + + /// Create a new [`VMKernelStack`] with a given size. This function will + /// already allocate the backing pages for the stack. + /// + /// # Arguments + /// + /// * `size` - Size of the kernel stack, without guard pages + /// + /// # Returns + /// + /// Initialized stack on success, Err(SvsmError::Mem) on error + pub fn new_size(size: usize) -> Result { + // Make sure size is page-aligned + let size = page_align_up(size); + // At least two guard-pages needed + let total_size = (size + 2 * PAGE_SIZE).next_power_of_two(); + let guard_pages = ((total_size - size) >> PAGE_SHIFT) / 2; + let mut stack = VMKernelStack { + alloc: RawAllocMapping::new(size), + guard_pages, + }; + stack.alloc_pages()?; + + Ok(stack) + } + + /// Create a new [`VMKernelStack`] with the default size. This function + /// will already allocate the backing pages for the stack. + /// + /// # Returns + /// + /// Initialized stack on success, Err(SvsmError::Mem) on error + pub fn new() -> Result { + VMKernelStack::new_size(STACK_SIZE) + } + + /// Create a new [`VMKernelStack`] with the default size, packed into a + /// [`Mapping`]. This function / will already allocate the backing pages for + /// the stack. + /// + /// # Returns + /// + /// Initialized Mapping to stack on success, Err(SvsmError::Mem) on error + pub fn new_mapping() -> Result { + Ok(Mapping::new(Self::new()?)) + } + + fn alloc_pages(&mut self) -> Result<(), SvsmError> { + self.alloc.alloc_pages() + } +} + +impl VirtualMapping for VMKernelStack { + fn mapping_size(&self) -> usize { + self.alloc.mapping_size() + ((self.guard_pages * 2) << PAGE_SHIFT) + } + + fn map(&self, offset: usize) -> Option { + let pfn = offset >> PAGE_SHIFT; + let guard_offset = self.guard_pages << PAGE_SHIFT; + + if pfn >= self.guard_pages { + self.alloc.map(offset - guard_offset) + } else { + None + } + } + + fn unmap(&self, offset: usize) { + let pfn = offset >> PAGE_SHIFT; + + if pfn >= self.guard_pages { + self.alloc.unmap(pfn - self.guard_pages); + } + } + + fn pt_flags(&self, _offset: usize) -> PTEntryFlags { + PTEntryFlags::WRITABLE | PTEntryFlags::NX | PTEntryFlags::ACCESSED | PTEntryFlags::DIRTY + } +} diff --git a/stage2/src/mm/vm/mapping/mod.rs b/stage2/src/mm/vm/mapping/mod.rs new file mode 100644 index 000000000..982f3e01e --- /dev/null +++ b/stage2/src/mm/vm/mapping/mod.rs @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod api; +pub mod file_mapping; +pub mod kernel_stack; +pub mod phys_mem; +pub mod rawalloc; +pub mod reserved; +pub mod vmalloc; + +pub use api::{Mapping, VMMAdapter, VMPageFaultResolution, VirtualMapping, VMM}; +pub use file_mapping::{VMFileMapping, VMFileMappingFlags}; +pub use kernel_stack::VMKernelStack; +pub use phys_mem::VMPhysMem; +pub use rawalloc::RawAllocMapping; +pub use reserved::VMReserved; +pub use vmalloc::VMalloc; diff --git a/stage2/src/mm/vm/mapping/phys_mem.rs b/stage2/src/mm/vm/mapping/phys_mem.rs new file mode 100644 index 000000000..a86017413 --- /dev/null +++ b/stage2/src/mm/vm/mapping/phys_mem.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr}; +use crate::mm::pagetable::PTEntryFlags; + +use super::{Mapping, VirtualMapping}; + +/// Map physically contiguous memory +#[derive(Default, Debug, Clone, Copy)] +pub struct VMPhysMem { + /// Physical base address to map + base: PhysAddr, + /// Number of bytes to map + size: usize, + /// Whether mapping is writable + writable: bool, +} + +impl VMPhysMem { + /// Initialize new instance of [`VMPhysMem`] + /// + /// # Arguments + /// + /// * `base` - Physical base address to map + /// * `size` - Number of bytes to map + /// * `writable` - Whether mapping is writable + /// + /// # Returns + /// + /// New instance of [`VMPhysMem`] + pub fn new(base: PhysAddr, size: usize, writable: bool) -> Self { + VMPhysMem { + base, + size, + writable, + } + } + + /// Initialize new [`Mapping`] with [`VMPhysMem`] + /// + /// # Arguments + /// + /// * `base` - Physical base address to map + /// * `size` - Number of bytes to map + /// * `writable` - Whether mapping is writable + /// + /// # Returns + /// + /// New [`Mapping`] containing [`VMPhysMem`] + pub fn new_mapping(base: PhysAddr, size: usize, writable: bool) -> Mapping { + Mapping::new(Self::new(base, size, writable)) + } +} + +impl VirtualMapping for VMPhysMem { + fn mapping_size(&self) -> usize { + self.size + } + + fn map(&self, offset: usize) -> Option { + if offset < self.size { + Some((self.base + offset).page_align()) + } else { + None + } + } + + fn pt_flags(&self, _offset: usize) -> PTEntryFlags { + PTEntryFlags::NX + | PTEntryFlags::ACCESSED + | if self.writable { + PTEntryFlags::WRITABLE | PTEntryFlags::DIRTY + } else { + PTEntryFlags::empty() + } + } +} diff --git a/stage2/src/mm/vm/mapping/rawalloc.rs b/stage2/src/mm/vm/mapping/rawalloc.rs new file mode 100644 index 000000000..3246be531 --- /dev/null +++ b/stage2/src/mm/vm/mapping/rawalloc.rs @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use core::iter; + +use crate::address::PhysAddr; +use crate::error::SvsmError; +use crate::mm::alloc::PageRef; +use crate::types::{PAGE_SHIFT, PAGE_SIZE}; +use crate::utils::align_up; + +extern crate alloc; +use alloc::vec::Vec; + +/// Contains base functionality for all [`VirtualMapping`](super::api::VirtualMapping) +/// types which use self-allocated PageFile pages. +#[derive(Default, Debug)] +pub struct RawAllocMapping { + /// A vec containing references to PageFile allocations + pages: Vec>, + + /// Number of pages required in `pages` + count: usize, +} + +impl RawAllocMapping { + /// Creates a new instance of RawAllocMapping + /// + /// # Arguments + /// + /// * `size` - Size of the mapping in bytes + /// + /// # Returns + /// + /// New instance of RawAllocMapping. Still needs to call `alloc_pages()` on it before it can be used. + pub fn new(size: usize) -> Self { + let count = align_up(size, PAGE_SIZE) >> PAGE_SHIFT; + let pages: Vec> = iter::repeat(None).take(count).collect(); + RawAllocMapping { pages, count } + } + + /// Allocates a single backing page of type PageFile if the page has not already + /// been allocated + /// + /// # Argument + /// + /// * 'offset' - The offset in bytes from the start of the mapping + /// + /// # Returns + /// + /// `Ok(())` if the page has been allocated, `Err(SvsmError::Mem)` otherwise + pub fn alloc_page(&mut self, offset: usize) -> Result<(), SvsmError> { + let index = offset >> PAGE_SHIFT; + if index < self.count { + let entry = self.pages.get_mut(index).ok_or(SvsmError::Mem)?; + entry.get_or_insert(PageRef::new()?); + } + Ok(()) + } + + /// Allocates a full set of backing pages of type PageFile + /// + /// # Returns + /// + /// `Ok(())` when all pages could be allocated, `Err(SvsmError::Mem)` otherwise + pub fn alloc_pages(&mut self) -> Result<(), SvsmError> { + for index in 0..self.count { + self.alloc_page(index * PAGE_SIZE)?; + } + Ok(()) + } + + /// Request size of the mapping in bytes + /// + /// # Returns + /// + /// The size of the mapping in bytes as `usize`. + pub fn mapping_size(&self) -> usize { + self.count * PAGE_SIZE + } + + /// Request physical address to map for a given offset + /// + /// # Arguments + /// + /// * `offset` - Byte offset into the memory mapping + /// + /// # Returns + /// + /// Physical address to map for the given offset. + pub fn map(&self, offset: usize) -> Option { + let pfn = offset >> PAGE_SHIFT; + self.pages + .get(pfn) + .and_then(|r| r.as_ref().map(|r| r.phys_addr())) + } + + /// Unmap call-back - currently nothing to do in this function + /// + /// # Arguments + /// + /// * `_offset` - Byte offset into the mapping + pub fn unmap(&self, _offset: usize) { + // Nothing to do for now + } + + /// Check if a page has been allocated + /// + /// # Arguments + /// + /// * 'offset' - Byte offset into the mapping + /// + /// # Returns + /// + /// 'true' if the page containing the offset has been allocated + /// otherwise 'false'. + pub fn present(&self, offset: usize) -> bool { + let pfn = offset >> PAGE_SHIFT; + self.pages.get(pfn).and_then(|r| r.as_ref()).is_some() + } +} diff --git a/stage2/src/mm/vm/mapping/reserved.rs b/stage2/src/mm/vm/mapping/reserved.rs new file mode 100644 index 000000000..5bd9b298b --- /dev/null +++ b/stage2/src/mm/vm/mapping/reserved.rs @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::PhysAddr; +use crate::mm::pagetable::PTEntryFlags; + +use super::{Mapping, VirtualMapping}; + +/// Reserve a region of address space so that no other mapping will be +/// established there. The map function for this type will always return +/// `None`. +#[derive(Default, Debug, Clone, Copy)] +pub struct VMReserved { + /// Size in bytes to reserve. Must be aligned to PAGE_SIZE + size: usize, +} + +impl VMReserved { + /// Create new instance of VMReserved + /// + /// # Arguments + /// + /// * `size` - Number of bytes to reserve + /// + /// # Returns + /// + /// New instance of VMReserved + pub fn new(size: usize) -> Self { + VMReserved { size } + } + + /// Create new [`Mapping`] of [`VMReserved`] + /// + /// # Arguments + /// + /// * `size` - Number of bytes to reserve + /// + /// # Returns + /// + /// New Mapping of VMReserved + pub fn new_mapping(size: usize) -> Mapping { + Mapping::new(Self::new(size)) + } +} + +impl VirtualMapping for VMReserved { + fn mapping_size(&self) -> usize { + self.size + } + + fn has_data(&self) -> bool { + false + } + + fn map(&self, _offset: usize) -> Option { + None + } + + fn pt_flags(&self, _offset: usize) -> PTEntryFlags { + PTEntryFlags::NX | PTEntryFlags::ACCESSED | PTEntryFlags::WRITABLE | PTEntryFlags::DIRTY + } +} diff --git a/stage2/src/mm/vm/mapping/vmalloc.rs b/stage2/src/mm/vm/mapping/vmalloc.rs new file mode 100644 index 000000000..d4b03316c --- /dev/null +++ b/stage2/src/mm/vm/mapping/vmalloc.rs @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::PhysAddr; +use crate::error::SvsmError; +use crate::mm::pagetable::PTEntryFlags; + +use super::rawalloc::RawAllocMapping; +use super::{Mapping, VMFileMappingFlags, VirtualMapping}; + +/// Virtual mapping backed by allocated pages. This can be used for memory +/// allocation if there is no need for the memory to be physically contiguous. +/// +/// This is a wrapper around RawAllocMapping. +#[derive(Default, Debug)] +pub struct VMalloc { + /// [`RawAllocMapping`] used for memory allocation + alloc: RawAllocMapping, + /// Page-table flags to map pages + flags: PTEntryFlags, +} + +impl VMalloc { + /// Create a new instance and allocate backing memory + /// + /// # Arguments + /// + /// * `size` - Size of the mapping. Must be aligned to PAGE_SIZE + /// + /// # Returns + /// + /// New instance on success, Err(SvsmError::Mem) on error + pub fn new(size: usize, flags: VMFileMappingFlags) -> Result { + let mut vmalloc = VMalloc { + alloc: RawAllocMapping::new(size), + flags: PTEntryFlags::ACCESSED | PTEntryFlags::DIRTY, + }; + + if flags.contains(VMFileMappingFlags::Write) { + vmalloc.flags |= PTEntryFlags::WRITABLE; + } + + if !flags.contains(VMFileMappingFlags::Execute) { + vmalloc.flags |= PTEntryFlags::NX; + } + + vmalloc.alloc_pages()?; + Ok(vmalloc) + } + + /// Create a new [`Mapping`] of [`VMalloc`] and allocate backing memory + /// + /// # Arguments + /// + /// * `size` - Size of the mapping. Must be aligned to PAGE_SIZE + /// + /// # Returns + /// + /// New [`Mapping`] on success, Err(SvsmError::Mem) on error + pub fn new_mapping(size: usize, flags: VMFileMappingFlags) -> Result { + Ok(Mapping::new(Self::new(size, flags)?)) + } + + fn alloc_pages(&mut self) -> Result<(), SvsmError> { + self.alloc.alloc_pages() + } +} + +impl VirtualMapping for VMalloc { + fn mapping_size(&self) -> usize { + self.alloc.mapping_size() + } + + fn map(&self, offset: usize) -> Option { + self.alloc.map(offset) + } + + fn unmap(&self, offset: usize) { + self.alloc.unmap(offset); + } + + fn pt_flags(&self, _offset: usize) -> PTEntryFlags { + self.flags + } +} diff --git a/stage2/src/mm/vm/mod.rs b/stage2/src/mm/vm/mod.rs new file mode 100644 index 000000000..78815370a --- /dev/null +++ b/stage2/src/mm/vm/mod.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +mod mapping; +mod range; + +pub use mapping::{ + Mapping, RawAllocMapping, VMFileMapping, VMFileMappingFlags, VMKernelStack, VMMAdapter, + VMPhysMem, VMReserved, VMalloc, VirtualMapping, VMM, +}; +pub use range::{VMRMapping, VMR, VMR_GRANULE}; diff --git a/stage2/src/mm/vm/range.rs b/stage2/src/mm/vm/range.rs new file mode 100644 index 000000000..ddd31d87d --- /dev/null +++ b/stage2/src/mm/vm/range.rs @@ -0,0 +1,519 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, VirtAddr}; +use crate::cpu::{flush_tlb_global_percpu, flush_tlb_global_sync}; +use crate::error::SvsmError; +use crate::locking::RWLock; +use crate::mm::pagetable::{PTEntryFlags, PageTable, PageTablePart}; +use crate::types::{PageSize, PAGE_SHIFT, PAGE_SIZE}; +use crate::utils::{align_down, align_up}; + +use core::cmp::max; + +use intrusive_collections::rbtree::{CursorMut, RBTree}; +use intrusive_collections::Bound; + +use super::{Mapping, VMMAdapter, VMM}; + +extern crate alloc; +use alloc::boxed::Box; +use alloc::sync::Arc; +use alloc::vec::Vec; + +/// Granularity of ranges mapped by [`struct VMR`]. The mapped region of a +/// [`struct VMR`] is always a multiple of this constant. +/// One [`VMR_GRANULE`] covers one top-level page-table entry on x86-64 with +/// 4-level paging. +pub const VMR_GRANULE: usize = PAGE_SIZE * 512 * 512 * 512; + +/// Virtual Memory Region +/// +/// This struct manages the mappings in a region of the virtual address space. +/// The region size is a multiple of 512GiB so that every region will fully +/// allocate one or more top-level page-table entries on x86-64. For the same +/// reason the start address must also be aligned to 512GB. +#[derive(Debug)] +pub struct VMR { + /// Start address of this range as virtual PFN (VirtAddr >> PAGE_SHIFT). + /// Virtual address must be aligned to [`VMR_GRANULE`] (512GB on x86-64). + start_pfn: usize, + + /// End address of this range as virtual PFN (VirtAddr >> PAGE_SHIFT) + /// Virtual address must be aligned to [`VMR_GRANULE`] (512GB on x86-64). + end_pfn: usize, + + /// RBTree containing all [`struct VMM`] instances with valid mappings in + /// the covered virtual address region. The [`struct VMM`]s are sorted by + /// their start address and stored in an RBTree for faster lookup. + tree: RWLock>, + + /// [`struct PageTableParts`] needed to map this VMR into a page-table. + /// There is one [`struct PageTablePart`] per [`VMR_GRANULE`] covered by + /// the region. + pgtbl_parts: RWLock>, + + /// [`PTEntryFlags`] global to all mappings in this region. This is a + /// combination of [`PTEntryFlags::GLOBAL`] and [`PTEntryFlags::USER`]. + pt_flags: PTEntryFlags, + + /// Indicates that this [`struct VMR`] is visible only on a single CPU + /// and therefore TLB flushes do not require broadcast. + per_cpu: bool, +} + +impl VMR { + /// Creates a new [`struct VMR`] + /// + /// # Arguments + /// + /// * `start` - Virtual start address for the memory region. Must be aligned to [`VMR_GRANULE`] + /// * `end` - Virtual end address (non-inclusive) for the memory region. + /// Must be bigger than `start` and aligned to [`VMR_GRANULE`]. + /// * `flags` - Global [`PTEntryFlags`] to use for this [`struct VMR`]. + /// + /// # Returns + /// + /// A new instance of [`struct VMR`]. + pub fn new(start: VirtAddr, end: VirtAddr, flags: PTEntryFlags) -> Self { + // Global and User are per VMR flags + VMR { + start_pfn: start.pfn(), + end_pfn: end.pfn(), + tree: RWLock::new(RBTree::new(VMMAdapter::new())), + pgtbl_parts: RWLock::new(Vec::new()), + pt_flags: flags, + per_cpu: false, + } + } + + /// Marks a [`struct VMR`] as being associated with only a single CPU + /// so that TLB flushes do not require broadcast. + pub fn set_per_cpu(&mut self, per_cpu: bool) { + self.per_cpu = per_cpu; + } + + /// Allocated all [`PageTablePart`]s needed to map this region + /// + /// # Returns + /// + /// `Ok(())` on success, Err(SvsmError::Mem) on allocation error + fn alloc_page_tables(&self, lazy: bool) -> Result<(), SvsmError> { + let start = VirtAddr::from(self.start_pfn << PAGE_SHIFT); + let end = VirtAddr::from(self.end_pfn << PAGE_SHIFT); + let count = end.to_pgtbl_idx::<3>() - start.to_pgtbl_idx::<3>(); + let mut vec = self.pgtbl_parts.lock_write(); + + for idx in 0..count { + let mut part = PageTablePart::new(start + (idx * VMR_GRANULE)); + if !lazy { + part.alloc(); + } + vec.push(part); + } + + Ok(()) + } + + /// Populate [`PageTablePart`]s of the [`VMR`] into a page-table + /// + /// # Arguments + /// + /// * `pgtbl` - A [`PageTable`] pointing to the target page-table + pub fn populate(&self, pgtbl: &mut PageTable) { + let parts = self.pgtbl_parts.lock_read(); + + for part in parts.iter() { + pgtbl.populate_pgtbl_part(part); + } + } + + pub fn populate_addr(&self, pgtbl: &mut PageTable, vaddr: VirtAddr) { + let start = VirtAddr::from(self.start_pfn << PAGE_SHIFT); + let end = VirtAddr::from(self.end_pfn << PAGE_SHIFT); + assert!(vaddr >= start && vaddr < end); + + let idx = vaddr.to_pgtbl_idx::<3>() - start.to_pgtbl_idx::<3>(); + let parts = self.pgtbl_parts.lock_read(); + pgtbl.populate_pgtbl_part(&parts[idx]); + } + + /// Initialize this [`VMR`] by checking the `start` and `end` values and + /// allocating the [`PageTablePart`]s required for the mappings. + /// + /// # Arguments + /// + /// * `lazy` - When `true`, use lazy allocation of [`PageTablePart`] pages. + /// + /// # Returns + /// + /// `Ok(())` on success, Err(SvsmError::Mem) on allocation error + fn initialize_common(&self, lazy: bool) -> Result<(), SvsmError> { + let start = VirtAddr::from(self.start_pfn << PAGE_SHIFT); + let end = VirtAddr::from(self.end_pfn << PAGE_SHIFT); + assert!(start < end && start.is_aligned(VMR_GRANULE) && end.is_aligned(VMR_GRANULE)); + + self.alloc_page_tables(lazy) + } + + /// Initialize this [`VMR`] by calling `VMR::initialize_common` with `lazy = false` + /// + /// # Returns + /// + /// `Ok(())` on success, Err(SvsmError::Mem) on allocation error + pub fn initialize(&self) -> Result<(), SvsmError> { + self.initialize_common(false) + } + + /// Initialize this [`VMR`] by calling `VMR::initialize_common` with `lazy = true` + /// + /// # Returns + /// + /// `Ok(())` on success, Err(SvsmError::Mem) on allocation error + pub fn initialize_lazy(&self) -> Result<(), SvsmError> { + self.initialize_common(true) + } + + /// Returns the virtual start and end addresses for this region + /// + /// # Returns + /// + /// Tuple containing `start` and `end` virtual address of the memory region + fn virt_range(&self) -> (VirtAddr, VirtAddr) { + ( + VirtAddr::from(self.start_pfn << PAGE_SHIFT), + VirtAddr::from(self.end_pfn << PAGE_SHIFT), + ) + } + + /// Map a [`VMM`] into the [`PageTablePart`]s of this region + /// + /// # Arguments + /// + /// - `vmm` - Reference to a [`VMM`] instance to map into the page-table + /// + /// # Returns + /// + /// `Ok(())` on success, Err(SvsmError::Mem) on allocation error + fn map_vmm(&self, vmm: &VMM) -> Result<(), SvsmError> { + let (rstart, _) = self.virt_range(); + let (vmm_start, vmm_end) = vmm.range(); + let mut pgtbl_parts = self.pgtbl_parts.lock_write(); + let mapping = vmm.get_mapping(); + let mut offset: usize = 0; + let page_size = mapping.page_size(); + let shared = mapping.shared(); + + // Exit early if the mapping has no data. + if !mapping.has_data() { + return Ok(()); + } + + while vmm_start + offset < vmm_end { + let idx = PageTable::index::<3>(VirtAddr::from(vmm_start - rstart)); + if let Some(paddr) = mapping.map(offset) { + let pt_flags = self.pt_flags | mapping.pt_flags(offset) | PTEntryFlags::PRESENT; + match page_size { + PageSize::Regular => { + pgtbl_parts[idx].map_4k(vmm_start + offset, paddr, pt_flags, shared)? + } + PageSize::Huge => { + pgtbl_parts[idx].map_2m(vmm_start + offset, paddr, pt_flags, shared)? + } + } + } + offset += usize::from(page_size); + } + + Ok(()) + } + + /// Unmap a [`VMM`] from the [`PageTablePart`]s of this region + /// + /// # Arguments + /// + /// - `vmm` - Reference to a [`VMM`] instance to unmap from the page-table + fn unmap_vmm(&self, vmm: &VMM) { + let (rstart, _) = self.virt_range(); + let (vmm_start, vmm_end) = vmm.range(); + let mut pgtbl_parts = self.pgtbl_parts.lock_write(); + let mapping = vmm.get_mapping(); + let page_size = mapping.page_size(); + let mut offset: usize = 0; + + while vmm_start + offset < vmm_end { + let idx = PageTable::index::<3>(VirtAddr::from(vmm_start - rstart)); + let result = match page_size { + PageSize::Regular => pgtbl_parts[idx].unmap_4k(vmm_start + offset), + PageSize::Huge => pgtbl_parts[idx].unmap_2m(vmm_start + offset), + }; + + if result.is_some() { + mapping.unmap(offset); + } + + offset += usize::from(page_size); + } + } + + fn do_insert( + &self, + mapping: Arc, + start_pfn: usize, + cursor: &mut CursorMut<'_, VMMAdapter>, + ) -> Result<(), SvsmError> { + let vmm = Box::new(VMM::new(start_pfn, mapping)); + if let Err(e) = self.map_vmm(&vmm) { + self.unmap_vmm(&vmm); + Err(e) + } else { + cursor.insert_before(vmm); + Ok(()) + } + } + + /// Inserts [`VMM`] at a specified virtual base address. This method + /// checks that the [`VMM`] does not overlap with any other region. + /// + /// # Arguments + /// + /// * `vaddr` - Virtual base address to map the [`VMM`] at + /// * `mapping` - `Rc` pointer to the VMM to insert + /// + /// # Returns + /// + /// Base address where the [`VMM`] was inserted on success or SvsmError::Mem on error + pub fn insert_at(&self, vaddr: VirtAddr, mapping: Arc) -> Result { + // mapping-size needs to be page-aligned + let size = mapping.get().mapping_size() >> PAGE_SHIFT; + let start_pfn = vaddr.pfn(); + let mut tree = self.tree.lock_write(); + let mut cursor = tree.upper_bound_mut(Bound::Included(&start_pfn)); + let mut start = self.start_pfn; + let mut end = self.end_pfn; + + if cursor.is_null() { + cursor = tree.front_mut(); + } else { + let (_, node_end) = cursor.get().unwrap().range_pfn(); + start = node_end; + cursor.move_next(); + } + + if let Some(node) = cursor.get() { + let (node_start, _) = node.range_pfn(); + end = node_start; + } + + let end_pfn = start_pfn + size; + + if start_pfn >= start && end_pfn <= end { + self.do_insert(mapping, start_pfn, &mut cursor)?; + Ok(vaddr) + } else { + Err(SvsmError::Mem) + } + } + + /// Inserts [`VMM`] with the specified alignment. This method walks the + /// RBTree to search for a suitable region. + /// + /// # Arguments + /// + /// * `mapping` - `Rc` pointer to the VMM to insert + /// * `align` - Alignment to use for tha mapping + /// + /// # Returns + /// + /// Base address where the [`VMM`] was inserted on success or SvsmError::Mem on error + pub fn insert_aligned( + &self, + hint: VirtAddr, + mapping: Arc, + align: usize, + ) -> Result { + assert!(align.is_power_of_two()); + + let size = mapping + .get() + .mapping_size() + .checked_next_power_of_two() + .unwrap_or(0) + >> PAGE_SHIFT; + let align = align >> PAGE_SHIFT; + + let start_pfn = max(self.start_pfn, hint.pfn()); + + let mut start = align_up(start_pfn, align); + let mut end = start; + + if size == 0 || start_pfn >= self.end_pfn { + return Err(SvsmError::Mem); + } + + let mut tree = self.tree.lock_write(); + let mut cursor = tree.upper_bound_mut(Bound::Included(&start_pfn)); + if cursor.is_null() { + cursor = tree.front_mut(); + } + + while let Some(node) = cursor.get() { + let (node_start, node_end) = node.range_pfn(); + end = node_start; + if end > start && end - start >= size { + break; + } + + start = max(start, align_up(node_end, align)); + cursor.move_next(); + } + + if cursor.is_null() { + end = align_down(self.end_pfn, align); + } + + if end > start && end - start >= size { + self.do_insert(mapping, start, &mut cursor)?; + Ok(VirtAddr::from(start << PAGE_SHIFT)) + } else { + Err(SvsmError::Mem) + } + } + + /// Inserts [`VMM`] into the virtual memory region. This method takes the + /// next power-of-two larger of the mapping size and uses that as the + /// alignment for the mappings base address. The search for the base + /// address starts at `addr`. With that it calls [`VMR::insert_aligned`]. + /// + /// # Arguments + /// + /// * `addr` - The virtual address at which the search for a mapping area + /// starts + /// * `mapping` - `Arc` pointer to the VMM to insert + /// + /// # Returns + /// + /// Base address where the [`VMM`] was inserted on success or SvsmError::Mem on error + pub fn insert_hint( + &self, + addr: VirtAddr, + mapping: Arc, + ) -> Result { + let align = mapping.get().mapping_size().next_power_of_two(); + self.insert_aligned(addr, mapping, align) + } + + /// Inserts [`VMM`] into the virtual memory region. It searches from the + /// beginning of the [`VMR`] region for a suitable slot. + /// + /// # Arguments + /// + /// * `mapping` - `Rc` pointer to the VMM to insert + /// + /// # Returns + /// + /// Base address where the [`VMM`] was inserted on success or SvsmError::Mem on error + pub fn insert(&self, mapping: Arc) -> Result { + self.insert_hint(VirtAddr::new(0), mapping) + } + + /// Removes the mapping from a given base address from the RBTree + /// + /// # Arguments + /// + /// * `base` - Virtual base address of the [`VMM`] to remove + /// + /// # Returns + /// + /// The removed mapping on success, SvsmError::Mem on error + pub fn remove(&self, base: VirtAddr) -> Result, SvsmError> { + let mut tree = self.tree.lock_write(); + let addr = base.pfn(); + + let mut cursor = tree.find_mut(&addr); + if let Some(node) = cursor.get() { + self.unmap_vmm(node); + if self.per_cpu { + flush_tlb_global_percpu(); + } else { + flush_tlb_global_sync(); + } + } + cursor.remove().ok_or(SvsmError::Mem) + } + + /// Dump all [`VMM`] mappings in the RBTree. This function is included for + /// debugging purposes. And should not be called in production code. + pub fn dump_ranges(&self) { + let tree = self.tree.lock_read(); + for elem in tree.iter() { + let (start_pfn, end_pfn) = elem.range_pfn(); + log::info!( + "VMRange {:#018x}-{:#018x}", + start_pfn << PAGE_SHIFT, + end_pfn << PAGE_SHIFT + ); + } + } + + /// Notify the range that a page fault has occurred. This should be called from + /// the page fault handler. The mappings withing this virtual memory region are + /// examined and if they overlap with the page fault address then + /// [`VMR::handle_page_fault()`] is called to handle the page fault within that + /// range. + /// + /// # Arguments + /// + /// * `vaddr` - Virtual memory address that was the subject of the page fault + /// + /// * 'write' - 'true' if a write was attempted. 'false' if a read was attempted. + /// + /// # Returns + /// + /// '()' if the page fault was successfully handled. + /// + /// 'SvsmError::Mem' if the page fault should propogate to the next handler. + pub fn handle_page_fault(&self, vaddr: VirtAddr, _write: bool) -> Result<(), SvsmError> { + // Get the mapping that contains the faulting address and check if the + // fault happened on a mapped part of the range. + + let tree = self.tree.lock_read(); + let pfn = vaddr.pfn(); + let cursor = tree.upper_bound(Bound::Included(&pfn)); + let node = cursor.get().ok_or(SvsmError::Mem)?; + let (start, end) = node.range(); + if vaddr < start || vaddr >= end { + return Err(SvsmError::Mem); + } + + Ok(()) + } +} + +#[derive(Debug)] +pub struct VMRMapping<'a> { + vmr: &'a VMR, + va: VirtAddr, +} + +impl<'a> VMRMapping<'a> { + pub fn new(vmr: &'a VMR, mapping: Arc) -> Result { + let va = vmr.insert(mapping)?; + Ok(Self { vmr, va }) + } + + pub fn virt_addr(&self) -> VirtAddr { + self.va + } +} + +impl Drop for VMRMapping<'_> { + fn drop(&mut self) { + self.vmr + .remove(self.va) + .expect("Error removing VRMapping virtual memory range"); + } +} diff --git a/stage2/src/platform/guest_cpu.rs b/stage2/src/platform/guest_cpu.rs new file mode 100644 index 000000000..b2a43bca7 --- /dev/null +++ b/stage2/src/platform/guest_cpu.rs @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +pub trait GuestCpuState { + fn get_tpr(&self) -> u8; + fn set_tpr(&mut self, tpr: u8); + fn request_nmi(&mut self); + fn queue_interrupt(&mut self, irq: u8); + fn try_deliver_interrupt_immediately(&mut self, irq: u8) -> bool; + fn in_intr_shadow(&self) -> bool; + fn interrupts_enabled(&self) -> bool; + fn check_and_clear_pending_nmi(&mut self) -> bool; + fn check_and_clear_pending_interrupt_event(&mut self) -> u8; + fn check_and_clear_pending_virtual_interrupt(&mut self) -> u8; + fn disable_alternate_injection(&mut self); +} diff --git a/stage2/src/platform/mod.rs b/stage2/src/platform/mod.rs new file mode 100644 index 000000000..ae32f9a38 --- /dev/null +++ b/stage2/src/platform/mod.rs @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange + +use core::ops::{Deref, DerefMut}; + +use crate::address::{PhysAddr, VirtAddr}; +use crate::cpu::cpuid::CpuidResult; +use crate::cpu::percpu::PerCpu; +use crate::error::SvsmError; +use crate::io::IOPort; +use crate::platform::native::NativePlatform; +use crate::platform::snp::SnpPlatform; +use crate::platform::tdp::TdpPlatform; +use crate::types::PageSize; +use crate::utils; +use crate::utils::immut_after_init::ImmutAfterInitCell; +use crate::utils::MemoryRegion; + +use bootlib::platform::SvsmPlatformType; + +pub mod guest_cpu; +pub mod native; +pub mod snp; +pub mod tdp; + +static SVSM_PLATFORM_TYPE: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); +pub static SVSM_PLATFORM: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +#[derive(Clone, Copy, Debug)] +pub struct PageEncryptionMasks { + pub private_pte_mask: usize, + pub shared_pte_mask: usize, + pub addr_mask_width: u32, + pub phys_addr_sizes: u32, +} + +#[derive(Debug, Clone, Copy)] +pub enum PageStateChangeOp { + Private, + Shared, + Psmash, + Unsmash, +} + +#[derive(Debug, Clone, Copy)] +pub enum PageValidateOp { + Validate, + Invalidate, +} + +/// This defines a platform abstraction to permit the SVSM to run on different +/// underlying architectures. +pub trait SvsmPlatform { + /// Halts the system as required by the platform. + fn halt() + where + Self: Sized, + { + utils::halt(); + } + + /// Performs basic early initialization of the runtime environment. + fn env_setup(&mut self, debug_serial_port: u16, vtom: usize) -> Result<(), SvsmError>; + + /// Performs initialization of the platform runtime environment after + /// the core system environment has been initialized. + fn env_setup_late(&mut self, debug_serial_port: u16) -> Result<(), SvsmError>; + + /// Performs initialiation of the environment specfic to the SVSM kernel + /// (for services not used by stage2). + fn env_setup_svsm(&self) -> Result<(), SvsmError>; + + /// Completes initialization of a per-CPU object during construction. + fn setup_percpu(&self, cpu: &PerCpu) -> Result<(), SvsmError>; + + /// Completes initialization of a per-CPU object on the target CPU. + fn setup_percpu_current(&self, cpu: &PerCpu) -> Result<(), SvsmError>; + + /// Determines the paging encryption masks for the current architecture. + fn get_page_encryption_masks(&self) -> PageEncryptionMasks; + + /// Obtain CPUID using platform-specific tables. + fn cpuid(&self, eax: u32) -> Option; + + /// Establishes state required for guest/host communication. + fn setup_guest_host_comm(&mut self, cpu: &PerCpu, is_bsp: bool); + + /// Obtains a reference to an I/O port implemetation appropriate to the + /// platform. + fn get_io_port(&self) -> &'static dyn IOPort; + + /// Performs a page state change between private and shared states. + fn page_state_change( + &self, + region: MemoryRegion, + size: PageSize, + op: PageStateChangeOp, + ) -> Result<(), SvsmError>; + + /// Marks a physical range of pages as valid or invalid for use as private + /// pages. Not usable in stage2. + fn validate_physical_page_range( + &self, + region: MemoryRegion, + op: PageValidateOp, + ) -> Result<(), SvsmError>; + + /// Marks a virtual range of pages as valid or invalid for use as private + /// pages. Provided primarily for use in stage2 where validation by + /// physical address cannot be supported. + fn validate_virtual_page_range( + &self, + region: MemoryRegion, + op: PageValidateOp, + ) -> Result<(), SvsmError>; + + /// Configures the use of alternate injection as requested. + fn configure_alternate_injection(&mut self, alt_inj_requested: bool) -> Result<(), SvsmError>; + + /// Changes the state of APIC registration on this system, returning either + /// the current registration state or an error. + fn change_apic_registration_state(&self, incr: bool) -> Result; + + /// Queries the state of APIC registration on this system. + fn query_apic_registration_state(&self) -> bool; + + /// Determines whether the platform supports interrupts to the SVSM. + fn use_interrupts(&self) -> bool; + + /// Signal an IRQ on one or more CPUs. + fn post_irq(&self, icr: u64) -> Result<(), SvsmError>; + + /// Perform an EOI of the current interrupt. + fn eoi(&self); + + /// Determines whether a given interrupt vector was invoked as an external + /// interrupt. + fn is_external_interrupt(&self, vector: usize) -> bool; + + /// Start an additional processor. + fn start_cpu(&self, cpu: &PerCpu, start_rip: u64) -> Result<(), SvsmError>; +} + +//FIXME - remove Copy trait +#[derive(Clone, Copy, Debug)] +pub enum SvsmPlatformCell { + Snp(SnpPlatform), + Tdp(TdpPlatform), + Native(NativePlatform), +} + +impl SvsmPlatformCell { + pub fn new(platform_type: SvsmPlatformType) -> Self { + assert_eq!(platform_type, *SVSM_PLATFORM_TYPE); + match platform_type { + SvsmPlatformType::Native => SvsmPlatformCell::Native(NativePlatform::new()), + SvsmPlatformType::Snp => SvsmPlatformCell::Snp(SnpPlatform::new()), + SvsmPlatformType::Tdp => SvsmPlatformCell::Tdp(TdpPlatform::new()), + } + } +} + +impl Deref for SvsmPlatformCell { + type Target = dyn SvsmPlatform; + + fn deref(&self) -> &Self::Target { + match self { + SvsmPlatformCell::Native(platform) => platform, + SvsmPlatformCell::Snp(platform) => platform, + SvsmPlatformCell::Tdp(platform) => platform, + } + } +} + +impl DerefMut for SvsmPlatformCell { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + SvsmPlatformCell::Native(platform) => platform, + SvsmPlatformCell::Snp(platform) => platform, + SvsmPlatformCell::Tdp(platform) => platform, + } + } +} + +pub fn init_platform_type(platform_type: SvsmPlatformType) { + SVSM_PLATFORM_TYPE.init(&platform_type).unwrap(); +} + +pub fn halt() { + // Use a platform-specific halt. However, the SVSM_PLATFORM global may not + // yet be initialized, so go choose the halt implementation based on the + // platform-specific halt instead. + match *SVSM_PLATFORM_TYPE { + SvsmPlatformType::Native => NativePlatform::halt(), + SvsmPlatformType::Snp => SnpPlatform::halt(), + SvsmPlatformType::Tdp => TdpPlatform::halt(), + } +} diff --git a/stage2/src/platform/native.rs b/stage2/src/platform/native.rs new file mode 100644 index 000000000..17f3d4421 --- /dev/null +++ b/stage2/src/platform/native.rs @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange + +use crate::address::{PhysAddr, VirtAddr}; +use crate::console::init_svsm_console; +use crate::cpu::cpuid::CpuidResult; +use crate::cpu::msr::write_msr; +use crate::cpu::percpu::PerCpu; +use crate::error::SvsmError; +use crate::io::{IOPort, DEFAULT_IO_DRIVER}; +use crate::platform::{PageEncryptionMasks, PageStateChangeOp, PageValidateOp, SvsmPlatform}; +use crate::types::PageSize; +use crate::utils::MemoryRegion; + +#[cfg(debug_assertions)] +use crate::mm::virt_to_phys; + +const APIC_MSR_ICR: u32 = 0x830; + +#[derive(Clone, Copy, Debug)] +pub struct NativePlatform {} + +impl NativePlatform { + pub fn new() -> Self { + Self {} + } +} + +impl Default for NativePlatform { + fn default() -> Self { + Self::new() + } +} + +impl SvsmPlatform for NativePlatform { + fn env_setup(&mut self, debug_serial_port: u16, _vtom: usize) -> Result<(), SvsmError> { + // In the native platform, console output does not require the use of + // any platform services, so it can be initialized immediately. + init_svsm_console(&DEFAULT_IO_DRIVER, debug_serial_port) + } + + fn env_setup_late(&mut self, _debug_serial_port: u16) -> Result<(), SvsmError> { + Ok(()) + } + + fn env_setup_svsm(&self) -> Result<(), SvsmError> { + Ok(()) + } + + fn setup_percpu(&self, _cpu: &PerCpu) -> Result<(), SvsmError> { + Ok(()) + } + + fn setup_percpu_current(&self, _cpu: &PerCpu) -> Result<(), SvsmError> { + Ok(()) + } + + fn get_page_encryption_masks(&self) -> PageEncryptionMasks { + // Find physical address size. + let res = CpuidResult::get(0x80000008, 0); + PageEncryptionMasks { + private_pte_mask: 0, + shared_pte_mask: 0, + addr_mask_width: 64, + phys_addr_sizes: res.eax, + } + } + + fn cpuid(&self, eax: u32) -> Option { + Some(CpuidResult::get(eax, 0)) + } + + fn setup_guest_host_comm(&mut self, _cpu: &PerCpu, _is_bsp: bool) {} + + fn get_io_port(&self) -> &'static dyn IOPort { + &DEFAULT_IO_DRIVER + } + + fn page_state_change( + &self, + _region: MemoryRegion, + _size: PageSize, + _op: PageStateChangeOp, + ) -> Result<(), SvsmError> { + Ok(()) + } + + fn validate_physical_page_range( + &self, + _region: MemoryRegion, + _op: PageValidateOp, + ) -> Result<(), SvsmError> { + Ok(()) + } + + fn validate_virtual_page_range( + &self, + _region: MemoryRegion, + _op: PageValidateOp, + ) -> Result<(), SvsmError> { + #[cfg(debug_assertions)] + { + // Ensure that it is possible to translate this virtual address to + // a physical address. This is not necessary for correctness + // here, but since other platformss may rely on virtual-to-physical + // translation, it is helpful to force a translation here for + // debugging purposes just to help catch potential errors when + // testing on native. + for va in _region.iter_pages(PageSize::Regular) { + let _ = virt_to_phys(va); + } + } + Ok(()) + } + + fn configure_alternate_injection(&mut self, _alt_inj_requested: bool) -> Result<(), SvsmError> { + Ok(()) + } + + fn change_apic_registration_state(&self, _incr: bool) -> Result { + Err(SvsmError::NotSupported) + } + + fn query_apic_registration_state(&self) -> bool { + false + } + + fn use_interrupts(&self) -> bool { + true + } + + fn post_irq(&self, icr: u64) -> Result<(), SvsmError> { + write_msr(APIC_MSR_ICR, icr); + Ok(()) + } + + fn eoi(&self) { + todo!(); + } + + fn is_external_interrupt(&self, _vector: usize) -> bool { + // For a native platform, the hypervisor is fully trusted with all + // event delivery, so all events are assumed not to be external + // interrupts. + false + } + + fn start_cpu(&self, _cpu: &PerCpu, _start_rip: u64) -> Result<(), SvsmError> { + todo!(); + } +} diff --git a/stage2/src/platform/snp.rs b/stage2/src/platform/snp.rs new file mode 100644 index 000000000..79863690c --- /dev/null +++ b/stage2/src/platform/snp.rs @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange + +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::console::init_svsm_console; +use crate::cpu::cpuid::{cpuid_table, CpuidResult}; +use crate::cpu::percpu::{current_ghcb, this_cpu, PerCpu}; +use crate::error::ApicError::Registration; +use crate::error::SvsmError; +use crate::greq::driver::guest_request_driver_init; +use crate::io::IOPort; +use crate::mm::{PerCPUPageMappingGuard, PAGE_SIZE, PAGE_SIZE_2M}; +use crate::platform::{PageEncryptionMasks, PageStateChangeOp, PageValidateOp, SvsmPlatform}; +use crate::sev::ghcb::GHCBIOSize; +use crate::sev::hv_doorbell::current_hv_doorbell; +use crate::sev::msr_protocol::{ + hypervisor_ghcb_features, request_termination_msr, verify_ghcb_version, GHCBHvFeatures, +}; +use crate::sev::status::{sev_restricted_injection, vtom_enabled}; +use crate::sev::{ + init_hypervisor_ghcb_features, pvalidate_range, sev_status_init, sev_status_verify, PvalidateOp, +}; +use crate::types::PageSize; +use crate::utils::immut_after_init::ImmutAfterInitCell; +use crate::utils::MemoryRegion; + +#[cfg(debug_assertions)] +use crate::mm::virt_to_phys; + +use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; + +static SVSM_ENV_INITIALIZED: AtomicBool = AtomicBool::new(false); + +static GHCB_IO_DRIVER: GHCBIOPort = GHCBIOPort::new(); + +static VTOM: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +static APIC_EMULATION_REG_COUNT: AtomicU32 = AtomicU32::new(0); + +fn pvalidate_page_range(range: MemoryRegion, op: PvalidateOp) -> Result<(), SvsmError> { + // In the future, it is likely that this function will need to be prepared + // to execute both PVALIDATE and RMPADJUST over the same set of addresses, + // so the loop is structured to anticipate that possibility. + let mut paddr = range.start(); + let paddr_end = range.end(); + while paddr < paddr_end { + // Check whether a 2 MB page can be attempted. + let len = if paddr.is_aligned(PAGE_SIZE_2M) && paddr + PAGE_SIZE_2M <= paddr_end { + PAGE_SIZE_2M + } else { + PAGE_SIZE + }; + let mapping = PerCPUPageMappingGuard::create(paddr, paddr + len, 0)?; + pvalidate_range(MemoryRegion::new(mapping.virt_addr(), len), op)?; + paddr = paddr + len; + } + + Ok(()) +} + +impl From for PvalidateOp { + fn from(op: PageValidateOp) -> PvalidateOp { + match op { + PageValidateOp::Validate => PvalidateOp::Valid, + PageValidateOp::Invalidate => PvalidateOp::Invalid, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct SnpPlatform { + can_use_interrupts: bool, +} + +impl SnpPlatform { + pub fn new() -> Self { + Self { + can_use_interrupts: false, + } + } +} + +impl Default for SnpPlatform { + fn default() -> Self { + Self::new() + } +} + +impl SvsmPlatform for SnpPlatform { + fn env_setup(&mut self, _debug_serial_port: u16, vtom: usize) -> Result<(), SvsmError> { + sev_status_init(); + VTOM.init(&vtom).map_err(|_| SvsmError::PlatformInit)?; + + // Now that SEV status is initialized, determine whether this platform + // supports the use of SVSM interrupts. SVSM interrupts are supported + // if this system uses restricted injection. + if sev_restricted_injection() { + self.can_use_interrupts = true; + } + + Ok(()) + } + + fn env_setup_late(&mut self, debug_serial_port: u16) -> Result<(), SvsmError> { + init_svsm_console(&GHCB_IO_DRIVER, debug_serial_port)?; + sev_status_verify(); + init_hypervisor_ghcb_features()?; + Ok(()) + } + + fn env_setup_svsm(&self) -> Result<(), SvsmError> { + this_cpu().configure_hv_doorbell()?; + guest_request_driver_init(); + SVSM_ENV_INITIALIZED.store(true, Ordering::Relaxed); + Ok(()) + } + + fn setup_percpu(&self, cpu: &PerCpu) -> Result<(), SvsmError> { + // Setup GHCB + cpu.setup_ghcb() + } + + fn setup_percpu_current(&self, cpu: &PerCpu) -> Result<(), SvsmError> { + cpu.register_ghcb()?; + + // #HV doorbell allocation can only occur if the SVSM environment has + // already been initialized. Skip allocation if not; it will be done + // during environment initialization. + if SVSM_ENV_INITIALIZED.load(Ordering::Relaxed) { + cpu.configure_hv_doorbell()?; + } + + Ok(()) + } + + fn get_page_encryption_masks(&self) -> PageEncryptionMasks { + // Find physical address size. + let processor_capacity = + cpuid_table(0x80000008).expect("Can not get physical address size from CPUID table"); + if vtom_enabled() { + let vtom = *VTOM; + PageEncryptionMasks { + private_pte_mask: 0, + shared_pte_mask: vtom, + addr_mask_width: vtom.leading_zeros(), + phys_addr_sizes: processor_capacity.eax, + } + } else { + // Find C-bit position. + let sev_capabilities = + cpuid_table(0x8000001f).expect("Can not get C-Bit position from CPUID table"); + let c_bit = sev_capabilities.ebx & 0x3f; + PageEncryptionMasks { + private_pte_mask: 1 << c_bit, + shared_pte_mask: 0, + addr_mask_width: c_bit, + phys_addr_sizes: processor_capacity.eax, + } + } + } + + fn cpuid(&self, eax: u32) -> Option { + cpuid_table(eax) + } + + fn setup_guest_host_comm(&mut self, cpu: &PerCpu, is_bsp: bool) { + if is_bsp { + verify_ghcb_version(); + } + + cpu.setup_ghcb().unwrap_or_else(|_| { + if is_bsp { + panic!("Failed to setup BSP GHCB"); + } else { + panic!("Failed to setup AP GHCB"); + } + }); + cpu.register_ghcb().expect("Failed to register GHCB"); + } + + fn get_io_port(&self) -> &'static dyn IOPort { + &GHCB_IO_DRIVER + } + + fn page_state_change( + &self, + region: MemoryRegion, + size: PageSize, + op: PageStateChangeOp, + ) -> Result<(), SvsmError> { + current_ghcb().page_state_change(region, size, op) + } + + fn validate_physical_page_range( + &self, + region: MemoryRegion, + op: PageValidateOp, + ) -> Result<(), SvsmError> { + pvalidate_page_range(region, PvalidateOp::from(op)) + } + + fn validate_virtual_page_range( + &self, + region: MemoryRegion, + op: PageValidateOp, + ) -> Result<(), SvsmError> { + #[cfg(debug_assertions)] + { + // Ensure that it is possible to translate this virtual address to + // a physical address. This is not necessary for correctness + // here, but since other platformss may rely on virtual-to-physical + // translation, it is helpful to force a translation here for + // debugging purposes just to help catch potential errors when + // testing on SNP. + for va in region.iter_pages(PageSize::Regular) { + let _ = virt_to_phys(va); + } + } + pvalidate_range(region, PvalidateOp::from(op)) + } + + fn configure_alternate_injection(&mut self, alt_inj_requested: bool) -> Result<(), SvsmError> { + if !alt_inj_requested { + return Ok(()); + } + + // If alternate injection was requested, then it must be supported by + // the hypervisor. + if !hypervisor_ghcb_features().contains(GHCBHvFeatures::SEV_SNP_EXT_INTERRUPTS) { + return Err(SvsmError::NotSupported); + } + + APIC_EMULATION_REG_COUNT.store(1, Ordering::Relaxed); + Ok(()) + } + + fn change_apic_registration_state(&self, incr: bool) -> Result { + let mut current = APIC_EMULATION_REG_COUNT.load(Ordering::Relaxed); + loop { + let new = if incr { + // Incrementing is only possible if the registration count + // has not already dropped to zero, and only if the + // registration count will not wrap around. + if current == 0 { + return Err(SvsmError::Apic(Registration)); + } + current + .checked_add(1) + .ok_or(SvsmError::Apic(Registration))? + } else { + // An attempt to decrement when the count is already zero is + // considered a benign race, which will not result in any + // actual change but will indicate that emulation is being + // disabled for the guest. + if current == 0 { + return Ok(false); + } + current - 1 + }; + match APIC_EMULATION_REG_COUNT.compare_exchange_weak( + current, + new, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => { + return Ok(new > 0); + } + Err(val) => current = val, + } + } + } + + fn query_apic_registration_state(&self) -> bool { + APIC_EMULATION_REG_COUNT.load(Ordering::Relaxed) > 0 + } + + fn use_interrupts(&self) -> bool { + self.can_use_interrupts + } + + fn post_irq(&self, icr: u64) -> Result<(), SvsmError> { + current_ghcb().hv_ipi(icr)?; + Ok(()) + } + + fn eoi(&self) { + // Issue an explicit EOI unless no explicit EOI is required. + if !current_hv_doorbell().no_eoi_required() { + // 0x80B is the X2APIC EOI MSR. + // Errors here cannot be handled but should not be grounds for + // panic. + let _ = current_ghcb().wrmsr(0x80B, 0); + } + } + + fn is_external_interrupt(&self, _vector: usize) -> bool { + // When restricted injection is active, the event disposition is + // already known to the caller and thus need not be examined. When + // restricted injection is not active, the hypervisor must be trusted + // with all event delivery, so all events are assumed not to be + // external interrupts. + false + } + + fn start_cpu(&self, cpu: &PerCpu, start_rip: u64) -> Result<(), SvsmError> { + let (vmsa_pa, sev_features) = cpu.alloc_svsm_vmsa(*VTOM as u64, start_rip)?; + + current_ghcb().ap_create(vmsa_pa, cpu.get_apic_id().into(), 0, sev_features) + } +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct GHCBIOPort {} + +impl GHCBIOPort { + pub const fn new() -> Self { + GHCBIOPort {} + } +} + +impl IOPort for GHCBIOPort { + fn outb(&self, port: u16, value: u8) { + let ret = current_ghcb().ioio_out(port, GHCBIOSize::Size8, value as u64); + if ret.is_err() { + request_termination_msr(); + } + } + + fn inb(&self, port: u16) -> u8 { + let ret = current_ghcb().ioio_in(port, GHCBIOSize::Size8); + match ret { + Ok(v) => (v & 0xff) as u8, + Err(_e) => request_termination_msr(), + } + } + + fn outw(&self, port: u16, value: u16) { + let ret = current_ghcb().ioio_out(port, GHCBIOSize::Size16, value as u64); + if ret.is_err() { + request_termination_msr(); + } + } + + fn inw(&self, port: u16) -> u16 { + let ret = current_ghcb().ioio_in(port, GHCBIOSize::Size16); + match ret { + Ok(v) => (v & 0xffff) as u16, + Err(_e) => request_termination_msr(), + } + } + + fn outl(&self, port: u16, value: u32) { + let ret = current_ghcb().ioio_out(port, GHCBIOSize::Size32, value as u64); + if ret.is_err() { + request_termination_msr(); + } + } + + fn inl(&self, port: u16) -> u32 { + let ret = current_ghcb().ioio_in(port, GHCBIOSize::Size32); + match ret { + Ok(v) => (v & 0xffffffff) as u32, + Err(_e) => request_termination_msr(), + } + } +} diff --git a/stage2/src/platform/tdp.rs b/stage2/src/platform/tdp.rs new file mode 100644 index 000000000..819f69d2f --- /dev/null +++ b/stage2/src/platform/tdp.rs @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (C) 2024 Intel Corporation +// +// Author: Peter Fang + +use crate::address::{PhysAddr, VirtAddr}; +use crate::console::init_svsm_console; +use crate::cpu::cpuid::CpuidResult; +use crate::cpu::percpu::PerCpu; +use crate::error::SvsmError; +use crate::io::IOPort; +use crate::mm::{virt_to_frame, PerCPUPageMappingGuard}; +use crate::platform::{PageEncryptionMasks, PageStateChangeOp, PageValidateOp, SvsmPlatform}; +use crate::types::PageSize; +use crate::utils::immut_after_init::ImmutAfterInitCell; +use crate::utils::{zero_mem_region, MemoryRegion}; +use tdx_tdcall::tdx::{ + td_accept_memory, tdvmcall_halt, tdvmcall_io_read_16, tdvmcall_io_read_32, tdvmcall_io_read_8, + tdvmcall_io_write_16, tdvmcall_io_write_32, tdvmcall_io_write_8, +}; + +static GHCI_IO_DRIVER: GHCIIOPort = GHCIIOPort::new(); +static VTOM: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +#[derive(Clone, Copy, Debug)] +pub struct TdpPlatform {} + +impl TdpPlatform { + pub fn new() -> Self { + Self {} + } +} + +impl Default for TdpPlatform { + fn default() -> Self { + Self::new() + } +} + +impl SvsmPlatform for TdpPlatform { + fn halt() { + tdvmcall_halt(); + } + + fn env_setup(&mut self, debug_serial_port: u16, vtom: usize) -> Result<(), SvsmError> { + VTOM.init(&vtom).map_err(|_| SvsmError::PlatformInit)?; + // Serial console device can be initialized immediately + init_svsm_console(&GHCI_IO_DRIVER, debug_serial_port) + } + + fn env_setup_late(&mut self, _debug_serial_port: u16) -> Result<(), SvsmError> { + Ok(()) + } + + fn env_setup_svsm(&self) -> Result<(), SvsmError> { + Ok(()) + } + + fn setup_percpu(&self, _cpu: &PerCpu) -> Result<(), SvsmError> { + Err(SvsmError::Tdx) + } + + fn setup_percpu_current(&self, _cpu: &PerCpu) -> Result<(), SvsmError> { + Err(SvsmError::Tdx) + } + + fn get_page_encryption_masks(&self) -> PageEncryptionMasks { + // Find physical address size. + let res = CpuidResult::get(0x80000008, 0); + let vtom = *VTOM; + PageEncryptionMasks { + private_pte_mask: 0, + shared_pte_mask: vtom, + addr_mask_width: vtom.trailing_zeros(), + phys_addr_sizes: res.eax, + } + } + + fn cpuid(&self, eax: u32) -> Option { + Some(CpuidResult::get(eax, 0)) + } + + fn setup_guest_host_comm(&mut self, _cpu: &PerCpu, _is_bsp: bool) {} + + fn get_io_port(&self) -> &'static dyn IOPort { + &GHCI_IO_DRIVER + } + + fn page_state_change( + &self, + _region: MemoryRegion, + _size: PageSize, + _op: PageStateChangeOp, + ) -> Result<(), SvsmError> { + Err(SvsmError::Tdx) + } + + fn validate_physical_page_range( + &self, + region: MemoryRegion, + op: PageValidateOp, + ) -> Result<(), SvsmError> { + match op { + PageValidateOp::Validate => { + td_accept_memory(region.start().into(), region.len().try_into().unwrap()); + } + PageValidateOp::Invalidate => { + let mapping = PerCPUPageMappingGuard::create(region.start(), region.end(), 0)?; + zero_mem_region(mapping.virt_addr(), mapping.virt_addr() + region.len()); + } + } + Ok(()) + } + + fn validate_virtual_page_range( + &self, + region: MemoryRegion, + op: PageValidateOp, + ) -> Result<(), SvsmError> { + match op { + PageValidateOp::Validate => { + let mut va = region.start(); + while va < region.end() { + let pa = virt_to_frame(va); + let sz = pa.end() - pa.address(); + // td_accept_memory() will take care of alignment + td_accept_memory(pa.address().into(), sz.try_into().unwrap()); + va = va + sz; + } + } + PageValidateOp::Invalidate => { + zero_mem_region(region.start(), region.end()); + } + } + Ok(()) + } + + fn configure_alternate_injection(&mut self, _alt_inj_requested: bool) -> Result<(), SvsmError> { + Err(SvsmError::Tdx) + } + + fn change_apic_registration_state(&self, _incr: bool) -> Result { + Err(SvsmError::NotSupported) + } + + fn query_apic_registration_state(&self) -> bool { + false + } + + fn use_interrupts(&self) -> bool { + true + } + + fn post_irq(&self, _icr: u64) -> Result<(), SvsmError> { + Err(SvsmError::Tdx) + } + + fn eoi(&self) {} + + fn is_external_interrupt(&self, _vector: usize) -> bool { + // Examine the APIC ISR to determine whether this interrupt vector is + // active. If so, it is assumed to be an external interrupt. + // TODO - add code to read the APIC ISR. + todo!(); + } + + fn start_cpu(&self, _cpu: &PerCpu, _start_rip: u64) -> Result<(), SvsmError> { + todo!(); + } +} + +#[derive(Clone, Copy, Debug, Default)] +struct GHCIIOPort {} + +impl GHCIIOPort { + pub const fn new() -> Self { + GHCIIOPort {} + } +} + +impl IOPort for GHCIIOPort { + fn outb(&self, port: u16, value: u8) { + tdvmcall_io_write_8(port, value); + } + + fn inb(&self, port: u16) -> u8 { + tdvmcall_io_read_8(port) + } + + fn outw(&self, port: u16, value: u16) { + tdvmcall_io_write_16(port, value); + } + + fn inw(&self, port: u16) -> u16 { + tdvmcall_io_read_16(port) + } + + fn outl(&self, port: u16, value: u32) { + tdvmcall_io_write_32(port, value); + } + + fn inl(&self, port: u16) -> u32 { + tdvmcall_io_read_32(port) + } +} diff --git a/stage2/src/protocols/apic.rs b/stage2/src/protocols/apic.rs new file mode 100644 index 000000000..ce44d3a32 --- /dev/null +++ b/stage2/src/protocols/apic.rs @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) Microsoft Corporation +// +// Author: Jon Lange (jlange@microsoft.com) + +use crate::cpu::percpu::this_cpu; +use crate::platform::SVSM_PLATFORM; +use crate::protocols::errors::SvsmReqError; +use crate::protocols::RequestParams; + +const SVSM_REQ_APIC_QUERY_FEATURES: u32 = 0; +const SVSM_REQ_APIC_CONFIGURE: u32 = 1; +const SVSM_REQ_APIC_READ_REGISTER: u32 = 2; +const SVSM_REQ_APIC_WRITE_REGISTER: u32 = 3; +const SVSM_REQ_APIC_CONFIGURE_VECTOR: u32 = 4; + +pub const APIC_PROTOCOL: u32 = 3; +pub const APIC_PROTOCOL_VERSION_MIN: u32 = 1; +pub const APIC_PROTOCOL_VERSION_MAX: u32 = 1; + +fn apic_query_features(params: &mut RequestParams) -> Result<(), SvsmReqError> { + // No features are supported beyond the base feature set. + params.rcx = 0; + Ok(()) +} + +fn apic_configure(params: &RequestParams) -> Result<(), SvsmReqError> { + let enabled = match params.rcx { + 0b00 => { + // Query the current registration state of APIC emulation to + // determine whether it should be disabled on the current CPU. + SVSM_PLATFORM.query_apic_registration_state() + } + + 0b01 => { + // Deregister APIC emulation if possible, noting whether it is now + // disabled for the platform. This cannot fail. + SVSM_PLATFORM.change_apic_registration_state(false).unwrap() + } + + 0b10 => { + // Increment the APIC emulation registration count. If successful, + // this will not cause any change to the state of the current CPU. + SVSM_PLATFORM.change_apic_registration_state(true)?; + return Ok(()); + } + + _ => { + return Err(SvsmReqError::invalid_parameter()); + } + }; + + // Disable APIC emulation on the current CPU if required. + if !enabled { + this_cpu().disable_apic_emulation(); + } + + Ok(()) +} + +fn apic_read_register(params: &mut RequestParams) -> Result<(), SvsmReqError> { + let cpu = this_cpu(); + let value = cpu.read_apic_register(params.rcx)?; + params.rdx = value; + Ok(()) +} + +fn apic_write_register(params: &RequestParams) -> Result<(), SvsmReqError> { + let cpu = this_cpu(); + cpu.write_apic_register(params.rcx, params.rdx)?; + Ok(()) +} + +fn apic_configure_vector(params: &RequestParams) -> Result<(), SvsmReqError> { + let cpu = this_cpu(); + if !cpu.use_apic_emulation() { + return Err(SvsmReqError::invalid_request()); + } + if params.rcx <= 0x1FF { + let vector: u8 = (params.rcx & 0xFF) as u8; + let allowed = (params.rcx & 0x100) != 0; + cpu.configure_apic_vector(vector, allowed)?; + Ok(()) + } else { + Err(SvsmReqError::invalid_parameter()) + } +} + +pub fn apic_protocol_request(request: u32, params: &mut RequestParams) -> Result<(), SvsmReqError> { + if !this_cpu().use_apic_emulation() { + return Err(SvsmReqError::unsupported_protocol()); + } + match request { + SVSM_REQ_APIC_QUERY_FEATURES => apic_query_features(params), + SVSM_REQ_APIC_CONFIGURE => apic_configure(params), + SVSM_REQ_APIC_READ_REGISTER => apic_read_register(params), + SVSM_REQ_APIC_WRITE_REGISTER => apic_write_register(params), + SVSM_REQ_APIC_CONFIGURE_VECTOR => apic_configure_vector(params), + + _ => Err(SvsmReqError::unsupported_call()), + } +} diff --git a/stage2/src/protocols/core.rs b/stage2/src/protocols/core.rs new file mode 100644 index 000000000..cd0f92ea6 --- /dev/null +++ b/stage2/src/protocols/core.rs @@ -0,0 +1,432 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::cpu::flush_tlb_global_sync; +use crate::cpu::percpu::{this_cpu, this_cpu_shared, PERCPU_AREAS, PERCPU_VMSAS}; +use crate::cpu::vmsa::{vmsa_mut_ref_from_vaddr, vmsa_ref_from_vaddr}; +use crate::error::SvsmError; +use crate::locking::RWLock; +use crate::mm::virtualrange::{VIRT_ALIGN_2M, VIRT_ALIGN_4K}; +use crate::mm::PerCPUPageMappingGuard; +use crate::mm::{valid_phys_address, writable_phys_addr, GuestPtr}; +use crate::protocols::apic::{APIC_PROTOCOL, APIC_PROTOCOL_VERSION_MAX, APIC_PROTOCOL_VERSION_MIN}; +use crate::protocols::errors::SvsmReqError; +use crate::protocols::RequestParams; +use crate::requests::SvsmCaa; +use crate::sev::utils::{ + pvalidate, rmp_clear_guest_vmsa, rmp_grant_guest_access, rmp_revoke_guest_access, + rmp_set_guest_vmsa, PvalidateOp, RMPFlags, SevSnpError, +}; +use crate::sev::vmsa::VMSAControl; +use crate::types::{PageSize, PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::zero_mem_region; +use cpuarch::vmsa::VMSA; + +const SVSM_REQ_CORE_REMAP_CA: u32 = 0; +const SVSM_REQ_CORE_PVALIDATE: u32 = 1; +const SVSM_REQ_CORE_CREATE_VCPU: u32 = 2; +const SVSM_REQ_CORE_DELETE_VCPU: u32 = 3; +const SVSM_REQ_CORE_DEPOSIT_MEM: u32 = 4; +const SVSM_REQ_CORE_WITHDRAW_MEM: u32 = 5; +const SVSM_REQ_CORE_QUERY_PROTOCOL: u32 = 6; +const SVSM_REQ_CORE_CONFIGURE_VTOM: u32 = 7; + +const CORE_PROTOCOL: u32 = 1; +const CORE_PROTOCOL_VERSION_MIN: u32 = 1; +const CORE_PROTOCOL_VERSION_MAX: u32 = 1; + +// This lock prevents races around PVALIDATE and CREATE_VCPU +// +// Without the lock there is a possible attack where the error path of +// core_create_vcpu() could give the guest OS access to a SVSM page. +// +// The PValidate path will take the lock for read, the create_vcpu path takes +// the lock for write. +static PVALIDATE_LOCK: RWLock<()> = RWLock::new(()); + +#[repr(C, packed)] +#[derive(Copy, Clone)] +struct PValidateRequest { + entries: u16, + next: u16, + resv: u32, +} + +fn core_create_vcpu_error_restore(paddr: Option, vaddr: Option) { + if let Some(v) = vaddr { + if let Err(err) = rmp_clear_guest_vmsa(v) { + log::error!("Failed to restore page permissions: {:#?}", err); + } + } + // In case mappings have been changed + flush_tlb_global_sync(); + + if let Some(p) = paddr { + // SAFETY: This can only fail if another CPU unregisters our + // unused VMSA. This is not possible, since unregistration of + // an unused VMSA only happens in the error path of core_create_vcpu(), + // with a physical address that only this CPU managed to register. + PERCPU_VMSAS.unregister(p, false).unwrap(); + } +} + +// VMSA validity checks according to SVSM spec +fn check_vmsa(new: &VMSA, sev_features: u64, svme_mask: u64) -> bool { + new.vmpl == RMPFlags::GUEST_VMPL.bits() as u8 + && new.efer & svme_mask == svme_mask + && new.sev_features == sev_features +} + +/// per-cpu request mapping area size (1GB) +fn core_create_vcpu(params: &RequestParams) -> Result<(), SvsmReqError> { + let paddr = PhysAddr::from(params.rcx); + let pcaa = PhysAddr::from(params.rdx); + let apic_id: u32 = (params.r8 & 0xffff_ffff) as u32; + + // Check VMSA address + if !valid_phys_address(paddr) || !paddr.is_page_aligned() { + return Err(SvsmReqError::invalid_address()); + } + + // Check CAA address + if !valid_phys_address(pcaa) || !pcaa.is_page_aligned() { + return Err(SvsmReqError::invalid_address()); + } + + // Check whether VMSA page and CAA region overlap + // + // Since both areas are 4kb aligned and 4kb in size, and correct alignment + // was already checked, it is enough here to check whether VMSA and CAA + // page have the same starting address. + if paddr == pcaa { + return Err(SvsmReqError::invalid_address()); + } + + let target_cpu = PERCPU_AREAS + .get(apic_id) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + // Got valid gPAs and APIC ID, register VMSA immediately to avoid races + PERCPU_VMSAS.register(paddr, apic_id, true)?; + + // Time to map the VMSA. No need to clean up the registered VMSA on the + // error path since this is a fatal error anyway. + let mapping_guard = PerCPUPageMappingGuard::create_4k(paddr)?; + let vaddr = mapping_guard.virt_addr(); + + // Prevent any parallel PVALIDATE requests from being processed + let lock = PVALIDATE_LOCK.lock_write(); + + // Make sure the guest can't make modifications to the VMSA page + rmp_revoke_guest_access(vaddr, PageSize::Regular).inspect_err(|_| { + core_create_vcpu_error_restore(Some(paddr), None); + })?; + + // TLB flush needed to propagate new permissions + flush_tlb_global_sync(); + + let new_vmsa = vmsa_ref_from_vaddr(vaddr); + let svme_mask: u64 = 1u64 << 12; + + // VMSA validity checks according to SVSM spec + if !check_vmsa(new_vmsa, params.sev_features, svme_mask) { + core_create_vcpu_error_restore(Some(paddr), Some(vaddr)); + return Err(SvsmReqError::invalid_parameter()); + } + + // Set the VMSA bit + rmp_set_guest_vmsa(vaddr).inspect_err(|_| { + core_create_vcpu_error_restore(Some(paddr), Some(vaddr)); + })?; + + drop(lock); + + assert!(PERCPU_VMSAS.set_used(paddr) == Some(apic_id)); + target_cpu.update_guest_vmsa_caa(paddr, pcaa); + + Ok(()) +} + +fn core_delete_vcpu(params: &RequestParams) -> Result<(), SvsmReqError> { + let paddr = PhysAddr::from(params.rcx); + + PERCPU_VMSAS + .unregister(paddr, true) + .map_err(|_| SvsmReqError::invalid_parameter())?; + + // Map the VMSA + let mapping_guard = PerCPUPageMappingGuard::create_4k(paddr)?; + let vaddr = mapping_guard.virt_addr(); + + // Clear EFER.SVME on deleted VMSA. If the VMSA is executing + // disable() will loop until that is not the case + let del_vmsa = vmsa_mut_ref_from_vaddr(vaddr); + del_vmsa.disable(); + + // Do not return early here, as we need to do a TLB flush + let res = rmp_clear_guest_vmsa(vaddr).map_err(|_| SvsmReqError::invalid_address()); + + // Unmap the page + drop(mapping_guard); + + // Tell everyone the news and flush temporary mapping + flush_tlb_global_sync(); + + res +} + +fn core_deposit_mem(_params: &RequestParams) -> Result<(), SvsmReqError> { + log::info!("Request SVSM_REQ_CORE_DEPOSIT_MEM not yet supported"); + Err(SvsmReqError::unsupported_call()) +} + +fn core_withdraw_mem(_params: &RequestParams) -> Result<(), SvsmReqError> { + log::info!("Request SVSM_REQ_CORE_WITHDRAW_MEM not yet supported"); + Err(SvsmReqError::unsupported_call()) +} + +fn protocol_supported(version: u32, version_min: u32, version_max: u32) -> u64 { + if version >= version_min && version <= version_max { + let ret_low: u64 = version_min.into(); + let ret_high: u64 = version_max.into(); + + ret_low | (ret_high << 32) + } else { + 0 + } +} + +fn core_query_protocol(params: &mut RequestParams) -> Result<(), SvsmReqError> { + let rcx: u64 = params.rcx; + let protocol: u32 = (rcx >> 32).try_into().unwrap(); + let version: u32 = (rcx & 0xffff_ffffu64).try_into().unwrap(); + + let ret_val = match protocol { + CORE_PROTOCOL => protocol_supported( + version, + CORE_PROTOCOL_VERSION_MIN, + CORE_PROTOCOL_VERSION_MAX, + ), + APIC_PROTOCOL => { + // The APIC protocol is only supported if the calling CPU supports + // alternate injection. + if this_cpu().use_apic_emulation() { + protocol_supported( + version, + APIC_PROTOCOL_VERSION_MIN, + APIC_PROTOCOL_VERSION_MAX, + ) + } else { + 0 + } + } + _ => 0, + }; + + params.rcx = ret_val; + + Ok(()) +} + +fn core_configure_vtom(params: &mut RequestParams) -> Result<(), SvsmReqError> { + let query: bool = (params.rcx & 1) == 1; + + // Report that vTOM configuration is unsupported + if query { + params.rcx = 0; + Ok(()) + } else { + Err(SvsmReqError::invalid_request()) + } +} + +fn core_pvalidate_one(entry: u64, flush: &mut bool) -> Result<(), SvsmReqError> { + let (page_size_bytes, valign, huge) = match entry & 3 { + 0 => (PAGE_SIZE, VIRT_ALIGN_4K, PageSize::Regular), + 1 => (PAGE_SIZE_2M, VIRT_ALIGN_2M, PageSize::Huge), + _ => return Err(SvsmReqError::invalid_parameter()), + }; + + let valid = match (entry & 4) == 4 { + true => PvalidateOp::Valid, + false => PvalidateOp::Invalid, + }; + let ign_cf = (entry & 8) == 8; + + let paddr = PhysAddr::from(entry).page_align(); + + if !paddr.is_aligned(page_size_bytes) { + return Err(SvsmReqError::invalid_parameter()); + } + + if !valid_phys_address(paddr) { + log::debug!("Invalid phys address: {:#x}", paddr); + return Err(SvsmReqError::invalid_address()); + } + + let guard = PerCPUPageMappingGuard::create(paddr, paddr + page_size_bytes, valign)?; + let vaddr = guard.virt_addr(); + + // Take lock to prevent races with CREATE_VCPU calls + let lock = PVALIDATE_LOCK.lock_read(); + + if valid == PvalidateOp::Invalid { + *flush |= true; + rmp_revoke_guest_access(vaddr, huge)?; + } + + pvalidate(vaddr, huge, valid).or_else(|err| match err { + SvsmError::SevSnp(SevSnpError::FAIL_UNCHANGED(_)) if ign_cf => Ok(()), + _ => Err(err), + })?; + + drop(lock); + + if valid == PvalidateOp::Valid { + // Zero out a page when it is validated and before giving other VMPLs + // access to it. This is necessary to prevent a possible HV attack: + // + // Attack scenario: + // 1) SVSM stores secrets in VMPL0 memory at GPA A + // 2) HV invalidates GPA A and maps the SPA to GPA B, which is in the + // OS range of GPAs + // 3) Guest OS asks SVSM to validate GPA B + // 4) SVSM validates page and gives OS access + // 5) OS can now read SVSM secrets from GPA B + // + // The SVSM will not notice the attack until it tries to access GPA A + // again. Prevent it by clearing every page before giving access to + // other VMPLs. + // + // Be careful to not clear GPAs which the HV might have mapped + // read-only, as the write operation might cause infinite #NPF loops. + // + // Special thanks to Tom Lendacky for reporting the issue and tracking + // down the #NPF loops. + // + if writable_phys_addr(paddr) { + // FIXME: This check leaves a window open for the attack described + // above. Remove the check once OVMF and Linux have been fixed and + // no longer try to pvalidate MMIO memory. + zero_mem_region(vaddr, vaddr + page_size_bytes); + } else { + log::warn!("Not clearing possible read-only page at PA {:#x}", paddr); + } + rmp_grant_guest_access(vaddr, huge)?; + } + + Ok(()) +} + +fn core_pvalidate(params: &RequestParams) -> Result<(), SvsmReqError> { + let gpa = PhysAddr::from(params.rcx); + + if !gpa.is_aligned(8) || !valid_phys_address(gpa) { + return Err(SvsmReqError::invalid_parameter()); + } + + let paddr = gpa.page_align(); + let offset = gpa.page_offset(); + + let guard = PerCPUPageMappingGuard::create_4k(paddr)?; + let start = guard.virt_addr(); + + let guest_page = GuestPtr::::new(start + offset); + // SAFETY: start is a new mapped page address, thus valid. + // offset can't exceed a page size, so guest_page belongs to mapped memory. + let mut request = unsafe { guest_page.read()? }; + + let entries = request.entries; + let next = request.next; + + // Each entry is 8 bytes in size, 8 bytes for the request header + let max_entries: u16 = ((PAGE_SIZE - offset - 8) / 8).try_into().unwrap(); + + if entries == 0 || entries > max_entries || entries <= next { + return Err(SvsmReqError::invalid_parameter()); + } + + let mut loop_result = Ok(()); + let mut flush = false; + + let guest_entries = guest_page.offset(1).cast::(); + for i in next..entries { + let index = i as isize; + // SAFETY: guest_entries comes from guest_page which is a new mapped + // page. index is between [next, entries) and both values have been + // validated. + let entry = match unsafe { guest_entries.offset(index).read() } { + Ok(v) => v, + Err(e) => { + loop_result = Err(e.into()); + break; + } + }; + + loop_result = core_pvalidate_one(entry, &mut flush); + match loop_result { + Ok(()) => request.next += 1, + Err(SvsmReqError::RequestError(..)) => break, + Err(SvsmReqError::FatalError(..)) => return loop_result, + } + } + + // SAFETY: guest_page is obtained from a guest-provided physical address + // (untrusted), so it needs to be valid (ie. belongs to the guest and only + // the guest). The physical address is validated by valid_phys_address() + // called at the beginning of SVSM_CORE_PVALIDATE handler (this one). + if let Err(e) = unsafe { guest_page.write_ref(&request) } { + loop_result = Err(e.into()); + } + + if flush { + flush_tlb_global_sync(); + } + + loop_result +} + +fn core_remap_ca(params: &RequestParams) -> Result<(), SvsmReqError> { + let gpa = PhysAddr::from(params.rcx); + + if !gpa.is_aligned(8) || !valid_phys_address(gpa) || gpa.crosses_page(8) { + return Err(SvsmReqError::invalid_parameter()); + } + + let offset = gpa.page_offset(); + let paddr = gpa.page_align(); + + // Temporarily map new CAA to clear it + let mapping_guard = PerCPUPageMappingGuard::create_4k(paddr)?; + let vaddr = mapping_guard.virt_addr() + offset; + + let pending = GuestPtr::::new(vaddr); + // SAFETY: pending points to a new allocated page + unsafe { pending.write(SvsmCaa::zeroed())? }; + + // Clear any pending interrupt state before remapping the calling area to + // ensure that any pending lazy EOI has been processed. + this_cpu().clear_pending_interrupts(); + + this_cpu_shared().update_guest_caa(gpa); + + Ok(()) +} + +pub fn core_protocol_request(request: u32, params: &mut RequestParams) -> Result<(), SvsmReqError> { + match request { + SVSM_REQ_CORE_REMAP_CA => core_remap_ca(params), + SVSM_REQ_CORE_PVALIDATE => core_pvalidate(params), + SVSM_REQ_CORE_CREATE_VCPU => core_create_vcpu(params), + SVSM_REQ_CORE_DELETE_VCPU => core_delete_vcpu(params), + SVSM_REQ_CORE_DEPOSIT_MEM => core_deposit_mem(params), + SVSM_REQ_CORE_WITHDRAW_MEM => core_withdraw_mem(params), + SVSM_REQ_CORE_QUERY_PROTOCOL => core_query_protocol(params), + SVSM_REQ_CORE_CONFIGURE_VTOM => core_configure_vtom(params), + _ => Err(SvsmReqError::unsupported_call()), + } +} diff --git a/stage2/src/protocols/errors.rs b/stage2/src/protocols/errors.rs new file mode 100644 index 000000000..70f4d4bbd --- /dev/null +++ b/stage2/src/protocols/errors.rs @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::error::ApicError; +use crate::error::SvsmError; + +#[derive(Debug, Clone, Copy)] +#[expect(non_camel_case_types)] +pub enum SvsmResultCode { + SUCCESS, + INCOMPLETE, + UNSUPPORTED_PROTOCOL, + UNSUPPORTED_CALL, + INVALID_ADDRESS, + INVALID_FORMAT, + INVALID_PARAMETER, + INVALID_REQUEST, + BUSY, + PROTOCOL_BASE(u64), +} + +impl From for u64 { + fn from(res: SvsmResultCode) -> u64 { + match res { + SvsmResultCode::SUCCESS => 0x0000_0000, + SvsmResultCode::INCOMPLETE => 0x8000_0000, + SvsmResultCode::UNSUPPORTED_PROTOCOL => 0x8000_0001, + SvsmResultCode::UNSUPPORTED_CALL => 0x8000_0002, + SvsmResultCode::INVALID_ADDRESS => 0x8000_0003, + SvsmResultCode::INVALID_FORMAT => 0x8000_0004, + SvsmResultCode::INVALID_PARAMETER => 0x8000_0005, + SvsmResultCode::INVALID_REQUEST => 0x8000_0006, + SvsmResultCode::BUSY => 0x8000_0007, + SvsmResultCode::PROTOCOL_BASE(code) => 0x8000_1000 + code, + } + } +} + +const SVSM_ERR_APIC_CANNOT_REGISTER: u64 = 0; + +#[derive(Debug, Clone, Copy)] +pub enum SvsmReqError { + RequestError(SvsmResultCode), + FatalError(SvsmError), +} + +macro_rules! impl_req_err { + ($name:ident, $v:ident) => { + pub fn $name() -> Self { + Self::RequestError(SvsmResultCode::$v) + } + }; +} + +impl SvsmReqError { + impl_req_err!(incomplete, INCOMPLETE); + impl_req_err!(unsupported_protocol, UNSUPPORTED_PROTOCOL); + impl_req_err!(unsupported_call, UNSUPPORTED_CALL); + impl_req_err!(invalid_address, INVALID_ADDRESS); + impl_req_err!(invalid_format, INVALID_FORMAT); + impl_req_err!(invalid_parameter, INVALID_PARAMETER); + impl_req_err!(invalid_request, INVALID_REQUEST); + impl_req_err!(busy, BUSY); + pub fn protocol(code: u64) -> Self { + Self::RequestError(SvsmResultCode::PROTOCOL_BASE(code)) + } +} + +impl From for SvsmReqError { + fn from(err: SvsmError) -> Self { + match err { + SvsmError::Mem => Self::FatalError(err), + // SEV-SNP errors obtained from PVALIDATE or RMPADJUST are returned + // to the guest as protocol-specific errors. + SvsmError::SevSnp(e) => Self::protocol(e.ret()), + SvsmError::InvalidAddress => Self::invalid_address(), + SvsmError::Apic(e) => match e { + ApicError::Disabled => Self::unsupported_protocol(), + ApicError::Emulation => Self::invalid_parameter(), + ApicError::Registration => Self::protocol(SVSM_ERR_APIC_CANNOT_REGISTER), + }, + // Use a fatal error for now + _ => Self::FatalError(err), + } + } +} diff --git a/stage2/src/protocols/mod.rs b/stage2/src/protocols/mod.rs new file mode 100644 index 000000000..6166fa91a --- /dev/null +++ b/stage2/src/protocols/mod.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2023 IBM Corp +// +// Author: Dov Murik + +pub mod apic; +pub mod core; +pub mod errors; +#[cfg(all(feature = "mstpm", not(test)))] +pub mod vtpm; + +use cpuarch::vmsa::{GuestVMExit, VMSA}; + +// SVSM protocols +pub const SVSM_CORE_PROTOCOL: u32 = 0; +pub const SVSM_VTPM_PROTOCOL: u32 = 2; +pub const SVSM_APIC_PROTOCOL: u32 = 3; + +#[derive(Debug, Default, Clone, Copy)] +pub struct RequestParams { + pub guest_exit_code: GuestVMExit, + sev_features: u64, + rcx: u64, + rdx: u64, + r8: u64, +} + +impl RequestParams { + pub fn from_vmsa(vmsa: &VMSA) -> Self { + RequestParams { + guest_exit_code: vmsa.guest_exit_code, + sev_features: vmsa.sev_features, + rcx: vmsa.rcx, + rdx: vmsa.rdx, + r8: vmsa.r8, + } + } + + pub fn write_back(&self, vmsa: &mut VMSA) { + vmsa.rcx = self.rcx; + vmsa.rdx = self.rdx; + vmsa.r8 = self.r8; + } +} diff --git a/stage2/src/protocols/vtpm.rs b/stage2/src/protocols/vtpm.rs new file mode 100644 index 000000000..1c7ad3df8 --- /dev/null +++ b/stage2/src/protocols/vtpm.rs @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2023 IBM Corp +// +// Author: Claudio Carvalho + +//! vTPM protocol implementation (SVSM spec, chapter 8). + +extern crate alloc; + +use core::{mem::size_of, slice::from_raw_parts_mut}; + +use alloc::vec::Vec; + +use crate::{ + address::{Address, PhysAddr}, + mm::{valid_phys_address, GuestPtr, PerCPUPageMappingGuard}, + protocols::{errors::SvsmReqError, RequestParams}, + types::PAGE_SIZE, + vtpm::{vtpm_get_locked, MsTpmSimulatorInterface, VtpmProtocolInterface}, +}; + +/// vTPM platform commands (SVSM spec, section 8.1 - SVSM_VTPM_QUERY) +/// +/// The platform commmand values follow the values used by the +/// Official TPM 2.0 Reference Implementation by Microsoft. +/// +/// `ms-tpm-20-ref/TPMCmd/Simulator/include/TpmTcpProtocol.h` +#[repr(u32)] +#[derive(PartialEq, Copy, Clone, Debug)] +pub enum TpmPlatformCommand { + SendCommand = 8, +} + +impl TryFrom for TpmPlatformCommand { + type Error = SvsmReqError; + + fn try_from(value: u32) -> Result { + let cmd = match value { + x if x == TpmPlatformCommand::SendCommand as u32 => TpmPlatformCommand::SendCommand, + other => { + log::warn!("Failed to convert {} to a TPM platform command", other); + return Err(SvsmReqError::invalid_parameter()); + } + }; + + Ok(cmd) + } +} + +fn vtpm_platform_commands_supported_bitmap() -> u64 { + let mut bitmap: u64 = 0; + let vtpm = vtpm_get_locked(); + + for cmd in vtpm.get_supported_commands() { + bitmap |= 1u64 << *cmd as u32; + } + + bitmap +} + +fn is_vtpm_platform_command_supported(cmd: TpmPlatformCommand) -> bool { + let vtpm = vtpm_get_locked(); + vtpm.get_supported_commands().iter().any(|x| *x == cmd) +} + +const SEND_COMMAND_REQ_INBUF_SIZE: usize = PAGE_SIZE - 9; + +// vTPM protocol services (SVSM spec, table 14) +const SVSM_VTPM_QUERY: u32 = 0; +const SVSM_VTPM_COMMAND: u32 = 1; + +/// TPM_SEND_COMMAND request structure (SVSM spec, table 16) +#[derive(Clone, Copy, Debug)] +#[repr(C, packed)] +struct TpmSendCommandRequest { + /// MSSIM platform command ID + command: u32, + /// Locality usage for the vTPM is not defined yet (must be zero) + locality: u8, + /// Size of the input buffer + inbuf_size: u32, + /// Input buffer that contains the TPM command + inbuf: [u8; SEND_COMMAND_REQ_INBUF_SIZE], +} + +impl TpmSendCommandRequest { + // Take as slice and return a reference for Self + pub fn try_from_as_ref(buffer: &[u8]) -> Result<&Self, SvsmReqError> { + let buffer = buffer + .get(..size_of::()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + // SAFETY: TpmSendCommandRequest has no invalid representations, as it + // is comprised entirely of integer types. It is repr(packed), so its + // required alignment is simply 1. We have checked the size, so this + // is entirely safe. + let request = unsafe { &*buffer.as_ptr().cast::() }; + + Ok(request) + } + + pub fn send(&self) -> Result, SvsmReqError> { + // TODO: Before implementing locality, we need to agree what it means + // to the platform + if self.locality != 0 { + return Err(SvsmReqError::invalid_parameter()); + } + + let mut length = self.inbuf_size as usize; + + let tpm_cmd = self + .inbuf + .get(..length) + .ok_or_else(SvsmReqError::invalid_parameter)?; + let mut buffer: Vec = Vec::with_capacity(SEND_COMMAND_RESP_OUTBUF_SIZE); + buffer.extend_from_slice(tpm_cmd); + + // The buffer slice must be large enough to hold the TPM command response + buffer.resize(SEND_COMMAND_RESP_OUTBUF_SIZE, 0); + + let vtpm = vtpm_get_locked(); + vtpm.send_tpm_command(buffer.as_mut_slice(), &mut length, self.locality)?; + + if length > buffer.len() { + return Err(SvsmReqError::invalid_request()); + } + buffer.truncate(length); + + Ok(buffer) + } +} + +const SEND_COMMAND_RESP_OUTBUF_SIZE: usize = PAGE_SIZE - 4; + +/// TPM_SEND_COMMAND response structure (SVSM spec, table 17) +#[derive(Clone, Copy, Debug)] +#[repr(C, packed)] +struct TpmSendCommandResponse { + /// Size of the output buffer + outbuf_size: u32, + /// Output buffer that will hold the command response + outbuf: [u8; SEND_COMMAND_RESP_OUTBUF_SIZE], +} + +impl TpmSendCommandResponse { + // Take as slice and return a &mut Self + pub fn try_from_as_mut_ref(buffer: &mut [u8]) -> Result<&mut Self, SvsmReqError> { + let buffer = buffer + .get_mut(..size_of::()) + .ok_or_else(SvsmReqError::invalid_parameter)?; + + // SAFETY: TpmSendCommandResponse has no invalid representations, as it + // is comprised entirely of integer types. It is repr(packed), so its + // required alignment is simply 1. We have checked the size, so this + // is entirely safe. + let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; + + Ok(response) + } + + /// Write the response to the outbuf + /// + /// # Arguments + /// + /// * `response`: TPM_SEND_COMMAND response slice + pub fn set_outbuf(&mut self, response: &[u8]) -> Result<(), SvsmReqError> { + self.outbuf + .get_mut(..response.len()) + .ok_or_else(SvsmReqError::invalid_request)? + .copy_from_slice(response); + self.outbuf_size = response.len() as u32; + + Ok(()) + } +} + +fn vtpm_query_request(params: &mut RequestParams) -> Result<(), SvsmReqError> { + // Bitmap of the supported vTPM commands + params.rcx = vtpm_platform_commands_supported_bitmap(); + // Supported vTPM features. Must-be-zero + params.rdx = 0; + + Ok(()) +} + +/// Send a TpmSendCommandRequest to the vTPM +/// +/// # Arguments +/// +/// * `buffer`: Contains the TpmSendCommandRequest. It will also be +/// used to store the TpmSendCommandResponse as a byte slice +/// +/// # Returns +/// +/// * `u32`: Number of bytes written back to `buffer` as part of +/// the TpmSendCommandResponse +fn tpm_send_command_request(buffer: &mut [u8]) -> Result { + let outbuf: Vec = { + let request = TpmSendCommandRequest::try_from_as_ref(buffer)?; + request.send()? + }; + let response = TpmSendCommandResponse::try_from_as_mut_ref(buffer)?; + let _ = response.set_outbuf(outbuf.as_slice()); + + Ok(outbuf.len() as u32) +} + +fn vtpm_command_request(params: &RequestParams) -> Result<(), SvsmReqError> { + let paddr = PhysAddr::from(params.rcx); + + if paddr.is_null() { + return Err(SvsmReqError::invalid_parameter()); + } + if !valid_phys_address(paddr) { + return Err(SvsmReqError::invalid_address()); + } + + // The vTPM buffer size is one page, but it not required to be page aligned. + let start = paddr.page_align(); + let offset = paddr.page_offset(); + let end = (paddr + PAGE_SIZE).page_align_up(); + + let guard = PerCPUPageMappingGuard::create(start, end, 0)?; + let vaddr = guard.virt_addr() + offset; + + // vTPM common request/response structure (SVSM spec, table 15) + // + // First 4 bytes are used as input and output. + // IN: platform command + // OUT: platform command response size + + // SAFETY: vaddr comes from a new mapped region. + let command = unsafe { GuestPtr::::new(vaddr).read()? }; + + let cmd = TpmPlatformCommand::try_from(command)?; + + if !is_vtpm_platform_command_supported(cmd) { + return Err(SvsmReqError::unsupported_call()); + } + + let buffer = unsafe { from_raw_parts_mut(vaddr.as_mut_ptr::(), PAGE_SIZE) }; + + let response_size = match cmd { + TpmPlatformCommand::SendCommand => tpm_send_command_request(buffer)?, + }; + + // SAFETY: vaddr points to a new mapped region. + // if paddr + sizeof::() goes to the folowing page, it should + // not be a problem since the end of the requested region is + // (paddr + PAGE_SIZE), which requests another page. So + // write(response_size) can only happen on valid memory, mapped + // by PerCPUPageMappingGuard::create(). + unsafe { + GuestPtr::::new(vaddr).write(response_size)?; + } + + Ok(()) +} + +pub fn vtpm_protocol_request(request: u32, params: &mut RequestParams) -> Result<(), SvsmReqError> { + match request { + SVSM_VTPM_QUERY => vtpm_query_request(params), + SVSM_VTPM_COMMAND => vtpm_command_request(params), + _ => Err(SvsmReqError::unsupported_call()), + } +} diff --git a/stage2/src/requests.rs b/stage2/src/requests.rs new file mode 100644 index 000000000..b8eb68dd5 --- /dev/null +++ b/stage2/src/requests.rs @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::cpu::percpu::{process_requests, this_cpu, wait_for_requests}; +use crate::cpu::{flush_tlb_global_sync, IrqGuard}; +use crate::error::SvsmError; +use crate::mm::GuestPtr; +use crate::protocols::apic::apic_protocol_request; +use crate::protocols::core::core_protocol_request; +use crate::protocols::errors::{SvsmReqError, SvsmResultCode}; +use crate::sev::ghcb::switch_to_vmpl; + +#[cfg(all(feature = "mstpm", not(test)))] +use crate::protocols::{vtpm::vtpm_protocol_request, SVSM_VTPM_PROTOCOL}; +use crate::protocols::{RequestParams, SVSM_APIC_PROTOCOL, SVSM_CORE_PROTOCOL}; +use crate::sev::vmsa::VMSAControl; +use crate::types::GUEST_VMPL; +use crate::utils::halt; +use cpuarch::vmsa::GuestVMExit; + +/// The SVSM Calling Area (CAA) +#[repr(C, packed)] +#[derive(Debug, Clone, Copy)] +pub struct SvsmCaa { + call_pending: u8, + mem_available: u8, + pub no_eoi_required: u8, + _rsvd: [u8; 5], +} + +impl SvsmCaa { + /// Returns a copy of the this CAA with the `call_pending` field cleared. + #[inline] + const fn serviced(self) -> Self { + Self { + call_pending: 0, + ..self + } + } + + /// Returns a copy of the this CAA with the `no_eoi_required` flag updated + #[inline] + pub const fn update_no_eoi_required(self, no_eoi_required: u8) -> Self { + Self { + no_eoi_required, + ..self + } + } + + /// A CAA with all of its fields set to zero. + #[inline] + pub const fn zeroed() -> Self { + Self { + call_pending: 0, + mem_available: 0, + no_eoi_required: 0, + _rsvd: [0; 5], + } + } +} + +const _: () = assert!(core::mem::size_of::() == 8); + +/// Returns true if there is a valid VMSA mapping +pub fn update_mappings() -> Result<(), SvsmError> { + let cpu = this_cpu(); + let mut locked = cpu.guest_vmsa_ref(); + let mut ret = Ok(()); + + if !locked.needs_update() { + return Ok(()); + } + + cpu.unmap_guest_vmsa(); + cpu.unmap_caa(); + + match locked.vmsa_phys() { + Some(paddr) => cpu.map_guest_vmsa(paddr)?, + None => ret = Err(SvsmError::MissingVMSA), + } + + if let Some(paddr) = locked.caa_phys() { + cpu.map_guest_caa(paddr)? + } + + locked.set_updated(); + + ret +} + +struct RequestInfo { + protocol: u32, + request: u32, + params: RequestParams, +} + +fn request_loop_once( + params: &mut RequestParams, + protocol: u32, + request: u32, +) -> Result { + if !matches!(params.guest_exit_code, GuestVMExit::VMGEXIT) { + return Ok(false); + } + + match protocol { + SVSM_CORE_PROTOCOL => core_protocol_request(request, params).map(|_| true), + #[cfg(all(feature = "mstpm", not(test)))] + SVSM_VTPM_PROTOCOL => vtpm_protocol_request(request, params).map(|_| true), + SVSM_APIC_PROTOCOL => apic_protocol_request(request, params).map(|_| true), + _ => Err(SvsmReqError::unsupported_protocol()), + } +} + +fn check_requests() -> Result { + let cpu = this_cpu(); + let vmsa_ref = cpu.guest_vmsa_ref(); + if let Some(caa_addr) = vmsa_ref.caa_addr() { + let calling_area = GuestPtr::::new(caa_addr); + // SAFETY: guest vmsa and ca are always validated before beeing updated + // (core_remap_ca(), core_create_vcpu() or prepare_fw_launch()) so + // they're safe to use. + let caa = unsafe { calling_area.read()? }; + + let caa_serviced = caa.serviced(); + + // SAFETY: guest vmsa is always validated before beeing updated + // (core_remap_ca() or core_create_vcpu()) so it's safe to use. + unsafe { + calling_area.write(caa_serviced)?; + } + + Ok(caa.call_pending != 0) + } else { + Ok(false) + } +} + +pub fn request_loop() { + loop { + // Determine whether the guest is runnable. If not, halt and wait for + // the guest to execute. When halting, assume that the hypervisor + // will schedule the guest VMPL on its own. + if update_mappings().is_ok() { + // No interrupts may be processed once guest APIC state is + // updated, since handling an interrupt may modify the guest + // APIC state calculations, which could cause state corruption. + // If interrupts are disabled, then any additional guest APIC + // updates generated by the host will block the VMPL transition and + // permit reevaluation of guest APIC state. + let guard = IrqGuard::new(); + + // Make VMSA runnable again by setting EFER.SVME. This requires a + // separate scope so the CPU reference does not outlive the use of + // the VMSA reference. + { + let cpu = this_cpu(); + let mut vmsa_ref = cpu.guest_vmsa_ref(); + let caa_addr = vmsa_ref.caa_addr(); + let vmsa = vmsa_ref.vmsa(); + + // Update APIC interrupt emulation state if required. + cpu.update_apic_emulation(vmsa, caa_addr); + + // Make VMSA runnable again by setting EFER.SVME + vmsa.enable(); + } + + flush_tlb_global_sync(); + + switch_to_vmpl(GUEST_VMPL as u32); + + drop(guard); + } else { + log::debug!("No VMSA or CAA! Halting"); + halt(); + } + + // Update mappings again on return from the guest VMPL or halt. If this + // is an AP it may have been created from the context of another CPU. + if update_mappings().is_err() { + continue; + } + + // Obtain a reference to the VMSA just long enough to extract the + // request parameters. + let (protocol, request) = { + let cpu = this_cpu(); + let mut vmsa_ref = cpu.guest_vmsa_ref(); + let vmsa = vmsa_ref.vmsa(); + + // Clear EFER.SVME in guest VMSA + vmsa.disable(); + + let rax = vmsa.rax; + + ((rax >> 32) as u32, (rax & 0xffff_ffff) as u32) + }; + + match check_requests() { + Ok(pending) => { + if pending { + process_requests(); + } + } + Err(SvsmReqError::RequestError(code)) => { + log::debug!( + "Soft error handling protocol {} request {}: {:?}", + protocol, + request, + code + ); + } + Err(SvsmReqError::FatalError(err)) => { + log::error!( + "Fatal error handling core protocol request {}: {:?}", + request, + err + ); + break; + } + } + } +} + +#[no_mangle] +pub extern "C" fn request_processing_main() { + let apic_id = this_cpu().get_apic_id(); + + log::info!("Launching request-processing task on CPU {}", apic_id); + + loop { + wait_for_requests(); + + // Obtain a reference to the VMSA just long enough to extract the + // request parameters. + let mut rax: u64; + let mut request_info = { + let cpu = this_cpu(); + let mut vmsa_ref = cpu.guest_vmsa_ref(); + let vmsa = vmsa_ref.vmsa(); + + // Clear EFER.SVME in guest VMSA + vmsa.disable(); + + rax = vmsa.rax; + RequestInfo { + protocol: (rax >> 32) as u32, + request: (rax & 0xffff_ffff) as u32, + params: RequestParams::from_vmsa(vmsa), + } + }; + + rax = match request_loop_once( + &mut request_info.params, + request_info.protocol, + request_info.request, + ) { + Ok(success) => match success { + true => SvsmResultCode::SUCCESS.into(), + false => rax, + }, + Err(SvsmReqError::RequestError(code)) => { + log::debug!( + "Soft error handling protocol {} request {}: {:?}", + request_info.protocol, + request_info.request, + code + ); + code.into() + } + Err(SvsmReqError::FatalError(err)) => { + log::error!( + "Fatal error handling core protocol request {}: {:?}", + request_info.request, + err + ); + break; + } + }; + + // Write back results + { + let cpu = this_cpu(); + let mut vmsa_ref = cpu.guest_vmsa_ref(); + let vmsa = vmsa_ref.vmsa(); + vmsa.rax = rax; + request_info.params.write_back(vmsa); + } + } + + panic!("Request processing task died unexpectedly"); +} diff --git a/stage2/src/serial.rs b/stage2/src/serial.rs new file mode 100644 index 000000000..37256830b --- /dev/null +++ b/stage2/src/serial.rs @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::io::{IOPort, DEFAULT_IO_DRIVER}; +use core::fmt::Debug; + +pub const SERIAL_PORT: u16 = 0x3f8; +const BAUD: u32 = 9600; +const DLAB: u8 = 0x80; + +pub const TXR: u16 = 0; // Transmit register +pub const _RXR: u16 = 0; // Receive register +pub const IER: u16 = 1; // Interrupt enable +pub const _IIR: u16 = 2; // Interrupt ID +pub const FCR: u16 = 2; // FIFO Control +pub const LCR: u16 = 3; // Line Control +pub const MCR: u16 = 4; // Modem Control +pub const LSR: u16 = 5; // Line Status +pub const _MSR: u16 = 6; // Modem Status +pub const DLL: u16 = 0; // Divisor Latch Low +pub const DLH: u16 = 1; // Divisor Latch High + +pub const RCVRDY: u8 = 0x01; +pub const XMTRDY: u8 = 0x20; + +pub trait Terminal: Sync + Debug { + fn put_byte(&self, _ch: u8) {} + fn get_byte(&self) -> u8 { + 0 + } +} + +#[derive(Debug, Copy, Clone)] +pub struct SerialPort<'a> { + driver: &'a dyn IOPort, + port: u16, +} + +impl<'a> SerialPort<'a> { + pub const fn new(driver: &'a dyn IOPort, p: u16) -> Self { + SerialPort { driver, port: p } + } + + pub fn init(&self) { + let divisor: u32 = 115200 / BAUD; + + self.outb(LCR, 0x3); // 8n1 + self.outb(IER, 0x0); // No Interrupt + self.outb(FCR, 0x0); // No FIFO + self.outb(MCR, 0x3); // DTR + RTS + + let c = self.inb(LCR); + self.outb(LCR, c | DLAB); + self.outb(DLL, (divisor & 0xff) as u8); + self.outb(DLH, ((divisor >> 8) & 0xff) as u8); + self.outb(LCR, c & !DLAB); + } + + #[inline] + fn inb(&self, port: u16) -> u8 { + self.driver.inb(self.port + port) + } + + #[inline] + fn outb(&self, port: u16, val: u8) { + self.driver.outb(self.port + port, val); + } +} + +impl Terminal for SerialPort<'_> { + fn put_byte(&self, ch: u8) { + loop { + let xmt = self.inb(LSR); + if (xmt & XMTRDY) == XMTRDY { + break; + } + } + + self.outb(TXR, ch) + } + + fn get_byte(&self) -> u8 { + loop { + let rcv = self.inb(LSR); + if (rcv & RCVRDY) == RCVRDY { + return self.inb(0); + } + } + } +} + +pub static DEFAULT_SERIAL_PORT: SerialPort<'_> = SerialPort::new(&DEFAULT_IO_DRIVER, SERIAL_PORT); diff --git a/stage2/src/sev/ghcb.rs b/stage2/src/sev/ghcb.rs new file mode 100644 index 000000000..ff0d76888 --- /dev/null +++ b/stage2/src/sev/ghcb.rs @@ -0,0 +1,776 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::cpu::msr::{write_msr, SEV_GHCB}; +use crate::cpu::percpu::this_cpu; +use crate::cpu::{flush_tlb_global_sync, IrqGuard, X86GeneralRegs}; +use crate::error::SvsmError; +use crate::mm::validate::{ + valid_bitmap_clear_valid_4k, valid_bitmap_set_valid_4k, valid_bitmap_valid_addr, +}; +use crate::mm::virt_to_phys; +use crate::platform::PageStateChangeOp; +use crate::sev::hv_doorbell::HVDoorbell; +use crate::sev::utils::raw_vmgexit; +use crate::types::{Bytes, PageSize, GUEST_VMPL, PAGE_SIZE_2M}; +use crate::utils::MemoryRegion; + +use crate::mm::PageBox; +use core::arch::global_asm; +use core::mem::{self, offset_of}; +use core::ops::Deref; +use core::ptr; +use core::sync::atomic::{AtomicU16, AtomicU32, AtomicU64, AtomicU8, Ordering}; + +use super::msr_protocol::{invalidate_page_msr, register_ghcb_gpa_msr, validate_page_msr}; +use super::{pvalidate, PvalidateOp}; + +use zerocopy::{FromZeros, Immutable, IntoBytes}; + +#[repr(C, packed)] +#[derive(Debug, Default, Clone, Copy, IntoBytes, Immutable)] +pub struct PageStateChangeHeader { + cur_entry: u16, + end_entry: u16, + reserved: u32, +} + +const PSC_GFN_MASK: u64 = ((1u64 << 52) - 1) & !0xfffu64; + +const PSC_OP_SHIFT: u8 = 52; +const PSC_OP_PRIVATE: u64 = 1 << PSC_OP_SHIFT; +const PSC_OP_SHARED: u64 = 2 << PSC_OP_SHIFT; +const PSC_OP_PSMASH: u64 = 3 << PSC_OP_SHIFT; +const PSC_OP_UNSMASH: u64 = 4 << PSC_OP_SHIFT; + +const PSC_FLAG_HUGE_SHIFT: u8 = 56; +const PSC_FLAG_HUGE: u64 = 1 << PSC_FLAG_HUGE_SHIFT; + +const GHCB_BUFFER_SIZE: usize = 0x7f0; + +macro_rules! ghcb_getter { + ($name:ident, $field:ident,$t:ty) => { + #[allow(unused)] + fn $name(&self) -> Result<$t, GhcbError> { + self.is_valid(offset_of!(Self, $field)) + .then(|| self.$field.load(Ordering::Relaxed)) + .ok_or(GhcbError::VmgexitInvalid) + } + }; +} + +macro_rules! ghcb_setter { + ($name:ident, $field:ident, $t:ty) => { + #[allow(unused)] + fn $name(&self, val: $t) { + self.$field.store(val, Ordering::Relaxed); + self.set_valid(offset_of!(Self, $field)); + } + }; +} + +#[derive(Clone, Copy, Debug)] +pub enum GhcbError { + // Attempted to write at an invalid offset in the GHCB + InvalidOffset, + // A response from the hypervisor after VMGEXIT is invalid + VmgexitInvalid, + // A response from the hypervisor included an error code + VmgexitError(u64, u64), +} + +impl From for SvsmError { + fn from(e: GhcbError) -> Self { + Self::Ghcb(e) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u64)] +#[expect(non_camel_case_types, clippy::upper_case_acronyms)] +enum GHCBExitCode { + RDTSC = 0x6e, + IOIO = 0x7b, + MSR = 0x7c, + RDTSCP = 0x87, + SNP_PSC = 0x8000_0010, + GUEST_REQUEST = 0x8000_0011, + GUEST_EXT_REQUEST = 0x8000_0012, + AP_CREATE = 0x80000013, + HV_DOORBELL = 0x8000_0014, + HV_IPI = 0x8000_0015, + CONFIGURE_INT_INJ = 0x8000_0019, + DISABLE_ALT_INJ = 0x8000_001A, + SPECIFIC_EOI = 0x8000_001B, +} + +#[derive(Clone, Copy, Debug)] +pub enum GHCBIOSize { + Size8, + Size16, + Size32, +} + +impl TryFrom for GHCBIOSize { + type Error = SvsmError; + + fn try_from(size: Bytes) -> Result { + match size { + Bytes::One => Ok(GHCBIOSize::Size8), + Bytes::Two => Ok(GHCBIOSize::Size16), + Bytes::Four => Ok(GHCBIOSize::Size32), + _ => Err(SvsmError::InvalidBytes), + } + } +} + +#[derive(Debug)] +pub struct GhcbPage(PageBox); + +impl GhcbPage { + pub fn new() -> Result { + let page = PageBox::::try_new_zeroed()?; + let vaddr = page.vaddr(); + let paddr = virt_to_phys(vaddr); + + // Make page invalid + pvalidate(vaddr, PageSize::Regular, PvalidateOp::Invalid)?; + + // Let the Hypervisor take the page back + invalidate_page_msr(paddr)?; + + // Needs guarding for Stage2 GHCB + if valid_bitmap_valid_addr(paddr) { + valid_bitmap_clear_valid_4k(paddr); + } + + // Map page unencrypted + this_cpu().get_pgtable().set_shared_4k(vaddr)?; + flush_tlb_global_sync(); + + // SAFETY: all zeros is a valid representation for the GHCB. + Ok(Self(page)) + } +} + +impl Drop for GhcbPage { + fn drop(&mut self) { + let vaddr = self.0.vaddr(); + let paddr = virt_to_phys(vaddr); + + // Re-encrypt page + this_cpu() + .get_pgtable() + .set_encrypted_4k(vaddr) + .expect("Could not re-encrypt page"); + + // Unregister GHCB PA + register_ghcb_gpa_msr(PhysAddr::null()).expect("Could not unregister GHCB"); + + // Ask the hypervisor to change the page back to the private page state. + validate_page_msr(paddr).expect("Could not change page state"); + + // Make page guest-valid + pvalidate(vaddr, PageSize::Regular, PvalidateOp::Valid).expect("Could not pvalidate page"); + + // Needs guarding for Stage2 GHCB + if valid_bitmap_valid_addr(paddr) { + valid_bitmap_set_valid_4k(paddr); + } + } +} + +impl Deref for GhcbPage { + type Target = GHCB; + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +#[repr(C)] +#[derive(Debug, FromZeros)] +pub struct GHCB { + reserved_1: [AtomicU8; 0xcb], + cpl: AtomicU8, + reserved_2: [AtomicU8; 0x74], + xss: AtomicU64, + reserved_3: [AtomicU8; 0x18], + dr7: AtomicU64, + reserved_4: [AtomicU8; 0x90], + rax: AtomicU64, + reserved_5: [AtomicU8; 0x100], + reserved_6: AtomicU64, + rcx: AtomicU64, + rdx: AtomicU64, + rbx: AtomicU64, + reserved_7: [AtomicU8; 0x70], + sw_exit_code: AtomicU64, + sw_exit_info_1: AtomicU64, + sw_exit_info_2: AtomicU64, + sw_scratch: AtomicU64, + reserved_8: [AtomicU8; 0x38], + xcr0: AtomicU64, + valid_bitmap: [AtomicU64; 2], + x87_state_gpa: AtomicU64, + reserved_9: [AtomicU8; 0x3f8], + buffer: [AtomicU8; GHCB_BUFFER_SIZE], + reserved_10: [AtomicU8; 0xa], + version: AtomicU16, + usage: AtomicU32, +} + +impl GHCB { + ghcb_getter!(get_cpl_valid, cpl, u8); + ghcb_setter!(set_cpl_valid, cpl, u8); + + ghcb_getter!(get_xss_valid, xss, u64); + ghcb_setter!(set_xss_valid, xss, u64); + + ghcb_getter!(get_dr7_valid, dr7, u64); + ghcb_setter!(set_dr7_valid, dr7, u64); + + ghcb_getter!(get_rax_valid, rax, u64); + ghcb_setter!(set_rax_valid, rax, u64); + + ghcb_getter!(get_rcx_valid, rcx, u64); + ghcb_setter!(set_rcx_valid, rcx, u64); + + ghcb_getter!(get_rdx_valid, rdx, u64); + ghcb_setter!(set_rdx_valid, rdx, u64); + + ghcb_getter!(get_rbx_valid, rbx, u64); + ghcb_setter!(set_rbx_valid, rbx, u64); + + ghcb_getter!(get_exit_code_valid, sw_exit_code, u64); + ghcb_setter!(set_exit_code_valid, sw_exit_code, u64); + + ghcb_getter!(get_exit_info_1_valid, sw_exit_info_1, u64); + ghcb_setter!(set_exit_info_1_valid, sw_exit_info_1, u64); + + ghcb_getter!(get_exit_info_2_valid, sw_exit_info_2, u64); + ghcb_setter!(set_exit_info_2_valid, sw_exit_info_2, u64); + + ghcb_getter!(get_sw_scratch_valid, sw_scratch, u64); + ghcb_setter!(set_sw_scratch_valid, sw_scratch, u64); + + ghcb_getter!(get_sw_xcr0_valid, xcr0, u64); + ghcb_setter!(set_sw_xcr0_valid, xcr0, u64); + + ghcb_getter!(get_sw_x87_state_gpa_valid, x87_state_gpa, u64); + ghcb_setter!(set_sw_x87_state_gpa_valid, x87_state_gpa, u64); + + ghcb_getter!(get_version_valid, version, u16); + ghcb_setter!(set_version_valid, version, u16); + + ghcb_getter!(get_usage_valid, usage, u32); + ghcb_setter!(set_usage_valid, usage, u32); + + pub fn rdtscp_regs(&self, regs: &mut X86GeneralRegs) -> Result<(), SvsmError> { + self.clear(); + self.vmgexit(GHCBExitCode::RDTSCP, 0, 0)?; + let rax = self.get_rax_valid()?; + let rdx = self.get_rdx_valid()?; + let rcx = self.get_rcx_valid()?; + regs.rax = rax as usize; + regs.rdx = rdx as usize; + regs.rcx = rcx as usize; + Ok(()) + } + + pub fn rdtsc_regs(&self, regs: &mut X86GeneralRegs) -> Result<(), SvsmError> { + self.clear(); + self.vmgexit(GHCBExitCode::RDTSC, 0, 0)?; + let rax = self.get_rax_valid()?; + let rdx = self.get_rdx_valid()?; + regs.rax = rax as usize; + regs.rdx = rdx as usize; + Ok(()) + } + + pub fn wrmsr(&self, msr_index: u32, value: u64) -> Result<(), SvsmError> { + self.wrmsr_raw(msr_index as u64, value & 0xFFFF_FFFF, value >> 32) + } + + pub fn wrmsr_regs(&self, regs: &X86GeneralRegs) -> Result<(), SvsmError> { + self.wrmsr_raw(regs.rcx as u64, regs.rax as u64, regs.rdx as u64) + } + + pub fn wrmsr_raw(&self, rcx: u64, rax: u64, rdx: u64) -> Result<(), SvsmError> { + self.clear(); + + self.set_rcx_valid(rcx); + self.set_rax_valid(rax); + self.set_rdx_valid(rdx); + + self.vmgexit(GHCBExitCode::MSR, 1, 0)?; + Ok(()) + } + + pub fn rdmsr_regs(&self, regs: &mut X86GeneralRegs) -> Result<(), SvsmError> { + self.clear(); + + self.set_rcx_valid(regs.rcx as u64); + + self.vmgexit(GHCBExitCode::MSR, 0, 0)?; + let rdx = self.get_rdx_valid()?; + let rax = self.get_rax_valid()?; + regs.rdx = rdx as usize; + regs.rax = rax as usize; + Ok(()) + } + + pub fn register(&self) -> Result<(), SvsmError> { + let vaddr = VirtAddr::from(self as *const GHCB); + let paddr = virt_to_phys(vaddr); + + // Register GHCB GPA + Ok(register_ghcb_gpa_msr(paddr)?) + } + + pub fn clear(&self) { + // Clear valid bitmap + self.valid_bitmap[0].store(0, Ordering::Relaxed); + self.valid_bitmap[1].store(0, Ordering::Relaxed); + + // Mark valid_bitmap valid + let off = offset_of!(Self, valid_bitmap); + self.set_valid(off); + self.set_valid(off + mem::size_of::()); + } + + fn set_valid(&self, offset: usize) { + let bit: usize = (offset >> 3) & 0x3f; + let index: usize = (offset >> 9) & 0x1; + let mask: u64 = 1 << bit; + + self.valid_bitmap[index].fetch_or(mask, Ordering::Relaxed); + } + + fn is_valid(&self, offset: usize) -> bool { + let bit: usize = (offset >> 3) & 0x3f; + let index: usize = (offset >> 9) & 0x1; + let mask: u64 = 1 << bit; + + (self.valid_bitmap[index].load(Ordering::Relaxed) & mask) == mask + } + + fn vmgexit( + &self, + exit_code: GHCBExitCode, + exit_info_1: u64, + exit_info_2: u64, + ) -> Result<(), GhcbError> { + // GHCB is version 2 + self.set_version_valid(2); + // GHCB Follows standard format + self.set_usage_valid(0); + self.set_exit_code_valid(exit_code as u64); + self.set_exit_info_1_valid(exit_info_1); + self.set_exit_info_2_valid(exit_info_2); + + let ghcb_address = VirtAddr::from(self as *const GHCB); + let ghcb_pa = u64::from(virt_to_phys(ghcb_address)); + // Disable interrupts between writing the MSR and making the GHCB call + // to prevent reentrant use of the GHCB MSR. + let guard = IrqGuard::new(); + write_msr(SEV_GHCB, ghcb_pa); + unsafe { + raw_vmgexit(); + } + drop(guard); + + let sw_exit_info_1 = self.get_exit_info_1_valid()?; + if sw_exit_info_1 != 0 { + return Err(GhcbError::VmgexitError( + sw_exit_info_1, + self.sw_exit_info_2.load(Ordering::Relaxed), + )); + } + + Ok(()) + } + + pub fn ioio_in(&self, port: u16, size: GHCBIOSize) -> Result { + self.clear(); + + let mut info: u64 = 1; // IN instruction + + info |= (port as u64) << 16; + + match size { + GHCBIOSize::Size8 => info |= 1 << 4, + GHCBIOSize::Size16 => info |= 1 << 5, + GHCBIOSize::Size32 => info |= 1 << 6, + } + + self.vmgexit(GHCBExitCode::IOIO, info, 0)?; + let rax = self.get_rax_valid()?; + Ok(rax) + } + + pub fn ioio_out(&self, port: u16, size: GHCBIOSize, value: u64) -> Result<(), SvsmError> { + self.clear(); + + let mut info: u64 = 0; // OUT instruction + + info |= (port as u64) << 16; + + match size { + GHCBIOSize::Size8 => info |= 1 << 4, + GHCBIOSize::Size16 => info |= 1 << 5, + GHCBIOSize::Size32 => info |= 1 << 6, + } + + self.set_rax_valid(value); + self.vmgexit(GHCBExitCode::IOIO, info, 0)?; + Ok(()) + } + + fn write_buffer(&self, data: &T, offset: usize) -> Result<(), GhcbError> + where + T: IntoBytes + Immutable, + { + let src = data.as_bytes(); + let dst = &self + .buffer + .get(offset..) + .ok_or(GhcbError::InvalidOffset)? + .get(..src.len()) + .ok_or(GhcbError::InvalidOffset)?; + for (dst, src) in dst.iter().zip(src.iter().copied()) { + dst.store(src, Ordering::Relaxed); + } + Ok(()) + } + + pub fn psc_entry( + &self, + paddr: PhysAddr, + op_mask: u64, + current_page: u64, + size: PageSize, + ) -> u64 { + assert!(size == PageSize::Regular || paddr.is_aligned(PAGE_SIZE_2M)); + + let mut entry: u64 = + ((paddr.bits() as u64) & PSC_GFN_MASK) | op_mask | (current_page & 0xfffu64); + if size == PageSize::Huge { + entry |= PSC_FLAG_HUGE; + } + + entry + } + + pub fn page_state_change( + &self, + region: MemoryRegion, + size: PageSize, + op: PageStateChangeOp, + ) -> Result<(), SvsmError> { + // Maximum entries (8 bytes each_ minus 8 bytes for header + let max_entries: u16 = ((GHCB_BUFFER_SIZE - 8) / 8).try_into().unwrap(); + let mut entries: u16 = 0; + let mut paddr = region.start(); + let end = region.end(); + let op_mask: u64 = match op { + PageStateChangeOp::Private => PSC_OP_PRIVATE, + PageStateChangeOp::Shared => PSC_OP_SHARED, + PageStateChangeOp::Psmash => PSC_OP_PSMASH, + PageStateChangeOp::Unsmash => PSC_OP_UNSMASH, + }; + + self.clear(); + + while paddr < end { + let size = if size == PageSize::Huge + && paddr.is_aligned(PAGE_SIZE_2M) + && paddr + PAGE_SIZE_2M <= end + { + PageSize::Huge + } else { + PageSize::Regular + }; + let pgsize = usize::from(size); + let entry = self.psc_entry(paddr, op_mask, 0, size); + let offset = usize::from(entries) * 8 + 8; + self.write_buffer(&entry, offset)?; + entries += 1; + paddr = paddr + pgsize; + + if entries == max_entries || paddr >= end { + let header = PageStateChangeHeader { + cur_entry: 0, + end_entry: entries - 1, + reserved: 0, + }; + self.write_buffer(&header, 0)?; + + let buffer_va = VirtAddr::from(self.buffer.as_ptr()); + let buffer_pa = u64::from(virt_to_phys(buffer_va)); + self.set_sw_scratch_valid(buffer_pa); + + if let Err(mut e) = self.vmgexit(GHCBExitCode::SNP_PSC, 0, 0) { + if let Err(err) = self.get_exit_info_2_valid() { + e = err; + } + + if let GhcbError::VmgexitError(_, info2) = e { + let info_high: u32 = (info2 >> 32) as u32; + let info_low: u32 = (info2 & 0xffff_ffffu64) as u32; + log::error!( + "GHCB SnpPageStateChange failed err_high: {:#x} err_low: {:#x}", + info_high, + info_low + ); + } + return Err(e.into()); + } + + entries = 0; + } + } + + Ok(()) + } + + pub fn ap_create( + &self, + vmsa_gpa: PhysAddr, + apic_id: u64, + vmpl: u64, + sev_features: u64, + ) -> Result<(), SvsmError> { + self.clear(); + let exit_info_1: u64 = 1 | (vmpl & 0xf) << 16 | apic_id << 32; + let exit_info_2: u64 = vmsa_gpa.into(); + self.set_rax_valid(sev_features); + self.vmgexit(GHCBExitCode::AP_CREATE, exit_info_1, exit_info_2)?; + Ok(()) + } + + pub fn register_guest_vmsa( + &self, + vmsa_gpa: PhysAddr, + apic_id: u64, + vmpl: u64, + sev_features: u64, + ) -> Result<(), SvsmError> { + self.clear(); + let exit_info_1: u64 = (vmpl & 0xf) << 16 | apic_id << 32; + let exit_info_2: u64 = vmsa_gpa.into(); + self.set_rax_valid(sev_features); + self.vmgexit(GHCBExitCode::AP_CREATE, exit_info_1, exit_info_2)?; + Ok(()) + } + + pub fn register_hv_doorbell(&self, paddr: PhysAddr) -> Result<(), SvsmError> { + self.clear(); + self.vmgexit(GHCBExitCode::HV_DOORBELL, 1, u64::from(paddr))?; + Ok(()) + } + + pub fn guest_request(&self, req_page: VirtAddr, resp_page: VirtAddr) -> Result<(), SvsmError> { + self.clear(); + + let info1: u64 = u64::from(virt_to_phys(req_page)); + let info2: u64 = u64::from(virt_to_phys(resp_page)); + + self.vmgexit(GHCBExitCode::GUEST_REQUEST, info1, info2)?; + + let sw_exit_info_2 = self.get_exit_info_2_valid()?; + if sw_exit_info_2 != 0 { + return Err(GhcbError::VmgexitError( + self.sw_exit_info_1.load(Ordering::Relaxed), + sw_exit_info_2, + ) + .into()); + } + + Ok(()) + } + + pub fn guest_ext_request( + &self, + req_page: VirtAddr, + resp_page: VirtAddr, + data_pages: VirtAddr, + data_size: u64, + ) -> Result<(), SvsmError> { + self.clear(); + + let info1: u64 = u64::from(virt_to_phys(req_page)); + let info2: u64 = u64::from(virt_to_phys(resp_page)); + let rax: u64 = u64::from(virt_to_phys(data_pages)); + + self.set_rax_valid(rax); + self.set_rbx_valid(data_size); + + self.vmgexit(GHCBExitCode::GUEST_EXT_REQUEST, info1, info2)?; + + let sw_exit_info_2 = self.get_exit_info_2_valid()?; + + // On error, RBX and exit_info_2 are returned for proper error handling. + // For an extended request, if the buffer provided is too small, the hypervisor + // will return in RBX the number of contiguous pages required + if sw_exit_info_2 != 0 { + return Err( + GhcbError::VmgexitError(self.rbx.load(Ordering::Relaxed), sw_exit_info_2).into(), + ); + } + + Ok(()) + } + + pub fn hv_ipi(&self, icr: u64) -> Result<(), SvsmError> { + self.clear(); + self.vmgexit(GHCBExitCode::HV_IPI, icr, 0)?; + Ok(()) + } + + pub fn configure_interrupt_injection(&self, vector: usize) -> Result<(), SvsmError> { + self.clear(); + self.vmgexit(GHCBExitCode::CONFIGURE_INT_INJ, vector as u64, 0)?; + Ok(()) + } + + pub fn specific_eoi(&self, vector: u8, vmpl: u8) -> Result<(), SvsmError> { + self.clear(); + let exit_info = ((vmpl as u64) << 16) | (vector as u64); + self.vmgexit(GHCBExitCode::SPECIFIC_EOI, exit_info, 0)?; + Ok(()) + } + + pub fn disable_alternate_injection( + &self, + tpr: u8, + in_intr_shadow: bool, + interrupts_enabled: bool, + ) -> Result<(), SvsmError> { + let mut exit_info = (GUEST_VMPL as u64) << 16; + exit_info |= (tpr as u64) << 8; + if in_intr_shadow { + exit_info |= 2; + } + if interrupts_enabled { + exit_info |= 1; + } + self.clear(); + self.vmgexit(GHCBExitCode::DISABLE_ALT_INJ, exit_info, 0)?; + Ok(()) + } + + #[inline] + #[cfg(test)] + pub fn fill(&self, val: u8) { + let bytes = unsafe { + // SAFETY: All bytes in `Self` are part of an atomic integer type. + // This allows us to cast `Self` to a slice of `AtomicU8`s. + core::slice::from_raw_parts(self as *const _ as *const AtomicU8, size_of::()) + }; + for byte in bytes { + byte.store(val, Ordering::Relaxed); + } + } +} + +extern "C" { + pub fn switch_to_vmpl_unsafe(hv_doorbell: *const HVDoorbell, vmpl: u32) -> bool; +} + +pub fn switch_to_vmpl(vmpl: u32) { + // The switch to a lower VMPL must be done with an assembly sequence in + // order to ensure that any #HV that occurs during the sequence will + // correctly block the VMPL switch so that events can be processed. + let hv_doorbell = this_cpu().hv_doorbell(); + let ptr = match hv_doorbell { + Some(doorbell) => ptr::from_ref(doorbell), + None => ptr::null(), + }; + unsafe { + if !switch_to_vmpl_unsafe(ptr, vmpl) { + panic!("Failed to switch to VMPL {}", vmpl); + } + } +} + +global_asm!( + r#" + .globl switch_to_vmpl_unsafe + switch_to_vmpl_unsafe: + + /* Upon entry, + * rdi = pointer to the HV doorbell page + * esi = target VMPL + */ + /* Check if NoFurtherSignal is set (bit 15 of the first word of the + * #HV doorbell page). If so, abort the transition. */ + test %rdi, %rdi + jz switch_vmpl_proceed + testw $0x8000, (%rdi) + + /* From this point until the vmgexit, if a #HV arrives, the #HV handler + * must prevent the VMPL transition. */ + .globl switch_vmpl_window_start + switch_vmpl_window_start: + jnz switch_vmpl_cancel + + switch_vmpl_proceed: + /* Use the MSR-based VMPL switch request to avoid any need to use the + * GHCB page. Run VMPL request is 0x16 and response is 0x17. */ + movl $0x16, %eax + movl %esi, %edx + movl ${SEV_GHCB}, %ecx + wrmsr + rep; vmmcall + + .globl switch_vmpl_window_end + switch_vmpl_window_end: + /* Verify that the request was honored. ECX still contains the MSR + * number. */ + rdmsr + andl $0xFFF, %eax + cmpl $0x17, %eax + jz switch_vmpl_cancel + xorl %eax, %eax + ret + + /* An aborted VMPL switch is treated as a successful switch. */ + .globl switch_vmpl_cancel + switch_vmpl_cancel: + movl $1, %eax + ret + "#, + SEV_GHCB = const SEV_GHCB, + options(att_syntax) +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ghcb_layout() { + assert_eq!(offset_of!(GHCB, cpl), 0x0cb); + assert_eq!(offset_of!(GHCB, xss), 0x140); + assert_eq!(offset_of!(GHCB, dr7), 0x160); + assert_eq!(offset_of!(GHCB, rax), 0x1f8); + assert_eq!(offset_of!(GHCB, rcx), 0x308); + assert_eq!(offset_of!(GHCB, rdx), 0x310); + assert_eq!(offset_of!(GHCB, rbx), 0x318); + assert_eq!(offset_of!(GHCB, sw_exit_code), 0x390); + assert_eq!(offset_of!(GHCB, sw_exit_info_1), 0x398); + assert_eq!(offset_of!(GHCB, sw_exit_info_2), 0x3a0); + assert_eq!(offset_of!(GHCB, sw_scratch), 0x3a8); + assert_eq!(offset_of!(GHCB, xcr0), 0x3e8); + assert_eq!(offset_of!(GHCB, valid_bitmap), 0x3f0); + assert_eq!(offset_of!(GHCB, x87_state_gpa), 0x400); + assert_eq!(offset_of!(GHCB, buffer), 0x800); + assert_eq!(offset_of!(GHCB, version), 0xffa); + assert_eq!(offset_of!(GHCB, usage), 0xffc); + assert_eq!(mem::size_of::(), 0x1000); + } +} diff --git a/stage2/src/sev/hv_doorbell.rs b/stage2/src/sev/hv_doorbell.rs new file mode 100644 index 000000000..7593d472e --- /dev/null +++ b/stage2/src/sev/hv_doorbell.rs @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 Copyright (c) Microsoft Corporation +// Author: Jon Lange (jlange@microsoft.com) + +use crate::cpu::idt::svsm::common_isr_handler; +use crate::cpu::percpu::this_cpu; +use crate::error::SvsmError; +use crate::mm::page_visibility::SharedBox; +use crate::mm::virt_to_phys; +use crate::sev::ghcb::GHCB; + +use bitfield_struct::bitfield; +use core::cell::UnsafeCell; +use core::sync::atomic::{AtomicU32, AtomicU8, Ordering}; + +#[bitfield(u8)] +pub struct HVDoorbellFlags { + pub nmi_pending: bool, + pub mc_pending: bool, + #[bits(5)] + rsvd_6_2: u8, + pub no_further_signal: bool, +} + +#[bitfield(u32)] +pub struct HVExtIntStatus { + pub pending_vector: u8, + pub nmi_pending: bool, + pub mc_pending: bool, + pub level_sensitive: bool, + #[bits(3)] + rsvd_13_11: u32, + pub multiple_vectors: bool, + #[bits(12)] + rsvd_26_15: u32, + ipi_requested: bool, + #[bits(3)] + rsvd_30_28: u32, + pub vector_31: bool, +} + +#[repr(C)] +#[derive(Debug)] +pub struct HVExtIntInfo { + pub status: AtomicU32, + pub irr: [AtomicU32; 7], + pub isr: [AtomicU32; 8], +} + +/// Allocates a new HV doorbell page and registers it on the hypervisor +/// using the given GHCB. +pub fn allocate_hv_doorbell_page(ghcb: &GHCB) -> Result<&'static HVDoorbell, SvsmError> { + let page = SharedBox::::try_new_zeroed()?; + + let vaddr = page.addr(); + let paddr = virt_to_phys(vaddr); + ghcb.register_hv_doorbell(paddr)?; + + // Create a static shared reference. + let ptr = page.leak(); + let doorbell = unsafe { + // SAFETY: Any bit-pattern is valid for `HVDoorbell` and it tolerates + // unsynchronized writes from the host. + ptr.as_ref() + }; + + Ok(doorbell) +} + +#[repr(C)] +#[derive(Debug)] +pub struct HVDoorbell { + pub vector: AtomicU8, + pub flags: AtomicU8, + pub no_eoi_required: AtomicU8, + pub per_vmpl_events: AtomicU8, + reserved_63_4: UnsafeCell<[u8; 60]>, + pub per_vmpl: [HVExtIntInfo; 3], +} + +impl HVDoorbell { + pub fn process_pending_events(&self) { + // Clear the NoFurtherSignal bit before processing. If any additional + // signal comes in after processing has commenced, it may be missed by + // this loop, but it will be detected when interrupts are processed + // again. Also clear the NMI bit, since NMIs are not expected. + let no_further_signal_mask: u8 = HVDoorbellFlags::new() + .with_no_further_signal(true) + .with_nmi_pending(true) + .into(); + let flags = HVDoorbellFlags::from( + self.flags + .fetch_and(!no_further_signal_mask, Ordering::Relaxed), + ); + + // #MC handling is not possible, so panic if a machine check has + // occurred. + if flags.mc_pending() { + panic!("#MC exception delivered via #HV"); + } + + // Consume interrupts as long as they are available. + loop { + // Consume any interrupt that may be present. + let vector = self.vector.swap(0, Ordering::Relaxed); + if vector == 0 { + break; + } + common_isr_handler(vector as usize); + } + + // Ignore per-VMPL events; these will be consumed when APIC emulation + // is performed. + } + + pub fn process_if_required(&self) { + let flags = HVDoorbellFlags::from(self.flags.load(Ordering::Relaxed)); + if flags.no_further_signal() { + self.process_pending_events(); + } + } + + pub fn no_eoi_required(&self) -> bool { + // Check to see if the "no EOI required" flag is set to determine + // whether an explicit EOI can be avoided. + let mut no_eoi_required = self.no_eoi_required.load(Ordering::Relaxed); + loop { + // If the flag is not set, then an explicit EOI is required. + if (no_eoi_required & 1) == 0 { + return false; + } + // Attempt to atomically clear the flag. + match self.no_eoi_required.compare_exchange_weak( + no_eoi_required, + no_eoi_required & !1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(new) => no_eoi_required = new, + } + } + + // If the flag was successfully cleared, then no explicit EOI is + // required. + true + } +} + +/// Gets the HV doorbell page configured for this CPU. +/// +/// # Panics +/// +/// Panics if te HV doorbell page has not been set up beforehand. +pub fn current_hv_doorbell() -> &'static HVDoorbell { + this_cpu() + .hv_doorbell() + .expect("HV doorbell page dereferenced before allocating") +} + +/// # Safety +/// This function takes a raw pointer to the #HV doorbell page because it is +/// called directly from assembly, and should not be invoked directly from +/// Rust code. +#[no_mangle] +pub unsafe extern "C" fn process_hv_events(hv_doorbell: *const HVDoorbell) { + unsafe { + (*hv_doorbell).process_pending_events(); + } +} diff --git a/stage2/src/sev/mod.rs b/stage2/src/sev/mod.rs new file mode 100644 index 000000000..1a9ce874d --- /dev/null +++ b/stage2/src/sev/mod.rs @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod ghcb; +pub mod hv_doorbell; +pub mod msr_protocol; +pub mod secrets_page; +pub mod status; +pub mod vmsa; + +pub mod utils; + +pub use msr_protocol::init_hypervisor_ghcb_features; +pub use secrets_page::{secrets_page, secrets_page_mut, SecretsPage, VMPCK_SIZE}; +pub use status::sev_status_init; +pub use status::sev_status_verify; +pub use utils::{pvalidate, pvalidate_range, PvalidateOp, SevSnpError}; +pub use utils::{rmp_adjust, RMPFlags}; diff --git a/stage2/src/sev/msr_protocol.rs b/stage2/src/sev/msr_protocol.rs new file mode 100644 index 000000000..dcb59c1e6 --- /dev/null +++ b/stage2/src/sev/msr_protocol.rs @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr}; +use crate::cpu::irq_state::raw_irqs_disable; +use crate::cpu::msr::{read_msr, write_msr, SEV_GHCB}; +use crate::cpu::{irqs_enabled, IrqGuard}; +use crate::error::SvsmError; +use crate::utils::halt; +use crate::utils::immut_after_init::ImmutAfterInitCell; + +use super::utils::raw_vmgexit; + +use bitflags::bitflags; +use core::fmt; +use core::fmt::Display; + +#[derive(Clone, Copy, Debug)] +pub enum GhcbMsrError { + // The info section of the response did not match our request + InfoMismatch, + // The data section of the response did not match our request, + // or it was malformed altogether. + DataMismatch, +} + +impl From for SvsmError { + fn from(e: GhcbMsrError) -> Self { + Self::GhcbMsr(e) + } +} + +#[derive(Clone, Copy, Debug)] +#[non_exhaustive] +pub enum GHCBMsr {} + +impl GHCBMsr { + pub const SEV_INFO_REQ: u64 = 0x02; + pub const SEV_INFO_RESP: u64 = 0x01; + pub const SNP_REG_GHCB_GPA_REQ: u64 = 0x12; + pub const SNP_REG_GHCB_GPA_RESP: u64 = 0x13; + pub const SNP_STATE_CHANGE_REQ: u64 = 0x14; + pub const SNP_STATE_CHANGE_RESP: u64 = 0x15; + pub const SNP_HV_FEATURES_REQ: u64 = 0x80; + pub const SNP_HV_FEATURES_RESP: u64 = 0x81; + pub const TERM_REQ: u64 = 0x100; +} + +bitflags! { + #[derive(Clone, Copy, Debug)] + pub struct GHCBHvFeatures: u64 { + const SEV_SNP = 1 << 0; + const SEV_SNP_AP_CREATION = 1 << 1; + const SEV_SNP_RESTR_INJ = 1 << 2; + const SEV_SNP_RESTR_INJ_TIMER = 1 << 3; + const APIC_ID_LIST = 1 << 4; + const SEV_SNP_MULTI_VMPL = 1 << 5; + const SEV_PAGE_STATE_CHANGE = 1 << 6; + const SEV_SNP_EXT_INTERRUPTS = 1 << 7; + } +} + +impl Display for GHCBHvFeatures { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{:#x}", self.bits())) + } +} + +static GHCB_HV_FEATURES: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +/// Check that we support the hypervisor's advertised GHCB versions. +pub fn verify_ghcb_version() { + // This function is normally only called early during initializtion before + // interrupts have been enabled, and before interrupt guards can safely be + // used. + assert!(!irqs_enabled()); + // Request SEV information. + write_msr(SEV_GHCB, GHCBMsr::SEV_INFO_REQ); + unsafe { + raw_vmgexit(); + } + let sev_info = read_msr(SEV_GHCB); + + // Parse the results. + + let response_ty = sev_info & 0xfff; + assert_eq!( + response_ty, + GHCBMsr::SEV_INFO_RESP, + "unexpected response type: {response_ty:#05x}" + ); + + // Compare announced supported GHCB MSR protocol version range + // for compatibility. + let min_version = (sev_info >> 32) & 0xffff; + let max_version = (sev_info >> 48) & 0xffff; + assert!( + (min_version..=max_version).contains(&2), + "the hypervisor doesn't support GHCB version 2 (min: {min_version}, max: {max_version})" + ); +} + +pub fn hypervisor_ghcb_features() -> GHCBHvFeatures { + *GHCB_HV_FEATURES +} + +pub fn init_hypervisor_ghcb_features() -> Result<(), GhcbMsrError> { + let guard = IrqGuard::new(); + write_msr(SEV_GHCB, GHCBMsr::SNP_HV_FEATURES_REQ); + unsafe { + raw_vmgexit(); + } + let result = read_msr(SEV_GHCB); + drop(guard); + if (result & 0xFFF) == GHCBMsr::SNP_HV_FEATURES_RESP { + let features = GHCBHvFeatures::from_bits_truncate(result >> 12); + + // Verify that the required features are supported. + let required = GHCBHvFeatures::SEV_SNP + | GHCBHvFeatures::SEV_SNP_AP_CREATION + | GHCBHvFeatures::SEV_SNP_MULTI_VMPL; + let missing = !features & required; + if !missing.is_empty() { + log::error!( + "Required hypervisor GHCB features not available: present={:#x}, required={:#x}, missing={:#x}", + features, required, missing + ); + // FIXME - enforce this panic once KVM advertises the required + // features. + // panic!("Required hypervisor GHCB features not available"); + } + + GHCB_HV_FEATURES + .init(&features) + .expect("Already initialized GHCB HV features"); + Ok(()) + } else { + Err(GhcbMsrError::InfoMismatch) + } +} + +pub fn register_ghcb_gpa_msr(addr: PhysAddr) -> Result<(), GhcbMsrError> { + let mut info = addr.bits() as u64; + + info |= GHCBMsr::SNP_REG_GHCB_GPA_REQ; + let guard = IrqGuard::new(); + write_msr(SEV_GHCB, info); + unsafe { + raw_vmgexit(); + } + info = read_msr(SEV_GHCB); + drop(guard); + + if (info & 0xfff) != GHCBMsr::SNP_REG_GHCB_GPA_RESP { + return Err(GhcbMsrError::InfoMismatch); + } + + if (info & !0xfff) != (addr.bits() as u64) { + return Err(GhcbMsrError::DataMismatch); + } + + Ok(()) +} + +fn set_page_valid_status_msr(addr: PhysAddr, valid: bool) -> Result<(), GhcbMsrError> { + let mut info: u64 = (addr.bits() as u64) & 0x000f_ffff_ffff_f000; + + if valid { + info |= 1u64 << 52; + } else { + info |= 2u64 << 52; + } + + info |= GHCBMsr::SNP_STATE_CHANGE_REQ; + let guard = IrqGuard::new(); + write_msr(SEV_GHCB, info); + unsafe { + raw_vmgexit(); + } + let response = read_msr(SEV_GHCB); + drop(guard); + + if (response & 0xfff) != GHCBMsr::SNP_STATE_CHANGE_RESP { + return Err(GhcbMsrError::InfoMismatch); + } + + if (response & !0xfff) != 0 { + return Err(GhcbMsrError::DataMismatch); + } + + Ok(()) +} + +pub fn validate_page_msr(addr: PhysAddr) -> Result<(), GhcbMsrError> { + set_page_valid_status_msr(addr, true) +} + +pub fn invalidate_page_msr(addr: PhysAddr) -> Result<(), GhcbMsrError> { + set_page_valid_status_msr(addr, false) +} + +pub fn request_termination_msr() -> ! { + let info: u64 = GHCBMsr::TERM_REQ; + + // Safety + // + // Since this processor is destined for a fatal termination, there is + // no reason to preserve interrupt state. Interrupts can be disabled + // outright prior to shutdown. + unsafe { + raw_irqs_disable(); + write_msr(SEV_GHCB, info); + raw_vmgexit(); + } + loop { + halt(); + } +} diff --git a/stage2/src/sev/secrets_page.rs b/stage2/src/sev/secrets_page.rs new file mode 100644 index 000000000..7f0f5486e --- /dev/null +++ b/stage2/src/sev/secrets_page.rs @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::VirtAddr; +use crate::locking::{RWLock, ReadLockGuard, WriteLockGuard}; +use crate::sev::vmsa::VMPL_MAX; +use crate::types::GUEST_VMPL; + +extern crate alloc; +use alloc::boxed::Box; + +pub const VMPCK_SIZE: usize = 32; + +#[derive(Copy, Clone, Debug)] +#[repr(C, packed)] +pub struct SecretsPage { + version: u32, + gctxt: u32, + fms: u32, + reserved_00c: u32, + gosvw: [u8; 16], + vmpck: [[u8; VMPCK_SIZE]; VMPL_MAX], + reserved_0a0: [u8; 96], + vmsa_tweak_bmp: [u64; 8], + svsm_base: u64, + svsm_size: u64, + svsm_caa: u64, + svsm_max_version: u32, + svsm_guest_vmpl: u8, + reserved_15d: [u8; 3], + tsc_factor: u32, + reserved_164: [u8; 3740], +} + +impl SecretsPage { + pub const fn new() -> Self { + Self { + version: 0, + gctxt: 0, + fms: 0, + reserved_00c: 0, + gosvw: [0; 16], + vmpck: [[0; VMPCK_SIZE]; VMPL_MAX], + reserved_0a0: [0; 96], + vmsa_tweak_bmp: [0; 8], + svsm_base: 0, + svsm_size: 0, + svsm_caa: 0, + svsm_max_version: 0, + svsm_guest_vmpl: 0, + reserved_15d: [0; 3], + tsc_factor: 0, + reserved_164: [0; 3740], + } + } + + /// Copy secrets page's content pointed by a [`VirtAddr`] + /// + /// # Safety + /// + /// The caller should verify that `source` points to mapped memory whose + /// size is at least the size of the [`SecretsPage`] structure. + pub unsafe fn copy_from(&mut self, source: VirtAddr) { + let from = source.as_ptr::(); + + unsafe { + *self = *from; + } + } + + /// Copy a secrets page's content to memory pointed by a [`VirtAddr`] + /// + /// # Safety + /// + /// The caller should verify that `target` points to mapped memory whose + /// size is at least the size of the [`SecretsPage`] structure. + /// + /// The caller should verify not to corrupt arbitrary memory, as this function + /// doesn't make any checks in that regard. + pub unsafe fn copy_to(&self, target: VirtAddr) { + let to = target.as_mut_ptr::(); + + unsafe { + *to = *self; + } + } + + pub fn copy_for_vmpl(&self, vmpl: usize) -> Box { + let mut sp = Box::new(*self); + for idx in 0..vmpl { + sp.clear_vmpck(idx); + } + + sp + } + + pub fn set_svsm_data(&mut self, base: u64, size: u64, caa_addr: u64) { + self.svsm_base = base; + self.svsm_size = size; + self.svsm_caa = caa_addr; + self.svsm_max_version = 1; + self.svsm_guest_vmpl = GUEST_VMPL as u8; + } + + pub fn get_vmpck(&self, idx: usize) -> [u8; VMPCK_SIZE] { + self.vmpck[idx] + } + + pub fn is_vmpck_clear(&self, idx: usize) -> bool { + self.vmpck[idx].iter().all(|e| *e == 0) + } + + pub fn clear_vmpck(&mut self, idx: usize) { + self.vmpck[idx].iter_mut().for_each(|e| *e = 0); + } +} + +impl Default for SecretsPage { + fn default() -> Self { + Self::new() + } +} + +static SECRETS_PAGE: RWLock = RWLock::new(SecretsPage::new()); + +pub fn secrets_page() -> ReadLockGuard<'static, SecretsPage> { + SECRETS_PAGE.lock_read() +} + +pub fn secrets_page_mut() -> WriteLockGuard<'static, SecretsPage> { + SECRETS_PAGE.lock_write() +} diff --git a/stage2/src/sev/status.rs b/stage2/src/sev/status.rs new file mode 100644 index 000000000..06a0503b1 --- /dev/null +++ b/stage2/src/sev/status.rs @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::cpu::msr::{read_msr, SEV_STATUS}; +use crate::utils::immut_after_init::ImmutAfterInitCell; +use bitflags::bitflags; +use core::fmt::{self, Write}; + +bitflags! { + #[derive(Copy, Clone, PartialEq, Eq)] + pub struct SEVStatusFlags: u64 { + const SEV = 1 << 0; + const SEV_ES = 1 << 1; + const SEV_SNP = 1 << 2; + const VTOM = 1 << 3; + const REFLECT_VC = 1 << 4; + const REST_INJ = 1 << 5; + const ALT_INJ = 1 << 6; + const DBGSWP = 1 << 7; + const PREV_HOST_IBS = 1 << 8; + const BTB_ISOLATION = 1 << 9; + const VMPL_SSS = 1 << 10; + const SECURE_TSC = 1 << 11; + const VMSA_REG_PROT = 1 << 16; + const SMT_PROT = 1 << 17; + } +} + +impl fmt::Display for SEVStatusFlags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + + if self.contains(SEVStatusFlags::SEV) { + f.write_str("SEV")?; + first = false; + } + + if self.contains(SEVStatusFlags::SEV_ES) { + if !first { + f.write_char(' ')?; + } + f.write_str("SEV-ES")?; + first = false; + } + + if self.contains(SEVStatusFlags::SEV_SNP) { + if !first { + f.write_char(' ')?; + } + f.write_str("SEV-SNP")?; + first = false; + } + + if self.contains(SEVStatusFlags::VTOM) { + if !first { + f.write_char(' ')?; + } + f.write_str("VTOM")?; + first = false; + } + + if self.contains(SEVStatusFlags::REFLECT_VC) { + if !first { + f.write_char(' ')?; + } + f.write_str("REFLECT_VC")?; + first = false; + } + + if self.contains(SEVStatusFlags::REST_INJ) { + if !first { + f.write_char(' ')?; + } + f.write_str("RESTRICTED_INJECTION")?; + first = false; + } + + if self.contains(SEVStatusFlags::ALT_INJ) { + if !first { + f.write_char(' ')?; + } + f.write_str("ALTERNATE_INJECTION")?; + first = false; + } + + if self.contains(SEVStatusFlags::DBGSWP) { + if !first { + f.write_char(' ')?; + } + f.write_str("DEBUG_SWAP")?; + first = false; + } + + if self.contains(SEVStatusFlags::PREV_HOST_IBS) { + if !first { + f.write_char(' ')?; + } + f.write_str("PREVENT_HOST_IBS")?; + first = false; + } + + if self.contains(SEVStatusFlags::BTB_ISOLATION) { + if !first { + f.write_char(' ')?; + } + f.write_str("SNP_BTB_ISOLATION")?; + first = false; + } + + if self.contains(SEVStatusFlags::SECURE_TSC) { + if !first { + f.write_char(' ')?; + } + f.write_str("SECURE_TSC")?; + first = false; + } + + if self.contains(SEVStatusFlags::VMSA_REG_PROT) { + if !first { + f.write_char(' ')?; + } + f.write_str("VMSA_REG_PROT")?; + } + + Ok(()) + } +} + +static SEV_FLAGS: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +fn read_sev_status() -> SEVStatusFlags { + SEVStatusFlags::from_bits_truncate(read_msr(SEV_STATUS)) +} + +pub fn sev_flags() -> SEVStatusFlags { + *SEV_FLAGS +} + +pub fn sev_status_init() { + let status: SEVStatusFlags = read_sev_status(); + SEV_FLAGS + .init(&status) + .expect("Already initialized SEV flags"); +} + +pub fn vtom_enabled() -> bool { + sev_flags().contains(SEVStatusFlags::VTOM) +} + +pub fn sev_restricted_injection() -> bool { + sev_flags().contains(SEVStatusFlags::REST_INJ) +} + +pub fn sev_status_verify() { + let required = SEVStatusFlags::SEV | SEVStatusFlags::SEV_ES | SEVStatusFlags::SEV_SNP; + let supported = SEVStatusFlags::DBGSWP + | SEVStatusFlags::VTOM + | SEVStatusFlags::REST_INJ + | SEVStatusFlags::PREV_HOST_IBS + | SEVStatusFlags::BTB_ISOLATION + | SEVStatusFlags::SMT_PROT; + + let status = sev_flags(); + let required_check = status & required; + let not_supported_check = status & !(supported | required); + + if required_check != required { + log::error!( + "Required features not available: {}", + required & !required_check + ); + panic!("Required SEV features not available"); + } + + if !not_supported_check.is_empty() { + log::error!("Unsupported features enabled: {}", not_supported_check); + panic!("Unsupported SEV features enabled"); + } +} + +impl SEVStatusFlags { + pub fn from_sev_features(sev_features: u64) -> Self { + SEVStatusFlags::from_bits(sev_features << 2).unwrap() + } + + pub fn as_sev_features(&self) -> u64 { + let sev_features = self.bits(); + sev_features >> 2 + } +} diff --git a/stage2/src/sev/utils.rs b/stage2/src/sev/utils.rs new file mode 100644 index 000000000..8dbeb875f --- /dev/null +++ b/stage2/src/sev/utils.rs @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, VirtAddr}; +use crate::error::SvsmError; +use crate::types::{PageSize, GUEST_VMPL, PAGE_SIZE, PAGE_SIZE_2M}; +use crate::utils::MemoryRegion; +use core::arch::asm; +use core::fmt; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[expect(non_camel_case_types)] +pub enum SevSnpError { + FAIL_INPUT(u64), + FAIL_PERMISSION(u64), + FAIL_SIZEMISMATCH(u64), + // Not a real error value, but we want to keep track of this, + // especially for protocol-specific messaging + FAIL_UNCHANGED(u64), +} + +impl From for SvsmError { + fn from(e: SevSnpError) -> Self { + Self::SevSnp(e) + } +} + +impl SevSnpError { + // This should get optimized away by the compiler to a single instruction + pub fn ret(&self) -> u64 { + match self { + Self::FAIL_INPUT(ret) + | Self::FAIL_UNCHANGED(ret) + | Self::FAIL_PERMISSION(ret) + | Self::FAIL_SIZEMISMATCH(ret) => *ret, + } + } +} + +impl fmt::Display for SevSnpError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::FAIL_INPUT(_) => write!(f, "FAIL_INPUT"), + Self::FAIL_UNCHANGED(_) => write!(f, "FAIL_UNCHANGED"), + Self::FAIL_PERMISSION(_) => write!(f, "FAIL_PERMISSION"), + Self::FAIL_SIZEMISMATCH(_) => write!(f, "FAIL_SIZEMISMATCH"), + } + } +} + +fn pvalidate_range_4k(region: MemoryRegion, valid: PvalidateOp) -> Result<(), SvsmError> { + for addr in region.iter_pages(PageSize::Regular) { + pvalidate(addr, PageSize::Regular, valid)?; + } + + Ok(()) +} + +pub fn pvalidate_range( + region: MemoryRegion, + valid: PvalidateOp, +) -> Result<(), SvsmError> { + let mut addr = region.start(); + let end = region.end(); + + while addr < end { + if addr.is_aligned(PAGE_SIZE_2M) && addr + PAGE_SIZE_2M <= end { + // Try to validate as a huge page. + // If we fail, try to fall back to regular-sized pages. + pvalidate(addr, PageSize::Huge, valid).or_else(|err| match err { + SvsmError::SevSnp(SevSnpError::FAIL_SIZEMISMATCH(_)) => { + pvalidate_range_4k(MemoryRegion::new(addr, PAGE_SIZE_2M), valid) + } + _ => Err(err), + })?; + addr = addr + PAGE_SIZE_2M; + } else { + pvalidate(addr, PageSize::Regular, valid)?; + addr = addr + PAGE_SIZE; + } + } + + Ok(()) +} + +/// The desired state of the page passed to PVALIDATE. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u64)] +pub enum PvalidateOp { + Invalid = 0, + Valid = 1, +} + +pub fn pvalidate(vaddr: VirtAddr, size: PageSize, valid: PvalidateOp) -> Result<(), SvsmError> { + let rax = vaddr.bits(); + let rcx: u64 = match size { + PageSize::Regular => 0, + PageSize::Huge => 1, + }; + let rdx = valid as u64; + let ret: u64; + let cf: u64; + + unsafe { + asm!("xorq %r8, %r8", + "pvalidate", + "jnc 1f", + "incq %r8", + "1:", + in("rax") rax, + in("rcx") rcx, + in("rdx") rdx, + lateout("rax") ret, + lateout("r8") cf, + options(att_syntax)); + } + + let changed = cf == 0; + + match ret { + 0 if changed => Ok(()), + 0 if !changed => Err(SevSnpError::FAIL_UNCHANGED(0x10).into()), + 1 => Err(SevSnpError::FAIL_INPUT(ret).into()), + 6 => Err(SevSnpError::FAIL_SIZEMISMATCH(ret).into()), + _ => { + log::error!("PVALIDATE: unexpected return value: {}", ret); + unreachable!(); + } + } +} + +/// Executes the vmmcall instruction. +/// # Safety +/// See cpu vendor documentation for what this can do. +pub unsafe fn raw_vmmcall(eax: u32, ebx: u32, ecx: u32, edx: u32) -> i32 { + let new_eax; + asm!( + // bx register is reserved by llvm so it can't be passed in directly and must be + // restored + "xchg %rbx, {0:r}", + "vmmcall", + "xchg %rbx, {0:r}", + in(reg) ebx as u64, + inout("eax") eax => new_eax, + in("ecx") ecx, + in("edx") edx, + options(att_syntax)); + new_eax +} + +/// Sets the dr7 register to the given value +/// # Safety +/// See cpu vendor documentation for what this can do. +pub unsafe fn set_dr7(new_val: u64) { + asm!("mov {0}, %dr7", in(reg) new_val, options(att_syntax)); +} + +pub fn get_dr7() -> u64 { + let out; + unsafe { asm!("mov %dr7, {0}", out(reg) out, options(att_syntax)) }; + out +} + +/// # Safety +/// VMGEXIT operations generally need to be performed with interrupts disabled +/// to ensure that an interrupt cannot cause the GHCB MSR to change prior to +/// exiting to the host. It is the caller's responsibility to ensure that +/// interrupt handling is configured correctly for the attemtped operation. +pub unsafe fn raw_vmgexit() { + asm!("rep; vmmcall", options(att_syntax)); +} + +bitflags::bitflags! { + pub struct RMPFlags: u64 { + const VMPL0 = 0; + const VMPL1 = 1; + const VMPL2 = 2; + const VMPL3 = 3; + const GUEST_VMPL = GUEST_VMPL as u64; + const READ = 1u64 << 8; + const WRITE = 1u64 << 9; + const X_USER = 1u64 << 10; + const X_SUPER = 1u64 << 11; + const BIT_VMSA = 1u64 << 16; + const NONE = 0; + const RWX = Self::READ.bits() | Self::WRITE.bits() | Self::X_USER.bits() | Self::X_SUPER.bits(); + const VMSA = Self::READ.bits() | Self::BIT_VMSA.bits(); + } +} + +pub fn rmp_adjust(addr: VirtAddr, flags: RMPFlags, size: PageSize) -> Result<(), SvsmError> { + let rcx: u64 = match size { + PageSize::Regular => 0, + PageSize::Huge => 1, + }; + let rax: u64 = addr.bits() as u64; + let rdx: u64 = flags.bits(); + let mut ret: u64; + let mut ex: u64; + + unsafe { + asm!("1: rmpadjust + xorq %rcx, %rcx + 2: + .pushsection \"__exception_table\",\"a\" + .balign 16 + .quad (1b) + .quad (2b) + .popsection", + inout("rax") rax => ret, + inout("rcx") rcx => ex, + in("rdx") rdx, + options(att_syntax)); + } + + if ex != 0 { + // Report exceptions just as FAIL_INPUT + return Err(SevSnpError::FAIL_INPUT(1).into()); + } + + match ret { + 0 => Ok(()), + 1 => Err(SevSnpError::FAIL_INPUT(ret).into()), + 2 => Err(SevSnpError::FAIL_PERMISSION(ret).into()), + 6 => Err(SevSnpError::FAIL_SIZEMISMATCH(ret).into()), + _ => { + log::error!("RMPADJUST: Unexpected return value: {:#x}", ret); + unreachable!(); + } + } +} + +pub fn rmp_revoke_guest_access(vaddr: VirtAddr, size: PageSize) -> Result<(), SvsmError> { + for vmpl in RMPFlags::GUEST_VMPL.bits()..=RMPFlags::VMPL3.bits() { + let vmpl = RMPFlags::from_bits_truncate(vmpl); + rmp_adjust(vaddr, vmpl | RMPFlags::NONE, size)?; + } + Ok(()) +} + +pub fn rmp_grant_guest_access(vaddr: VirtAddr, size: PageSize) -> Result<(), SvsmError> { + rmp_adjust(vaddr, RMPFlags::GUEST_VMPL | RMPFlags::RWX, size) +} + +pub fn rmp_set_guest_vmsa(vaddr: VirtAddr) -> Result<(), SvsmError> { + rmp_revoke_guest_access(vaddr, PageSize::Regular)?; + rmp_adjust( + vaddr, + RMPFlags::GUEST_VMPL | RMPFlags::VMSA, + PageSize::Regular, + ) +} + +pub fn rmp_clear_guest_vmsa(vaddr: VirtAddr) -> Result<(), SvsmError> { + rmp_revoke_guest_access(vaddr, PageSize::Regular)?; + rmp_grant_guest_access(vaddr, PageSize::Regular) +} diff --git a/stage2/src/sev/vmsa.rs b/stage2/src/sev/vmsa.rs new file mode 100644 index 000000000..f830a8623 --- /dev/null +++ b/stage2/src/sev/vmsa.rs @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use super::utils::{rmp_adjust, RMPFlags}; +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::error::SvsmError; +use crate::mm::{virt_to_phys, PageBox}; +use crate::platform::guest_cpu::GuestCpuState; +use crate::sev::status::SEVStatusFlags; +use crate::types::{PageSize, PAGE_SIZE_2M}; +use core::mem::{size_of, ManuallyDrop}; +use core::ops::{Deref, DerefMut}; +use core::ptr; + +use cpuarch::vmsa::{VmsaEventInject, VmsaEventType, VMSA}; + +pub const VMPL_MAX: usize = 4; + +/// An allocated page containing a VMSA structure. +#[derive(Debug)] +pub struct VmsaPage { + page: PageBox<[VMSA; 2]>, + idx: usize, +} + +impl VmsaPage { + /// Allocates a new VMSA for the given VPML. + pub fn new(vmpl: RMPFlags) -> Result { + assert!(vmpl.bits() < (VMPL_MAX as u64)); + + let page = PageBox::<[VMSA; 2]>::try_new_zeroed()?; + // Make sure the VMSA page is not 2M-aligned, as some hardware + // generations can't handle this properly. To ensure this property, we + // allocate 2 VMSAs and choose whichever is not 2M-aligned. + let idx = if page.vaddr().is_aligned(PAGE_SIZE_2M) { + 1 + } else { + 0 + }; + + let vaddr = page.vaddr() + idx * size_of::(); + rmp_adjust(vaddr, RMPFlags::VMSA | vmpl, PageSize::Regular)?; + Ok(Self { page, idx }) + } + + /// Returns the virtual address fro this VMSA. + #[inline] + fn vaddr(&self) -> VirtAddr { + let ptr: *const VMSA = ptr::from_ref(&self.page[self.idx]); + VirtAddr::from(ptr) + } + + /// Returns the physical address for this VMSA. + #[inline] + pub fn paddr(&self) -> PhysAddr { + virt_to_phys(self.vaddr()) + } + + /// Leaks the allocation for this VMSA, ensuring it never gets freed. + pub fn leak(self) -> &'static mut VMSA { + let mut vmsa = ManuallyDrop::new(self); + // SAFETY: `self.idx` is either 0 or 1, so this will never overflow + let ptr = unsafe { ptr::from_mut(&mut vmsa).add(vmsa.idx) }; + // SAFETY: this pointer will never be freed because of ManuallyDrop, + // so we can create a static mutable reference. We go through a raw + // pointer to promote the lifetime to static. + unsafe { &mut *ptr } + } +} + +impl Drop for VmsaPage { + fn drop(&mut self) { + rmp_adjust( + self.vaddr(), + RMPFlags::RWX | RMPFlags::VMPL0, + PageSize::Regular, + ) + .expect("Failed to RMPADJUST VMSA page"); + } +} + +impl Deref for VmsaPage { + type Target = VMSA; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.page[self.idx] + } +} + +impl DerefMut for VmsaPage { + #[inline] + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.page[self.idx] + } +} + +pub trait VMSAControl { + fn enable(&mut self); + fn disable(&mut self); +} + +impl VMSAControl for VMSA { + fn enable(&mut self) { + self.efer |= 1u64 << 12; + } + + fn disable(&mut self) { + self.efer &= !(1u64 << 12); + } +} + +impl GuestCpuState for VMSA { + fn get_tpr(&self) -> u8 { + let vintr_ctrl = self.vintr_ctrl; + + // The VMSA holds a 4-bit TPR but this routine must return an 8-bit + // TPR to maintain consistency with PPR. + vintr_ctrl.v_tpr() << 4 + } + + fn set_tpr(&mut self, tpr: u8) { + let mut vintr_ctrl = self.vintr_ctrl; + vintr_ctrl.set_v_tpr(tpr >> 4) + } + + fn request_nmi(&mut self) { + self.event_inj = VmsaEventInject::new() + .with_valid(true) + .with_event_type(VmsaEventType::NMI); + } + + fn queue_interrupt(&mut self, irq: u8) { + // Schedule the interrupt vector for delivery as a virtual interrupt. + let mut vintr_ctrl = self.vintr_ctrl; + vintr_ctrl.set_v_intr_vector(irq); + vintr_ctrl.set_v_intr_prio(irq >> 4); + vintr_ctrl.set_v_ign_tpr(false); + vintr_ctrl.set_v_irq(true); + self.vintr_ctrl = vintr_ctrl; + } + + fn try_deliver_interrupt_immediately(&mut self, irq: u8) -> bool { + // Attempt to inject the interrupt immediately using event injection. + // If the event injection field already contains a pending event, then + // injection is not possible. + let event_inj = self.event_inj; + if event_inj.valid() { + false + } else { + self.event_inj = VmsaEventInject::new() + .with_vector(irq) + .with_valid(true) + .with_event_type(VmsaEventType::Interrupt); + true + } + } + + fn in_intr_shadow(&self) -> bool { + let vintr_ctrl = self.vintr_ctrl; + vintr_ctrl.int_shadow() + } + + fn interrupts_enabled(&self) -> bool { + (self.rflags & 0x200) != 0 + } + + fn check_and_clear_pending_nmi(&mut self) -> bool { + // Check to see whether the current event injection is for an + // NMI. If so, clear the pending event.. + let event_inj = self.event_inj; + if event_inj.valid() && event_inj.event_type() == VmsaEventType::NMI { + self.event_inj = VmsaEventInject::new(); + true + } else { + false + } + } + + fn check_and_clear_pending_interrupt_event(&mut self) -> u8 { + // Check to see whether the current event injection is for an + // interrupt. If so, clear the pending event.. + let event_inj = self.event_inj; + if event_inj.valid() && event_inj.event_type() == VmsaEventType::Interrupt { + self.event_inj = VmsaEventInject::new(); + event_inj.vector() + } else { + 0 + } + } + + fn check_and_clear_pending_virtual_interrupt(&mut self) -> u8 { + // Check to see whether a virtual interrupt is queued for delivery. + // If so, clear the virtual interrupt request. + let mut vintr_ctrl = self.vintr_ctrl; + if vintr_ctrl.v_irq() { + vintr_ctrl.set_v_irq(false); + self.vintr_ctrl = vintr_ctrl; + vintr_ctrl.v_intr_vector() + } else { + 0 + } + } + fn disable_alternate_injection(&mut self) { + let mut sev_status = SEVStatusFlags::from_sev_features(self.sev_features); + sev_status.remove(SEVStatusFlags::ALT_INJ); + self.sev_features = sev_status.as_sev_features(); + } +} diff --git a/stage2/src/stage2.lds b/stage2/src/stage2.lds new file mode 100644 index 000000000..c0ad0c09f --- /dev/null +++ b/stage2/src/stage2.lds @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: MIT OR Apache-2.0 */ + +/* + * Copyright (c) 2022-2023 SUSE LLC + * + * Author: Joerg Roedel + */ + +OUTPUT_ARCH(i386:x86-64) + +SECTIONS +{ + /* Base address is 8 MB + 32 KB */ + . = 8m + 32k; + .stext = .; + .text : { + *(.startup.*) + *(.text) + *(.text.*) + . = ALIGN(16); + exception_table_start = .; + KEEP(*(__exception_table)) + exception_table_end = .; + } + . = ALIGN(16); + .data : { *(.data) } + edata = .; + . = ALIGN(16); + .bss : { + _bss = .; + *(.bss) *(.bss.[0-9a-zA-Z_]*) + . = ALIGN(16); + _ebss = .; + } + /* Move rodata to follow bss so that the in-memory image has the same + * length as the ELF image. This is required so that the IGVM + * builder does not have to parse the ELF file to know how much space + * to reserve for BSS. */ + . = ALIGN(16); + .rodata : { *(.rodata) } +} + +ENTRY(startup_32) diff --git a/stage2/src/stage2.rs b/stage2/src/stage2.rs new file mode 100755 index 000000000..988918ceb --- /dev/null +++ b/stage2/src/stage2.rs @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +#![no_std] +#![no_main] + +pub mod boot_stage2; + +use bootlib::kernel_launch::{KernelLaunchInfo, Stage2LaunchInfo}; +use bootlib::platform::SvsmPlatformType; +use core::arch::asm; +use core::panic::PanicInfo; +use core::ptr::addr_of_mut; +use core::slice; +use cpuarch::snp_cpuid::SnpCpuidTable; +use elf::ElfError; +use svsm::address::{Address, PhysAddr, VirtAddr}; +use svsm::config::SvsmConfig; +use svsm::console::install_console_logger; +use svsm::cpu::cpuid::{dump_cpuid_table, register_cpuid_table}; +use svsm::cpu::gdt; +use svsm::cpu::idt::stage2::{early_idt_init, early_idt_init_no_ghcb}; +use svsm::cpu::percpu::{this_cpu, PerCpu}; +use svsm::error::SvsmError; +use svsm::fw_cfg::FwCfg; +use svsm::igvm_params::IgvmParams; +use svsm::mm::alloc::{memory_info, print_memory_info, root_mem_init}; +use svsm::mm::pagetable::{paging_init_early, PTEntryFlags, PageTable}; +use svsm::mm::validate::{ + init_valid_bitmap_alloc, valid_bitmap_addr, valid_bitmap_set_valid_range, +}; +use svsm::mm::{init_kernel_mapping_info, FixedAddressMappingRange, SVSM_PERCPU_BASE}; +use svsm::platform; +use svsm::platform::{ + init_platform_type, PageStateChangeOp, PageValidateOp, SvsmPlatform, SvsmPlatformCell, +}; +use svsm::types::{PageSize, PAGE_SIZE, PAGE_SIZE_2M}; +use svsm::utils::{is_aligned, MemoryRegion}; + +extern "C" { + static mut pgtable: PageTable; +} + +fn setup_stage2_allocator(heap_start: u64, heap_end: u64) { + let vstart = VirtAddr::from(heap_start); + let vend = VirtAddr::from(heap_end); + let pstart = PhysAddr::from(vstart.bits()); // Identity mapping + let nr_pages = (vend - vstart) / PAGE_SIZE; + + root_mem_init(pstart, vstart, nr_pages); +} + +fn init_percpu(platform: &mut dyn SvsmPlatform) -> Result<(), SvsmError> { + let bsp_percpu = PerCpu::alloc(0)?; + let init_pgtable = unsafe { + // SAFETY: pgtable is a static mut and this is the only place where we + // get a reference to it. + &mut *addr_of_mut!(pgtable) + }; + bsp_percpu.set_pgtable(init_pgtable); + bsp_percpu.map_self_stage2()?; + platform.setup_guest_host_comm(bsp_percpu, true); + Ok(()) +} + +/// Release all resources in the `PerCpu` instance associated with the current +/// CPU. +/// +/// # Safety +/// +/// The caller must ensure that the `PerCpu` is never used again. +unsafe fn shutdown_percpu() { + let ptr = SVSM_PERCPU_BASE.as_mut_ptr::(); + core::ptr::drop_in_place(ptr); +} + +fn setup_env( + config: &SvsmConfig<'_>, + platform: &mut dyn SvsmPlatform, + launch_info: &Stage2LaunchInfo, +) { + gdt().load(); + early_idt_init_no_ghcb(); + + let debug_serial_port = config.debug_serial_port(); + install_console_logger("Stage2").expect("Console logger already initialized"); + platform + .env_setup(debug_serial_port, launch_info.vtom.try_into().unwrap()) + .expect("Early environment setup failed"); + + let kernel_mapping = FixedAddressMappingRange::new( + VirtAddr::from(0x808000u64), + VirtAddr::from(launch_info.stage2_end as u64), + PhysAddr::from(0x808000u64), + ); + + // Use the low 640 KB of memory as the heap. + let lowmem_region = MemoryRegion::new(VirtAddr::from(0u64), 640 * 1024); + let heap_mapping = FixedAddressMappingRange::new( + lowmem_region.start(), + lowmem_region.end(), + PhysAddr::from(0u64), + ); + init_kernel_mapping_info(kernel_mapping, Some(heap_mapping)); + + // Now that the heap virtual-to-physical mapping has been established, + // validate the first 640 KB of memory so it can be used if necessary. + platform + .validate_virtual_page_range(lowmem_region, PageValidateOp::Validate) + .expect("failed to validate low 640 KB"); + + let cpuid_page = unsafe { &*(launch_info.cpuid_page as *const SnpCpuidTable) }; + + register_cpuid_table(cpuid_page); + paging_init_early(platform).expect("Failed to initialize early paging"); + + // Configure the heap to exist from 64 KB to 640 KB. + setup_stage2_allocator(0x10000, 0xA0000); + + init_percpu(platform).expect("Failed to initialize per-cpu area"); + + // Init IDT again with handlers requiring GHCB (eg. #VC handler) + early_idt_init(); + + // Complete initializtion of the platform. After that point, the console + // will be fully working and any unsupported configuration can be properly + // reported. + platform + .env_setup_late(debug_serial_port) + .expect("Late environment setup failed"); + + dump_cpuid_table(); +} + +/// Map and validate the specified virtual memory region at the given physical +/// address. +fn map_and_validate( + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, + vregion: MemoryRegion, + paddr: PhysAddr, +) -> Result<(), SvsmError> { + let flags = PTEntryFlags::PRESENT + | PTEntryFlags::WRITABLE + | PTEntryFlags::ACCESSED + | PTEntryFlags::DIRTY; + + let mut pgtbl = this_cpu().get_pgtable(); + pgtbl.map_region(vregion, paddr, flags)?; + + if config.page_state_change_required() { + platform.page_state_change( + MemoryRegion::new(paddr, vregion.len()), + PageSize::Huge, + PageStateChangeOp::Private, + )?; + } + platform.validate_virtual_page_range(vregion, PageValidateOp::Validate)?; + valid_bitmap_set_valid_range(paddr, paddr + vregion.len()); + Ok(()) +} + +#[inline] +fn check_launch_info(launch_info: &KernelLaunchInfo) { + let offset: u64 = launch_info.heap_area_virt_start - launch_info.heap_area_phys_start; + let align: u64 = PAGE_SIZE_2M.try_into().unwrap(); + + assert!(is_aligned(offset, align)); +} + +fn get_svsm_config( + launch_info: &Stage2LaunchInfo, + platform: &dyn SvsmPlatform, +) -> Result, SvsmError> { + if launch_info.igvm_params == 0 { + return Ok(SvsmConfig::FirmwareConfig(FwCfg::new( + platform.get_io_port(), + ))); + } + + IgvmParams::new(VirtAddr::from(launch_info.igvm_params as u64)).map(SvsmConfig::IgvmConfig) +} + +/// Loads a single ELF segment and returns its virtual memory region. +fn load_elf_segment( + segment: elf::Elf64ImageLoadSegment<'_>, + paddr: PhysAddr, + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, +) -> Result, SvsmError> { + // Find the segment's bounds + let segment_start = VirtAddr::from(segment.vaddr_range.vaddr_begin); + let segment_end = VirtAddr::from(segment.vaddr_range.vaddr_end).page_align_up(); + let segment_len = segment_end - segment_start; + let segment_region = MemoryRegion::new(segment_start, segment_len); + + // All ELF segments should be aligned to the page size. If not, there's + // the risk of pvalidating a page twice, bail out if so. Note that the + // ELF reading code had already verified that the individual segments, + // with bounds specified as in the ELF file, are non-overlapping. + if !segment_start.is_page_aligned() { + return Err(SvsmError::Elf(ElfError::UnalignedSegmentAddress)); + } + + // Map and validate the segment at the next contiguous physical address + map_and_validate(platform, config, segment_region, paddr)?; + + // Copy the segment contents and pad with zeroes + let segment_buf = + unsafe { slice::from_raw_parts_mut(segment_start.as_mut_ptr::(), segment_len) }; + let contents_len = segment.file_contents.len(); + segment_buf[..contents_len].copy_from_slice(segment.file_contents); + segment_buf[contents_len..].fill(0); + + Ok(segment_region) +} + +/// Loads the kernel ELF and returns the virtual memory region where it +/// resides, as well as its entry point. Updates the used physical memory +/// region accordingly. +fn load_kernel_elf( + launch_info: &Stage2LaunchInfo, + loaded_phys: &mut MemoryRegion, + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, +) -> Result<(VirtAddr, MemoryRegion), SvsmError> { + // Find the bounds of the kernel ELF and load it into the ELF parser + let elf_start = PhysAddr::from(launch_info.kernel_elf_start as u64); + let elf_end = PhysAddr::from(launch_info.kernel_elf_end as u64); + let elf_len = elf_end - elf_start; + let bytes = unsafe { slice::from_raw_parts(elf_start.bits() as *const u8, elf_len) }; + let elf = elf::Elf64File::read(bytes)?; + + let vaddr_alloc_info = elf.image_load_vaddr_alloc_info(); + let vaddr_alloc_base = vaddr_alloc_info.range.vaddr_begin; + + // Map, validate and populate the SVSM kernel ELF's PT_LOAD segments. The + // segments' virtual address range might not necessarily be contiguous, + // track their total extent along the way. Physical memory is successively + // being taken from the physical memory region, the remaining space will be + // available as heap space for the SVSM kernel. Remember the end of all + // physical memory occupied by the loaded ELF image. + let mut load_virt_start = None; + let mut load_virt_end = VirtAddr::null(); + for segment in elf.image_load_segment_iter(vaddr_alloc_base) { + let region = load_elf_segment(segment, loaded_phys.end(), platform, config)?; + // Remember the mapping range's lower and upper bounds to pass it on + // the kernel later. Note that the segments are being iterated over + // here in increasing load order. + if load_virt_start.is_none() { + load_virt_start = Some(region.start()); + } + load_virt_end = region.end(); + + // Update to the next contiguous physical address + *loaded_phys = loaded_phys.expand(region.len()); + } + + let Some(load_virt_start) = load_virt_start else { + log::error!("No loadable segment found in kernel ELF"); + return Err(SvsmError::Mem); + }; + + // Apply relocations, if any + if let Some(dyn_relocs) = + elf.apply_dyn_relas(elf::Elf64X86RelocProcessor::new(), vaddr_alloc_base)? + { + for reloc in dyn_relocs { + let Some(reloc) = reloc? else { + continue; + }; + let dst = unsafe { slice::from_raw_parts_mut(reloc.dst as *mut u8, reloc.value_len) }; + let src = &reloc.value[..reloc.value_len]; + dst.copy_from_slice(src) + } + } + + let entry = VirtAddr::from(elf.get_entry(vaddr_alloc_base)); + let region = MemoryRegion::from_addresses(load_virt_start, load_virt_end); + Ok((entry, region)) +} + +/// Loads the IGVM params at the next contiguous location from the loaded +/// kernel image. Returns the virtual and physical memory regions hosting the +/// loaded data. +fn load_igvm_params( + launch_info: &Stage2LaunchInfo, + params: &IgvmParams<'_>, + loaded_kernel_vregion: &MemoryRegion, + loaded_kernel_pregion: &MemoryRegion, + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, +) -> Result<(MemoryRegion, MemoryRegion), SvsmError> { + // Map and validate destination region + let igvm_vregion = MemoryRegion::new(loaded_kernel_vregion.end(), params.size()); + let igvm_pregion = MemoryRegion::new(loaded_kernel_pregion.end(), params.size()); + map_and_validate(platform, config, igvm_vregion, igvm_pregion.start())?; + + // Copy the contents over + let src_addr = VirtAddr::from(launch_info.igvm_params as u64); + let igvm_src = unsafe { slice::from_raw_parts(src_addr.as_ptr::(), igvm_vregion.len()) }; + let igvm_dst = unsafe { + slice::from_raw_parts_mut(igvm_vregion.start().as_mut_ptr::(), igvm_vregion.len()) + }; + igvm_dst.copy_from_slice(igvm_src); + + Ok((igvm_vregion, igvm_pregion)) +} + +/// Maps any remaining memory between the end of the kernel image and the end +/// of the allocated kernel memory region as heap space. Exclude any memory +/// reserved by the configuration. +/// +/// # Panics +/// +/// Panics if the allocated kernel region (`kernel_region`) is not sufficient +/// to host the loaded kernel region (`loaded_kernel_pregion`) plus memory +/// reserved for configuration. +fn prepare_heap( + kernel_region: MemoryRegion, + loaded_kernel_pregion: MemoryRegion, + loaded_kernel_vregion: MemoryRegion, + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, +) -> Result<(MemoryRegion, MemoryRegion), SvsmError> { + // Heap starts after kernel + let heap_pstart = loaded_kernel_pregion.end(); + let heap_vstart = loaded_kernel_vregion.end(); + + // Compute size, excluding any memory reserved by the configuration. + let heap_size = kernel_region + .end() + .checked_sub(heap_pstart.into()) + .and_then(|r| r.checked_sub(config.reserved_kernel_area_size())) + .expect("Insufficient physical space for kernel image") + .into(); + let heap_pregion = MemoryRegion::new(heap_pstart, heap_size); + let heap_vregion = MemoryRegion::new(heap_vstart, heap_size); + + map_and_validate(platform, config, heap_vregion, heap_pregion.start())?; + + Ok((heap_vregion, heap_pregion)) +} + +#[no_mangle] +pub extern "C" fn stage2_main(launch_info: &Stage2LaunchInfo) { + let platform_type = SvsmPlatformType::from(launch_info.platform_type); + + init_platform_type(platform_type); + let mut platform = SvsmPlatformCell::new(platform_type); + + let config = + get_svsm_config(launch_info, &*platform).expect("Failed to get SVSM configuration"); + setup_env(&config, &mut *platform, launch_info); + + // Get the available physical memory region for the kernel + let kernel_region = config + .find_kernel_region() + .expect("Failed to find memory region for SVSM kernel"); + + init_valid_bitmap_alloc(kernel_region).expect("Failed to allocate valid-bitmap"); + + // The physical memory region we've loaded so far + let mut loaded_kernel_pregion = MemoryRegion::new(kernel_region.start(), 0); + + // Load first the kernel ELF and update the loaded physical region + let (kernel_entry, mut loaded_kernel_vregion) = + load_kernel_elf(launch_info, &mut loaded_kernel_pregion, &*platform, &config) + .expect("Failed to load kernel ELF"); + + // Load the IGVM params, if present. Update loaded region accordingly. + let (igvm_vregion, igvm_pregion) = if let SvsmConfig::IgvmConfig(ref igvm_params) = config { + let (igvm_vregion, igvm_pregion) = load_igvm_params( + launch_info, + igvm_params, + &loaded_kernel_vregion, + &loaded_kernel_pregion, + &*platform, + &config, + ) + .expect("Failed to load IGVM params"); + + // Update the loaded kernel region + loaded_kernel_pregion = loaded_kernel_pregion.expand(igvm_vregion.len()); + loaded_kernel_vregion = loaded_kernel_vregion.expand(igvm_pregion.len()); + (igvm_vregion, igvm_pregion) + } else { + ( + MemoryRegion::new(VirtAddr::null(), 0), + MemoryRegion::new(PhysAddr::null(), 0), + ) + }; + + // Use remaining space after kernel image as heap space. + let (heap_vregion, heap_pregion) = prepare_heap( + kernel_region, + loaded_kernel_pregion, + loaded_kernel_vregion, + &*platform, + &config, + ) + .expect("Failed to map and validate heap"); + + // Build the handover information describing the memory layout and hand + // control to the SVSM kernel. + let launch_info = KernelLaunchInfo { + kernel_region_phys_start: u64::from(kernel_region.start()), + kernel_region_phys_end: u64::from(kernel_region.end()), + heap_area_phys_start: u64::from(heap_pregion.start()), + heap_area_virt_start: u64::from(heap_vregion.start()), + heap_area_size: heap_vregion.len() as u64, + kernel_region_virt_start: u64::from(loaded_kernel_vregion.start()), + kernel_elf_stage2_virt_start: u64::from(launch_info.kernel_elf_start), + kernel_elf_stage2_virt_end: u64::from(launch_info.kernel_elf_end), + kernel_fs_start: u64::from(launch_info.kernel_fs_start), + kernel_fs_end: u64::from(launch_info.kernel_fs_end), + stage2_start: 0x800000u64, + stage2_end: launch_info.stage2_end as u64, + cpuid_page: launch_info.cpuid_page as u64, + secrets_page: launch_info.secrets_page as u64, + stage2_igvm_params_phys_addr: u64::from(launch_info.igvm_params), + stage2_igvm_params_size: igvm_pregion.len() as u64, + igvm_params_phys_addr: u64::from(igvm_pregion.start()), + igvm_params_virt_addr: u64::from(igvm_vregion.start()), + vtom: launch_info.vtom, + debug_serial_port: config.debug_serial_port(), + use_alternate_injection: config.use_alternate_injection(), + platform_type, + }; + + check_launch_info(&launch_info); + + let mem_info = memory_info(); + print_memory_info(&mem_info); + + log::info!( + " kernel_region_phys_start = {:#018x}", + kernel_region.start() + ); + log::info!(" kernel_region_phys_end = {:#018x}", kernel_region.end()); + log::info!( + " kernel_virtual_base = {:#018x}", + loaded_kernel_vregion.start() + ); + + let valid_bitmap = valid_bitmap_addr(); + + log::info!("Starting SVSM kernel..."); + + // Shut down the GHCB + unsafe { + shutdown_percpu(); + } + + unsafe { + asm!("jmp *%rax", + in("rax") u64::from(kernel_entry), + in("r8") &launch_info, + in("r9") valid_bitmap.bits(), + options(att_syntax)) + }; + + panic!("Road ends here!"); +} + +#[panic_handler] +fn panic(info: &PanicInfo<'_>) -> ! { + log::error!("Panic: {}", info); + loop { + platform::halt(); + } +} diff --git a/stage2/src/string.rs b/stage2/src/string.rs new file mode 100644 index 000000000..34ddc883b --- /dev/null +++ b/stage2/src/string.rs @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use core::fmt; + +#[derive(Copy, Clone, Debug)] +pub struct FixedString { + len: usize, + data: [char; T], +} + +impl FixedString { + pub const fn new() -> Self { + FixedString { + len: 0, + data: ['\0'; T], + } + } + + pub fn push(&mut self, c: char) { + let l = self.len; + + if l > 0 && self.data[l - 1] == '\0' { + return; + } + + self.data[l] = c; + self.len += 1; + } + + pub fn length(&self) -> usize { + self.len + } +} + +impl Default for FixedString { + fn default() -> Self { + Self::new() + } +} + +impl From<[u8; N]> for FixedString { + fn from(arr: [u8; N]) -> FixedString { + let data = arr.map(char::from); + let len = arr.iter().position(|&b| b == 0).unwrap_or(N); + FixedString { data, len } + } +} + +impl From<&str> for FixedString { + fn from(st: &str) -> FixedString { + let mut fs = FixedString::new(); + for c in st.chars().take(N) { + fs.data[fs.len] = c; + fs.len += 1; + } + fs + } +} + +impl PartialEq<&str> for FixedString { + fn eq(&self, other: &&str) -> bool { + for (i, c) in other.chars().enumerate() { + if i >= N { + return false; + } + if self.data[i] != c { + return false; + } + } + true + } +} + +impl PartialEq> for FixedString { + fn eq(&self, other: &FixedString) -> bool { + if self.len != other.len { + return false; + } + + self.data + .iter() + .zip(&other.data) + .take(self.len) + .all(|(a, b)| *a == *b) + } +} + +impl fmt::Display for FixedString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for b in self.data.iter().take(self.len) { + write!(f, "{}", *b)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + extern crate alloc; + use super::*; + use alloc::string::String; + use core::fmt::Write; + + #[test] + fn from_u8_array1() { + let st = FixedString::from([b'a', b'b', b'c', b'd', b'z']); + assert_eq!(st, "abcdz"); + assert_eq!(st.len, 5); + } + + #[test] + fn from_u8_array2() { + let st = FixedString::from([b'a', b'b', b'c', b'\0', b'd', b'e']); + assert_eq!(st, "abc"); + assert_eq!(st.len, 3); + } + + #[test] + fn display() { + let mut buf = String::new(); + let st = FixedString::from([b's', b'v', b's', b'm', b'\0', b'x', b'y']); + write!(&mut buf, "{}", st).unwrap(); + assert_eq!(buf.as_str(), "svsm"); + } +} diff --git a/stage2/src/svsm.lds b/stage2/src/svsm.lds new file mode 100644 index 000000000..c7cd6f465 --- /dev/null +++ b/stage2/src/svsm.lds @@ -0,0 +1,31 @@ +OUTPUT_ARCH(i386:x86-64) + +SECTIONS +{ + . = 0xffffff8000000000; + .text : { + *(.startup.*) + *(.text) + *(.text.*) + . = ALIGN(16); + entry_code_start = .; + *(.entry.text) + entry_code_end = .; + . = ALIGN(16); + exception_table_start = .; + KEEP(*(__exception_table)) + exception_table_end = .; + } + . = ALIGN(4096); + .rodata : { *(.rodata) *(.rodata.*) } + . = ALIGN(4096); + .data : { *(.data) *(.data.*) } + . = ALIGN(4096); + .bss : { + *(.bss) *(.bss.*) + . = ALIGN(4096); + } + . = ALIGN(4096); +} + +ENTRY(startup_64) diff --git a/stage2/src/svsm.rs b/stage2/src/svsm.rs new file mode 100755 index 000000000..9126ee690 --- /dev/null +++ b/stage2/src/svsm.rs @@ -0,0 +1,487 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +#![cfg_attr(not(test), no_std)] +#![cfg_attr(not(test), no_main)] + +use svsm::fw_meta::{print_fw_meta, validate_fw_memory, SevFWMetaData}; + +use bootlib::kernel_launch::KernelLaunchInfo; +use core::arch::global_asm; +use core::panic::PanicInfo; +use core::slice; +use cpuarch::snp_cpuid::SnpCpuidTable; +use svsm::address::{PhysAddr, VirtAddr}; +use svsm::config::SvsmConfig; +use svsm::console::install_console_logger; +use svsm::cpu::control_regs::{cr0_init, cr4_init}; +use svsm::cpu::cpuid::{dump_cpuid_table, register_cpuid_table}; +use svsm::cpu::gdt; +use svsm::cpu::idt::svsm::{early_idt_init, idt_init}; +use svsm::cpu::percpu::current_ghcb; +use svsm::cpu::percpu::PerCpu; +use svsm::cpu::percpu::{this_cpu, this_cpu_shared}; +use svsm::cpu::smp::start_secondary_cpus; +use svsm::cpu::sse::sse_init; +use svsm::debug::gdbstub::svsm_gdbstub::{debug_break, gdbstub_start}; +use svsm::debug::stacktrace::print_stack; +use svsm::error::SvsmError; +use svsm::fs::{initialize_fs, populate_ram_fs}; +use svsm::fw_cfg::FwCfg; +use svsm::igvm_params::IgvmParams; +use svsm::kernel_region::new_kernel_region; +use svsm::mm::alloc::{memory_info, print_memory_info, root_mem_init}; +use svsm::mm::memory::{init_memory_map, write_guest_memory_map}; +use svsm::mm::pagetable::paging_init; +use svsm::mm::virtualrange::virt_log_usage; +use svsm::mm::{init_kernel_mapping_info, FixedAddressMappingRange, PerCPUPageMappingGuard}; +use svsm::platform; +use svsm::platform::{init_platform_type, SvsmPlatformCell, SVSM_PLATFORM}; +use svsm::requests::{request_loop, request_processing_main, update_mappings}; +use svsm::sev::utils::{rmp_adjust, RMPFlags}; +use svsm::sev::{secrets_page, secrets_page_mut}; +use svsm::svsm_paging::{init_page_table, invalidate_early_boot_memory}; +use svsm::task::exec_user; +use svsm::task::{create_kernel_task, schedule_init}; +use svsm::types::{PageSize, GUEST_VMPL, PAGE_SIZE}; +use svsm::utils::{immut_after_init::ImmutAfterInitCell, zero_mem_region}; +#[cfg(all(feature = "mstpm", not(test)))] +use svsm::vtpm::vtpm_init; + +use svsm::mm::validate::{init_valid_bitmap_ptr, migrate_valid_bitmap}; + +extern "C" { + pub static bsp_stack_end: u8; +} + +/* + * Launch protocol: + * + * The stage2 loader will map and load the svsm binary image and jump to + * startup_64. + * + * %r8 Pointer to the KernelLaunchInfo structure + * %r9 Pointer to the valid-bitmap from stage2 + */ +global_asm!( + r#" + .text + .section ".startup.text","ax" + .code64 + + .globl startup_64 + startup_64: + /* Setup stack */ + leaq bsp_stack_end(%rip), %rsp + + /* Jump to rust code */ + movq %r8, %rdi + movq %r9, %rsi + jmp svsm_start + + .bss + + .align {PAGE_SIZE} + bsp_stack: + .fill 8*{PAGE_SIZE}, 1, 0 + bsp_stack_end: + "#, + PAGE_SIZE = const PAGE_SIZE, + options(att_syntax) +); + +static CPUID_PAGE: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); +static LAUNCH_INFO: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); + +const _: () = assert!(size_of::() <= PAGE_SIZE); + +fn copy_cpuid_table_to_fw(fw_addr: PhysAddr) -> Result<(), SvsmError> { + let guard = PerCPUPageMappingGuard::create_4k(fw_addr)?; + let start = guard.virt_addr().as_mut_ptr::(); + + // SAFETY: this is called from CPU 0, so the underlying physical address + // is not being aliased. We are mapping a full page, which is 4k-aligned, + // and is enough for SnpCpuidTable. We also assert above at compile time + // that SnpCpuidTable fits within a page, so the write is safe. + unsafe { + // Zero target and copy data + start.write_bytes(0, PAGE_SIZE); + start + .cast::() + .copy_from_nonoverlapping(&*CPUID_PAGE, 1); + } + + Ok(()) +} + +fn copy_secrets_page_to_fw(fw_addr: PhysAddr, caa_addr: PhysAddr) -> Result<(), SvsmError> { + let guard = PerCPUPageMappingGuard::create_4k(fw_addr)?; + let start = guard.virt_addr(); + + // Zero target + zero_mem_region(start, start + PAGE_SIZE); + + // Copy secrets page + let mut fw_secrets_page = secrets_page().copy_for_vmpl(GUEST_VMPL); + + let &li = &*LAUNCH_INFO; + + fw_secrets_page.set_svsm_data( + li.kernel_region_phys_start, + li.kernel_region_phys_end - li.kernel_region_phys_start, + u64::from(caa_addr), + ); + + // SAFETY: start points to a new allocated and zeroed page. + unsafe { + fw_secrets_page.copy_to(start); + } + + Ok(()) +} + +fn zero_caa_page(fw_addr: PhysAddr) -> Result<(), SvsmError> { + let guard = PerCPUPageMappingGuard::create_4k(fw_addr)?; + let vaddr = guard.virt_addr(); + + zero_mem_region(vaddr, vaddr + PAGE_SIZE); + + Ok(()) +} + +fn copy_tables_to_fw(fw_meta: &SevFWMetaData) -> Result<(), SvsmError> { + if let Some(addr) = fw_meta.cpuid_page { + copy_cpuid_table_to_fw(addr)?; + } + + let secrets_page = fw_meta.secrets_page.ok_or(SvsmError::MissingSecrets)?; + let caa_page = fw_meta.caa_page.ok_or(SvsmError::MissingCAA)?; + + copy_secrets_page_to_fw(secrets_page, caa_page)?; + + zero_caa_page(caa_page)?; + + Ok(()) +} + +fn prepare_fw_launch(fw_meta: &SevFWMetaData) -> Result<(), SvsmError> { + if let Some(caa) = fw_meta.caa_page { + this_cpu_shared().update_guest_caa(caa); + } + + this_cpu().alloc_guest_vmsa()?; + update_mappings()?; + + Ok(()) +} + +fn launch_fw(config: &SvsmConfig<'_>) -> Result<(), SvsmError> { + let cpu = this_cpu(); + let mut vmsa_ref = cpu.guest_vmsa_ref(); + let vmsa_pa = vmsa_ref.vmsa_phys().unwrap(); + let vmsa = vmsa_ref.vmsa(); + + config.initialize_guest_vmsa(vmsa)?; + + log::info!("VMSA PA: {:#x}", vmsa_pa); + + let sev_features = vmsa.sev_features; + + log::info!("Launching Firmware"); + current_ghcb().register_guest_vmsa(vmsa_pa, 0, GUEST_VMPL as u64, sev_features)?; + + Ok(()) +} + +fn validate_fw(config: &SvsmConfig<'_>, launch_info: &KernelLaunchInfo) -> Result<(), SvsmError> { + let kernel_region = new_kernel_region(launch_info); + let flash_regions = config.get_fw_regions(&kernel_region); + + for (i, region) in flash_regions.into_iter().enumerate() { + log::info!( + "Flash region {} at {:#018x} size {:018x}", + i, + region.start(), + region.len(), + ); + + for paddr in region.iter_pages(PageSize::Regular) { + let guard = PerCPUPageMappingGuard::create_4k(paddr)?; + let vaddr = guard.virt_addr(); + if let Err(e) = rmp_adjust( + vaddr, + RMPFlags::GUEST_VMPL | RMPFlags::RWX, + PageSize::Regular, + ) { + log::info!("rmpadjust failed for addr {:#018x}", vaddr); + return Err(e); + } + } + } + + Ok(()) +} + +pub fn memory_init(launch_info: &KernelLaunchInfo) { + root_mem_init( + PhysAddr::from(launch_info.heap_area_phys_start), + VirtAddr::from(launch_info.heap_area_virt_start), + launch_info.heap_area_size as usize / PAGE_SIZE, + ); +} + +pub fn boot_stack_info() { + // SAFETY: this is only unsafe because `bsp_stack_end` is an extern + // static, but we're simply printing its address. We are not creating a + // reference so this is safe. + let vaddr = VirtAddr::from(&raw const bsp_stack_end); + log::info!("Boot stack starts @ {:#018x}", vaddr); +} + +fn mapping_info_init(launch_info: &KernelLaunchInfo) { + let kernel_mapping = FixedAddressMappingRange::new( + VirtAddr::from(launch_info.heap_area_virt_start), + VirtAddr::from(launch_info.heap_area_virt_end()), + PhysAddr::from(launch_info.heap_area_phys_start), + ); + init_kernel_mapping_info(kernel_mapping, None); +} + +/// # Panics +/// +/// Panics if the provided address is not aligned to a [`SnpCpuidTable`]. +fn init_cpuid_table(addr: VirtAddr) { + // SAFETY: this is called from the main function for the SVSM and no other + // CPUs have been brought up, so the pointer cannot be aliased. + // `aligned_mut()` will check alignment for us. + let table = unsafe { + addr.aligned_mut::() + .expect("Misaligned SNP CPUID table address") + }; + + for func in table.func.iter_mut().take(table.count as usize) { + if func.eax_in == 0x8000001f { + func.eax_out |= 1 << 28; + } + } + + CPUID_PAGE + .init(table) + .expect("Already initialized CPUID page"); + register_cpuid_table(&CPUID_PAGE); +} + +#[no_mangle] +pub extern "C" fn svsm_start(li: &KernelLaunchInfo, vb_addr: usize) { + let launch_info: KernelLaunchInfo = *li; + init_platform_type(launch_info.platform_type); + + let vb_ptr = core::ptr::NonNull::new(VirtAddr::new(vb_addr).as_mut_ptr::()).unwrap(); + + mapping_info_init(&launch_info); + + // SAFETY: we trust the previous stage to pass a valid pointer + unsafe { init_valid_bitmap_ptr(new_kernel_region(&launch_info), vb_ptr) }; + + gdt().load(); + early_idt_init(); + + // Capture the debug serial port before the launch info disappears from + // the address space. + let debug_serial_port = li.debug_serial_port; + + LAUNCH_INFO + .init(li) + .expect("Already initialized launch info"); + + let mut platform = SvsmPlatformCell::new(li.platform_type); + + init_cpuid_table(VirtAddr::from(launch_info.cpuid_page)); + + let secrets_page_virt = VirtAddr::from(launch_info.secrets_page); + + // SAFETY: the secrets page address directly comes from IGVM. + // We trust stage 2 to give the value provided by IGVM. + unsafe { + secrets_page_mut().copy_from(secrets_page_virt); + } + + zero_mem_region(secrets_page_virt, secrets_page_virt + PAGE_SIZE); + + cr0_init(); + cr4_init(&*platform); + install_console_logger("SVSM").expect("Console logger already initialized"); + platform + .env_setup(debug_serial_port, launch_info.vtom.try_into().unwrap()) + .expect("Early environment setup failed"); + + memory_init(&launch_info); + migrate_valid_bitmap().expect("Failed to migrate valid-bitmap"); + + let kernel_elf_len = (launch_info.kernel_elf_stage2_virt_end + - launch_info.kernel_elf_stage2_virt_start) as usize; + let kernel_elf_buf_ptr = launch_info.kernel_elf_stage2_virt_start as *const u8; + // SAFETY: we trust stage 2 to pass on a correct pointer and length. This + // cannot be aliased because we are on CPU 0 and other CPUs have not been + // brought up. The resulting slice is &[u8], so there are no alignment + // requirements. + let kernel_elf_buf = unsafe { slice::from_raw_parts(kernel_elf_buf_ptr, kernel_elf_len) }; + let kernel_elf = match elf::Elf64File::read(kernel_elf_buf) { + Ok(kernel_elf) => kernel_elf, + Err(e) => panic!("error reading kernel ELF: {}", e), + }; + + paging_init(&*platform).expect("Failed to initialize paging"); + let init_pgtable = + init_page_table(&launch_info, &kernel_elf).expect("Could not initialize the page table"); + init_pgtable.load(); + + let bsp_percpu = PerCpu::alloc(0).expect("Failed to allocate BSP per-cpu data"); + + bsp_percpu + .setup(&*platform, init_pgtable) + .expect("Failed to setup BSP per-cpu area"); + bsp_percpu + .setup_on_cpu(&*platform) + .expect("Failed to run percpu.setup_on_cpu()"); + bsp_percpu.load(); + + // Idle task must be allocated after PerCPU data is mapped + bsp_percpu + .setup_idle_task(svsm_main) + .expect("Failed to allocate idle task for BSP"); + + idt_init(); + platform + .env_setup_late(debug_serial_port) + .expect("Late environment setup failed"); + + dump_cpuid_table(); + + let mem_info = memory_info(); + print_memory_info(&mem_info); + + boot_stack_info(); + + let bp = this_cpu().get_top_of_stack(); + log::info!("BSP Runtime stack starts @ {:#018x}", bp); + + platform + .configure_alternate_injection(launch_info.use_alternate_injection) + .expect("Alternate injection required but not available"); + + SVSM_PLATFORM + .init(&platform) + .expect("Failed to initialize SVSM platform object"); + + sse_init(); + schedule_init(); + + panic!("SVSM entry point terminated unexpectedly"); +} + +#[no_mangle] +pub extern "C" fn svsm_main() { + // If required, the GDB stub can be started earlier, just after the console + // is initialised in svsm_start() above. + gdbstub_start(&**SVSM_PLATFORM).expect("Could not start GDB stub"); + // Uncomment the line below if you want to wait for + // a remote GDB connection + //debug_break(); + + SVSM_PLATFORM + .env_setup_svsm() + .expect("SVSM platform environment setup failed"); + + let launch_info = &*LAUNCH_INFO; + let config = if launch_info.igvm_params_virt_addr != 0 { + let igvm_params = IgvmParams::new(VirtAddr::from(launch_info.igvm_params_virt_addr)) + .expect("Invalid IGVM parameters"); + if (launch_info.vtom != 0) && (launch_info.vtom != igvm_params.get_vtom()) { + panic!("Launch VTOM does not match VTOM from IGVM parameters"); + } + SvsmConfig::IgvmConfig(igvm_params) + } else { + SvsmConfig::FirmwareConfig(FwCfg::new(SVSM_PLATFORM.get_io_port())) + }; + + init_memory_map(&config, &LAUNCH_INFO).expect("Failed to init guest memory map"); + + initialize_fs(); + + populate_ram_fs(LAUNCH_INFO.kernel_fs_start, LAUNCH_INFO.kernel_fs_end) + .expect("Failed to unpack FS archive"); + + invalidate_early_boot_memory(&**SVSM_PLATFORM, &config, launch_info) + .expect("Failed to invalidate early boot memory"); + + let cpus = config.load_cpu_info().expect("Failed to load ACPI tables"); + let mut nr_cpus = 0; + + for cpu in cpus.iter() { + if cpu.enabled { + nr_cpus += 1; + } + } + + log::info!("{} CPU(s) present", nr_cpus); + + start_secondary_cpus(&**SVSM_PLATFORM, &cpus); + + let fw_metadata = config.get_fw_metadata(); + if let Some(ref fw_meta) = fw_metadata { + print_fw_meta(fw_meta); + write_guest_memory_map(&config).expect("Failed to write guest memory map"); + validate_fw_memory(&config, fw_meta, &LAUNCH_INFO).expect("Failed to validate memory"); + copy_tables_to_fw(fw_meta).expect("Failed to copy firmware tables"); + validate_fw(&config, &LAUNCH_INFO).expect("Failed to validate flash memory"); + } + + if let Some(ref fw_meta) = fw_metadata { + prepare_fw_launch(fw_meta).expect("Failed to setup guest VMSA/CAA"); + } + + #[cfg(all(feature = "mstpm", not(test)))] + vtpm_init().expect("vTPM failed to initialize"); + + virt_log_usage(); + + if config.should_launch_fw() { + if let Err(e) = launch_fw(&config) { + panic!("Failed to launch FW: {:#?}", e); + } + } + + create_kernel_task(request_processing_main).expect("Failed to launch request processing task"); + + #[cfg(test)] + crate::test_main(); + + if exec_user("/init").is_err() { + log::info!("Failed to launch /init"); + } + + request_loop(); + + panic!("Road ends here!"); +} + +#[panic_handler] +fn panic(info: &PanicInfo<'_>) -> ! { + secrets_page_mut().clear_vmpck(0); + secrets_page_mut().clear_vmpck(1); + secrets_page_mut().clear_vmpck(2); + secrets_page_mut().clear_vmpck(3); + + log::error!("Panic: CPU[{}] {}", this_cpu().get_apic_id(), info); + + print_stack(3); + + loop { + debug_break(); + platform::halt(); + } +} diff --git a/stage2/src/svsm_paging.rs b/stage2/src/svsm_paging.rs new file mode 100644 index 000000000..8c7772ed6 --- /dev/null +++ b/stage2/src/svsm_paging.rs @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, PhysAddr, VirtAddr}; +use crate::config::SvsmConfig; +use crate::error::SvsmError; +use crate::igvm_params::IgvmParams; +use crate::mm::pagetable::{PTEntryFlags, PageTable}; +use crate::mm::PageBox; +use crate::platform::{PageStateChangeOp, PageValidateOp, SvsmPlatform}; +use crate::types::PageSize; +use crate::utils::MemoryRegion; +use bootlib::kernel_launch::KernelLaunchInfo; + +struct IgvmParamInfo<'a> { + virt_addr: VirtAddr, + igvm_params: Option>, +} + +pub fn init_page_table( + launch_info: &KernelLaunchInfo, + kernel_elf: &elf::Elf64File<'_>, +) -> Result, SvsmError> { + let mut pgtable = PageTable::allocate_new()?; + + let igvm_param_info = if launch_info.igvm_params_virt_addr != 0 { + let addr = VirtAddr::from(launch_info.igvm_params_virt_addr); + IgvmParamInfo { + virt_addr: addr, + igvm_params: Some(IgvmParams::new(addr)?), + } + } else { + IgvmParamInfo { + virt_addr: VirtAddr::null(), + igvm_params: None, + } + }; + + // Install mappings for the kernel's ELF segments each. + // The memory backing the kernel ELF segments gets allocated back to back + // from the physical memory region by the Stage2 loader. + let mut phys = PhysAddr::from(launch_info.kernel_region_phys_start); + for segment in kernel_elf.image_load_segment_iter(launch_info.kernel_region_virt_start) { + let vaddr_start = VirtAddr::from(segment.vaddr_range.vaddr_begin); + let vaddr_end = VirtAddr::from(segment.vaddr_range.vaddr_end); + let aligned_vaddr_end = vaddr_end.page_align_up(); + let segment_len = aligned_vaddr_end - vaddr_start; + let flags = if segment.flags.contains(elf::Elf64PhdrFlags::EXECUTE) { + PTEntryFlags::exec() + } else if segment.flags.contains(elf::Elf64PhdrFlags::WRITE) { + PTEntryFlags::data() + } else { + PTEntryFlags::data_ro() + }; + + let vregion = MemoryRegion::new(vaddr_start, segment_len); + pgtable + .map_region(vregion, phys, flags) + .expect("Failed to map kernel ELF segment"); + + phys = phys + segment_len; + } + + // Map the IGVM parameters if present. + if let Some(ref igvm_params) = igvm_param_info.igvm_params { + let vregion = MemoryRegion::new(igvm_param_info.virt_addr, igvm_params.size()); + pgtable + .map_region( + vregion, + PhysAddr::from(launch_info.igvm_params_phys_addr), + PTEntryFlags::data(), + ) + .expect("Failed to map IGVM parameters"); + } + + // Map subsequent heap area. + let heap_vregion = MemoryRegion::new( + VirtAddr::from(launch_info.heap_area_virt_start), + launch_info.heap_area_size as usize, + ); + pgtable + .map_region( + heap_vregion, + PhysAddr::from(launch_info.heap_area_phys_start), + PTEntryFlags::data(), + ) + .expect("Failed to map heap"); + + pgtable.load(); + + Ok(pgtable) +} + +fn invalidate_boot_memory_region( + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, + region: MemoryRegion, +) -> Result<(), SvsmError> { + log::info!( + "Invalidating boot region {:018x}-{:018x}", + region.start(), + region.end() + ); + + if !region.is_empty() { + platform.validate_physical_page_range(region, PageValidateOp::Invalidate)?; + + if config.page_state_change_required() { + platform.page_state_change(region, PageSize::Regular, PageStateChangeOp::Shared)?; + } + } + + Ok(()) +} + +pub fn invalidate_early_boot_memory( + platform: &dyn SvsmPlatform, + config: &SvsmConfig<'_>, + launch_info: &KernelLaunchInfo, +) -> Result<(), SvsmError> { + // Early boot memory must be invalidated after changing to the SVSM page + // page table to avoid invalidating page tables currently in use. Always + // invalidate stage 2 memory, unless firmware is loaded into low memory. + // Also invalidate the boot data if required. + if !config.fw_in_low_memory() { + let lowmem_region = MemoryRegion::new(PhysAddr::null(), 640 * 1024); + invalidate_boot_memory_region(platform, config, lowmem_region)?; + } + + let stage2_base = PhysAddr::from(launch_info.stage2_start); + let stage2_end = PhysAddr::from(launch_info.stage2_end); + let stage2_region = MemoryRegion::from_addresses(stage2_base, stage2_end); + invalidate_boot_memory_region(platform, config, stage2_region)?; + + if config.invalidate_boot_data() { + let kernel_elf_size = + launch_info.kernel_elf_stage2_virt_end - launch_info.kernel_elf_stage2_virt_start; + let kernel_elf_region = MemoryRegion::new( + PhysAddr::new(launch_info.kernel_elf_stage2_virt_start.try_into().unwrap()), + kernel_elf_size.try_into().unwrap(), + ); + invalidate_boot_memory_region(platform, config, kernel_elf_region)?; + + let kernel_fs_size = launch_info.kernel_fs_end - launch_info.kernel_fs_start; + if kernel_fs_size > 0 { + let kernel_fs_region = MemoryRegion::new( + PhysAddr::new(launch_info.kernel_fs_start.try_into().unwrap()), + kernel_fs_size.try_into().unwrap(), + ); + invalidate_boot_memory_region(platform, config, kernel_fs_region)?; + } + + if launch_info.stage2_igvm_params_size > 0 { + let igvm_params_region = MemoryRegion::new( + PhysAddr::new(launch_info.stage2_igvm_params_phys_addr.try_into().unwrap()), + launch_info.stage2_igvm_params_size as usize, + ); + invalidate_boot_memory_region(platform, config, igvm_params_region)?; + } + } + + Ok(()) +} diff --git a/stage2/src/syscall/handlers.rs b/stage2/src/syscall/handlers.rs new file mode 100644 index 000000000..a1031b907 --- /dev/null +++ b/stage2/src/syscall/handlers.rs @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel + +use crate::task::{current_task_terminated, schedule}; + +pub fn sys_hello() -> usize { + log::info!("Hello, world! System call invoked from user-space."); + 0 +} + +pub fn sys_exit() -> ! { + log::info!("Terminating current task"); + unsafe { + current_task_terminated(); + } + schedule(); + panic!("schedule() returned in sys_exit()"); +} diff --git a/stage2/src/syscall/mod.rs b/stage2/src/syscall/mod.rs new file mode 100644 index 000000000..26e5a51bc --- /dev/null +++ b/stage2/src/syscall/mod.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel + +mod handlers; +mod obj; + +pub use handlers::*; +pub use obj::{Obj, ObjError, ObjHandle}; diff --git a/stage2/src/syscall/obj.rs b/stage2/src/syscall/obj.rs new file mode 100644 index 000000000..f8a248cd0 --- /dev/null +++ b/stage2/src/syscall/obj.rs @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2024 Intel Corporation. +// +// Author: Chuanxiao Dong + +extern crate alloc; + +use crate::cpu::percpu::current_task; +use crate::error::SvsmError; +use alloc::sync::Arc; + +#[derive(Clone, Copy, Debug)] +pub enum ObjError { + InvalidHandle, + NotFound, +} + +/// An object represents the type of resource like file, VM, vCPU in the +/// COCONUT-SVSM kernel which can be accessible by the user mode. The Obj +/// trait is defined for such type of resource, which can be used to define +/// the common functionalities of the objects. With the trait bounds of Send +/// and Sync, the objects implementing Obj trait could be sent to another +/// thread and shared between threads safely. +pub trait Obj: Send + Sync + core::fmt::Debug {} + +/// ObjHandle is a unique identifier for an object in the current process. +/// An ObjHandle can be converted to a u32 id which can be used by the user +/// mode to access this object. The passed id from the user mode by syscalls +/// can be converted to an `ObjHandle`, which is used to access the object in +/// the COCONUT-SVSM kernel. +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub struct ObjHandle(u32); + +impl ObjHandle { + pub fn new(id: u32) -> Self { + Self(id) + } +} + +impl From for ObjHandle { + #[inline] + fn from(id: u32) -> Self { + Self(id) + } +} + +impl From for u32 { + #[inline] + fn from(obj_handle: ObjHandle) -> Self { + obj_handle.0 + } +} + +/// Add an object to the current process and assigns it an `ObjHandle`. +/// +/// # Arguments +/// +/// * `obj` - An `Arc` representing the object to be added. +/// +/// # Returns +/// +/// * `Result` - Returns the object handle of the +/// added object if successful, or an `SvsmError` on failure. +/// +/// # Errors +/// +/// This function will return an error if adding the object to the +/// current task fails. +#[expect(dead_code)] +pub fn obj_add(obj: Arc) -> Result { + current_task().add_obj(obj) +} + +/// Closes an object identified by its ObjHandle. +/// +/// # Arguments +/// +/// * `id` - The ObjHandle for the object to be closed. +/// +/// # Returns +/// +/// * `Result>, SvsmError>` - Returns the `Arc` +/// on success, or an `SvsmError` on failure. +/// +/// # Errors +/// +/// This function will return an error if removing the object from the +/// current task fails. +#[expect(dead_code)] +pub fn obj_close(id: ObjHandle) -> Result, SvsmError> { + current_task().remove_obj(id) +} + +/// Retrieves an object by its ObjHandle. +/// +/// # Arguments +/// +/// * `id` - The ObjHandle for the object to be retrieved. +/// +/// # Returns +/// +/// * `Result>, SvsmError>` - Returns the `Arc` on +/// success, or an `SvsmError` on failure. +/// +/// # Errors +/// +/// This function will return an error if retrieving the object from the +/// current task fails. +#[expect(dead_code)] +pub fn obj_get(id: ObjHandle) -> Result, SvsmError> { + current_task().get_obj(id) +} diff --git a/stage2/src/task/exec.rs b/stage2/src/task/exec.rs new file mode 100644 index 000000000..f3b742f7f --- /dev/null +++ b/stage2/src/task/exec.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, VirtAddr}; +use crate::error::SvsmError; +use crate::fs::open; +use crate::mm::vm::VMFileMappingFlags; +use crate::mm::USER_MEM_END; +use crate::task::{create_user_task, current_task, schedule}; +use crate::types::PAGE_SIZE; +use elf::{Elf64File, Elf64PhdrFlags}; + +fn convert_elf_phdr_flags(flags: Elf64PhdrFlags) -> VMFileMappingFlags { + let mut vm_flags = VMFileMappingFlags::Fixed; + + if flags.contains(Elf64PhdrFlags::WRITE) { + vm_flags |= VMFileMappingFlags::Write | VMFileMappingFlags::Private; + } + + if flags.contains(Elf64PhdrFlags::EXECUTE) { + vm_flags |= VMFileMappingFlags::Execute; + } + + vm_flags +} + +pub fn exec_user(binary: &str) -> Result<(), SvsmError> { + let fh = open(binary)?; + let file_size = fh.size(); + + let task = current_task(); + let vstart = task.mmap_kernel_guard( + VirtAddr::new(0), + Some(&fh), + 0, + file_size, + VMFileMappingFlags::Read, + )?; + let buf = unsafe { vstart.to_slice::(file_size) }; + let elf_bin = Elf64File::read(buf).map_err(|_| SvsmError::Mem)?; + + let alloc_info = elf_bin.image_load_vaddr_alloc_info(); + let virt_base = alloc_info.range.vaddr_begin; + let entry = elf_bin.get_entry(virt_base); + + let task = create_user_task(entry.try_into().unwrap())?; + + for seg in elf_bin.image_load_segment_iter(virt_base) { + let virt_start = VirtAddr::from(seg.vaddr_range.vaddr_begin); + let virt_end = VirtAddr::from(seg.vaddr_range.vaddr_end).align_up(PAGE_SIZE); + let file_offset = seg.file_range.offset_begin; + let len = virt_end - virt_start; + let flags = convert_elf_phdr_flags(seg.flags); + + if !virt_start.is_aligned(PAGE_SIZE) { + return Err(SvsmError::Mem); + } + + if file_offset > 0 { + task.mmap_user(virt_start, Some(&fh), file_offset, len, flags)?; + } else { + task.mmap_user(virt_start, None, 0, len, flags)?; + } + } + + // Make sure the mapping is gone before calling schedule + drop(vstart); + + // Setup 64k of task stack + let user_stack_size: usize = 64 * 1024; + let stack_flags: VMFileMappingFlags = VMFileMappingFlags::Fixed | VMFileMappingFlags::Write; + let stack_addr = USER_MEM_END - user_stack_size; + task.mmap_user(stack_addr, None, 0, user_stack_size, stack_flags)?; + + schedule(); + + Ok(()) +} diff --git a/stage2/src/task/mod.rs b/stage2/src/task/mod.rs new file mode 100644 index 000000000..9e4345f3c --- /dev/null +++ b/stage2/src/task/mod.rs @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +mod exec; +mod schedule; +mod tasks; +mod waiting; + +pub use schedule::{ + create_kernel_task, create_user_task, current_task, current_task_terminated, is_current_task, + schedule, schedule_init, schedule_task, terminate, RunQueue, TASKLIST, +}; + +pub use tasks::{ + is_task_fault, Task, TaskContext, TaskError, TaskListAdapter, TaskPointer, TaskRunListAdapter, + TaskState, INITIAL_TASK_ID, TASK_FLAG_SHARE_PT, +}; + +pub use exec::exec_user; +pub use waiting::WaitQueue; diff --git a/stage2/src/task/schedule.rs b/stage2/src/task/schedule.rs new file mode 100644 index 000000000..aedbe86fb --- /dev/null +++ b/stage2/src/task/schedule.rs @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +//! Round-Robin scheduler implementation for COCONUT-SVSM +//! +//! This module implements a round-robin scheduler for cooperative multi-tasking. +//! It works by assigning a single owner for each struct [`Task`]. The owner +//! depends on the state of the task: +//! +//! * [`RUNNING`] A task in running state is owned by the [`RunQueue`] and either +//! stored in the `run_list` (when the task is not actively running) or in +//! `current_task` when it is scheduled on the CPU. +//! * [`BLOCKED`] A task in this state is waiting for an event to become runnable +//! again. It is owned by a wait object when in this state. +//! * [`TERMINATED`] The task is about to be destroyed and owned by the [`RunQueue`]. +//! +//! The scheduler is cooperative. A task runs until it voluntarily calls the +//! [`schedule()`] function. +//! +//! Only when a task is in [`RUNNING`] or [`TERMINATED`] state it is assigned to a +//! specific CPU. Tasks in the [`BLOCKED`] state have no CPU assigned and will run +//! on the CPU where their event is triggered that makes them [`RUNNING`] again. +//! +//! [`RUNNING`]: super::tasks::TaskState::RUNNING +//! [`BLOCKED`]: super::tasks::TaskState::BLOCKED +//! [`TERMINATED`]: super::tasks::TaskState::TERMINATED + +extern crate alloc; + +use super::INITIAL_TASK_ID; +use super::{Task, TaskListAdapter, TaskPointer, TaskRunListAdapter}; +use crate::address::Address; +use crate::cpu::percpu::{irq_nesting_count, this_cpu}; +use crate::cpu::sse::sse_restore_context; +use crate::cpu::sse::sse_save_context; +use crate::cpu::IrqGuard; +use crate::error::SvsmError; +use crate::locking::SpinLock; +use alloc::sync::Arc; +use core::arch::{asm, global_asm}; +use core::cell::OnceCell; +use core::mem::offset_of; +use core::ptr::null_mut; +use intrusive_collections::LinkedList; + +/// A RunQueue implementation that uses an RBTree to efficiently sort the priority +/// of tasks within the queue. +#[derive(Debug, Default)] +pub struct RunQueue { + /// Linked list with runable tasks + run_list: LinkedList, + + /// Pointer to currently running task + current_task: Option, + + /// Idle task - runs when there is no other runnable task + idle_task: OnceCell, + + /// Temporary storage for tasks which are about to be terminated + terminated_task: Option, +} + +impl RunQueue { + /// Create a new runqueue for an id. The id would normally be set + /// to the APIC ID of the CPU that owns the runqueue and is used to + /// determine the affinity of tasks. + pub fn new() -> Self { + Self { + run_list: LinkedList::new(TaskRunListAdapter::new()), + current_task: None, + idle_task: OnceCell::new(), + terminated_task: None, + } + } + + /// Find the next task to run, which is either the task at the front of the + /// run_list or the idle task, if the run_list is empty. + /// + /// # Returns + /// + /// Pointer to next task to run + /// + /// # Panics + /// + /// Panics if there are no tasks to run and no idle task has been + /// allocated via [`set_idle_task()`](Self::set_idle_task). + fn get_next_task(&mut self) -> TaskPointer { + self.run_list + .pop_front() + .unwrap_or_else(|| self.idle_task.get().unwrap().clone()) + } + + /// Update state before a task is scheduled out. Non-idle tasks in RUNNING + /// state will be put at the end of the run_list. Terminated tasks will be + /// stored in the terminated_task field of the RunQueue and be destroyed + /// after the task-switch. + fn handle_task(&mut self, task: TaskPointer) { + if task.is_running() && !task.is_idle_task() { + self.run_list.push_back(task); + } else if task.is_terminated() { + self.terminated_task = Some(task); + } + } + + /// Initialized the scheduler for this (RunQueue)[RunQueue]. This method is + /// called on the very first scheduling event when there is no current task + /// yet. + /// + /// # Returns + /// + /// [TaskPointer] to the first task to run + pub fn schedule_init(&mut self) -> TaskPointer { + let task = self.get_next_task(); + self.current_task = Some(task.clone()); + task + } + + /// Prepares a task switch. The function checks if a task switch needs to + /// be done and return pointers to the current and next task. It will + /// also call `handle_task()` on the current task in case a task-switch + /// is requested. + /// + /// # Returns + /// + /// `None` when no task-switch is needed. + /// `Some` with current and next task in case a task-switch is required. + /// + /// # Panics + /// + /// Panics if there is no current task. + pub fn schedule_prepare(&mut self) -> Option<(TaskPointer, TaskPointer)> { + // Remove current and put it back into the RunQueue in case it is still + // runnable. This is important to make sure the last runnable task + // keeps running, even if it calls schedule() + let current = self.current_task.take().unwrap(); + self.handle_task(current.clone()); + + // Get next task and update current_task state + let next = self.get_next_task(); + self.current_task = Some(next.clone()); + + // Check if task switch is needed + if current != next { + Some((current, next)) + } else { + None + } + } + + pub fn current_task_id(&self) -> u32 { + self.current_task + .as_ref() + .map_or(INITIAL_TASK_ID, |t| t.get_task_id()) + } + + /// Sets the idle task for this RunQueue. This function sets a + /// OnceCell at the end and can thus be only called once. + /// + /// # Returns + /// + /// Ok(()) on success, SvsmError on failure + /// + /// # Panics + /// + /// Panics if the idle task was already set. + pub fn set_idle_task(&self, task: TaskPointer) { + task.set_idle_task(); + + // Add idle task to global task list + TASKLIST.lock().list().push_front(task.clone()); + + self.idle_task + .set(task) + .expect("Idle task already allocated"); + } + + /// Gets a pointer to the current task + /// + /// # Panics + /// + /// Panics if there is no current task. + pub fn current_task(&self) -> TaskPointer { + self.current_task.as_ref().unwrap().clone() + } +} + +/// Global task list +/// This contains every task regardless of affinity or run state. +#[derive(Debug, Default)] +pub struct TaskList { + list: Option>, +} + +impl TaskList { + pub const fn new() -> Self { + Self { list: None } + } + + pub fn list(&mut self) -> &mut LinkedList { + self.list + .get_or_insert_with(|| LinkedList::new(TaskListAdapter::new())) + } + + pub fn get_task(&self, id: u32) -> Option { + let task_list = &self.list.as_ref()?; + let mut cursor = task_list.front(); + while let Some(task) = cursor.get() { + if task.get_task_id() == id { + return cursor.clone_pointer(); + } + cursor.move_next(); + } + None + } + + fn terminate(&mut self, task: TaskPointer) { + // Set the task state as terminated. If the task being terminated is the + // current task then the task context will still need to be in scope until + // the next schedule() has completed. Schedule will keep a reference to this + // task until some time after the context switch. + task.set_task_terminated(); + let mut cursor = unsafe { self.list().cursor_mut_from_ptr(task.as_ref()) }; + cursor.remove(); + } +} + +pub static TASKLIST: SpinLock = SpinLock::new(TaskList::new()); + +pub fn create_kernel_task(entry: extern "C" fn()) -> Result { + let cpu = this_cpu(); + let task = Task::create(cpu, entry)?; + TASKLIST.lock().list().push_back(task.clone()); + + // Put task on the runqueue of this CPU + cpu.runqueue().lock_write().handle_task(task.clone()); + + schedule(); + + Ok(task) +} + +pub fn create_user_task(user_entry: usize) -> Result { + let cpu = this_cpu(); + let task = Task::create_user(cpu, user_entry)?; + TASKLIST.lock().list().push_back(task.clone()); + + // Put task on the runqueue of this CPU + cpu.runqueue().lock_write().handle_task(task.clone()); + + Ok(task) +} + +pub fn current_task() -> TaskPointer { + this_cpu().current_task() +} + +/// Check to see if the task scheduled on the current processor has the given id +pub fn is_current_task(id: u32) -> bool { + match &this_cpu().runqueue().lock_read().current_task { + Some(current_task) => current_task.get_task_id() == id, + None => id == INITIAL_TASK_ID, + } +} + +/// Terminates the current task. +/// +/// # Safety +/// +/// This function must only be called after scheduling is initialized, otherwise it will panic. +pub unsafe fn current_task_terminated() { + let cpu = this_cpu(); + let mut rq = cpu.runqueue().lock_write(); + let task_node = rq + .current_task + .as_mut() + .expect("Task termination handler called when there is no current task"); + TASKLIST.lock().terminate(task_node.clone()); +} + +pub fn terminate() { + // TODO: re-evaluate whether current_task_terminated() needs to be unsafe + unsafe { + current_task_terminated(); + } + schedule(); +} + +// SAFETY: This function returns a raw pointer to a task. It is safe +// because this function is only used in the task switch code, which also only +// takes a single reference to the next and previous tasks. Also, this +// function works on an Arc, which ensures that only a single mutable reference +// can exist. +unsafe fn task_pointer(taskptr: TaskPointer) -> *const Task { + Arc::as_ptr(&taskptr) +} + +#[inline(always)] +unsafe fn switch_to(prev: *const Task, next: *const Task) { + let cr3: u64 = unsafe { (*next).page_table.lock().cr3_value().bits() as u64 }; + + // Switch to new task + asm!( + r#" + call switch_context + "#, + in("rsi") prev as u64, + in("rdi") next as u64, + in("rdx") cr3, + options(att_syntax)); +} + +/// Initializes the [RunQueue] on the current CPU. It will switch to the idle +/// task and initialize the current_task field of the RunQueue. After this +/// function has ran it is safe to call [`schedule()`] on the current CPU. +pub fn schedule_init() { + unsafe { + let guard = IrqGuard::new(); + let next = task_pointer(this_cpu().schedule_init()); + switch_to(null_mut(), next); + drop(guard); + } +} + +fn preemption_checks() { + assert!(irq_nesting_count() == 0); +} + +/// Perform a task switch and hand the CPU over to the next task on the +/// run-list. In case the current task is terminated, it will be destroyed after +/// the switch to the next task. +pub fn schedule() { + // check if preemption is safe + preemption_checks(); + + let guard = IrqGuard::new(); + + let work = this_cpu().schedule_prepare(); + + // !!! Runqueue lock must be release here !!! + if let Some((current, next)) = work { + // Update per-cpu mappings if needed + let apic_id = this_cpu().get_apic_id(); + + if next.update_cpu(apic_id) != apic_id { + // Task has changed CPU, update per-cpu mappings + let mut pt = next.page_table.lock(); + this_cpu().populate_page_table(&mut pt); + } + + this_cpu().set_tss_rsp0(next.stack_bounds.end()); + + // Get task-pointers, consuming the Arcs and release their reference + unsafe { + let a = task_pointer(current); + let b = task_pointer(next); + sse_save_context(u64::from((*a).xsa.vaddr())); + + // Switch tasks + switch_to(a, b); + + // We're now in the context of task pointed to by 'a' + // which was previously scheduled out. + sse_restore_context(u64::from((*a).xsa.vaddr())); + } + } + + drop(guard); + + // If the previous task had terminated then we can release + // it's reference here. + let _ = this_cpu().runqueue().lock_write().terminated_task.take(); +} + +pub fn schedule_task(task: TaskPointer) { + task.set_task_running(); + this_cpu().runqueue().lock_write().handle_task(task); + schedule(); +} + +global_asm!( + r#" + .text + + switch_context: + // Save the current context. The layout must match the TaskContext structure. + pushfq + pushq %rax + pushq %rbx + pushq %rcx + pushq %rdx + pushq %rsi + pushq %rdi + pushq %rbp + pushq %r8 + pushq %r9 + pushq %r10 + pushq %r11 + pushq %r12 + pushq %r13 + pushq %r14 + pushq %r15 + pushq %rsp + + // Save the current stack pointer + testq %rsi, %rsi + jz 1f + movq %rsp, {TASK_RSP_OFFSET}(%rsi) + + 1: + // Switch to the new task state + mov %rdx, %cr3 + + // Switch to the new task stack + movq {TASK_RSP_OFFSET}(%rdi), %rsp + + // We've already restored rsp + addq $8, %rsp + + // Restore the task context + popq %r15 + popq %r14 + popq %r13 + popq %r12 + popq %r11 + popq %r10 + popq %r9 + popq %r8 + popq %rbp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %rbx + popq %rax + popfq + + ret + "#, + TASK_RSP_OFFSET = const offset_of!(Task, rsp), + options(att_syntax) +); diff --git a/stage2/src/task/tasks.rs b/stage2/src/task/tasks.rs new file mode 100644 index 000000000..53c4e7108 --- /dev/null +++ b/stage2/src/task/tasks.rs @@ -0,0 +1,750 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +extern crate alloc; + +use alloc::collections::btree_map::BTreeMap; +use alloc::sync::Arc; +use core::fmt; +use core::mem::size_of; +use core::num::NonZeroUsize; +use core::sync::atomic::{AtomicU32, Ordering}; + +use crate::address::{Address, VirtAddr}; +use crate::cpu::idt::svsm::return_new_task; +use crate::cpu::percpu::PerCpu; +use crate::cpu::sse::{get_xsave_area_size, sse_restore_context}; +use crate::cpu::X86ExceptionContext; +use crate::cpu::{irqs_enable, X86GeneralRegs}; +use crate::error::SvsmError; +use crate::fs::FileHandle; +use crate::locking::{RWLock, SpinLock}; +use crate::mm::pagetable::{PTEntryFlags, PageTable}; +use crate::mm::vm::{Mapping, VMFileMappingFlags, VMKernelStack, VMR}; +use crate::mm::PageBox; +use crate::mm::{ + mappings::create_anon_mapping, mappings::create_file_mapping, VMMappingGuard, + SVSM_PERTASK_BASE, SVSM_PERTASK_END, SVSM_PERTASK_STACK_BASE, USER_MEM_END, USER_MEM_START, +}; +use crate::syscall::{Obj, ObjError, ObjHandle}; +use crate::types::{SVSM_USER_CS, SVSM_USER_DS}; +use crate::utils::MemoryRegion; +use intrusive_collections::{intrusive_adapter, LinkedListAtomicLink}; + +use super::schedule::{current_task_terminated, schedule}; + +pub const INITIAL_TASK_ID: u32 = 1; + +#[derive(PartialEq, Debug, Copy, Clone, Default)] +pub enum TaskState { + RUNNING, + BLOCKED, + #[default] + TERMINATED, +} + +#[derive(Clone, Copy, Debug)] +pub enum TaskError { + // Attempt to close a non-terminated task + NotTerminated, + // A closed task could not be removed from the task list + CloseFailed, +} + +impl From for SvsmError { + fn from(e: TaskError) -> Self { + Self::Task(e) + } +} + +pub const TASK_FLAG_SHARE_PT: u16 = 0x01; + +#[derive(Debug, Default)] +struct TaskIDAllocator { + next_id: AtomicU32, +} + +impl TaskIDAllocator { + const fn new() -> Self { + Self { + next_id: AtomicU32::new(INITIAL_TASK_ID + 1), + } + } + + fn next_id(&self) -> u32 { + let mut id = self.next_id.fetch_add(1, Ordering::Relaxed); + // Reserve IDs of 0 and 1 + while (id == 0_u32) || (id == INITIAL_TASK_ID) { + id = self.next_id.fetch_add(1, Ordering::Relaxed); + } + id + } +} + +static TASK_ID_ALLOCATOR: TaskIDAllocator = TaskIDAllocator::new(); + +#[repr(C)] +#[derive(Default, Debug, Clone, Copy)] +pub struct TaskContext { + pub rsp: u64, + pub regs: X86GeneralRegs, + pub flags: u64, + pub ret_addr: u64, +} + +#[repr(C)] +struct TaskSchedState { + /// Whether this is an idle task + idle_task: bool, + + /// Current state of the task + state: TaskState, + + /// CPU this task is currently assigned to + cpu: u32, +} + +impl TaskSchedState { + pub fn panic_on_idle(&mut self, msg: &str) -> &mut Self { + if self.idle_task { + panic!("{}", msg); + } + self + } +} + +pub struct Task { + pub rsp: u64, + + /// XSave area + pub xsa: PageBox<[u8]>, + + pub stack_bounds: MemoryRegion, + + /// Page table that is loaded when the task is scheduled + pub page_table: SpinLock>, + + /// Task virtual memory range for use at CPL 0 + vm_kernel_range: VMR, + + /// Task virtual memory range for use at CPL 3 - None for kernel tasks + vm_user_range: Option, + + /// State relevant for scheduler + sched_state: RWLock, + + /// ID of the task + id: u32, + + /// Link to global task list + list_link: LinkedListAtomicLink, + + /// Link to scheduler run queue + runlist_link: LinkedListAtomicLink, + + /// Objects shared among threads within the same process + objs: Arc>>>, +} + +// SAFETY: Send + Sync is required for Arc to implement Send. All members +// of `Task` are Send + Sync except for the intrusive_collection links, which +// are only Send. The only access to these is via the intrusive_adapter! +// generated code which does not use them concurrently across threads. +unsafe impl Sync for Task {} + +pub type TaskPointer = Arc; + +intrusive_adapter!(pub TaskRunListAdapter = TaskPointer: Task { runlist_link: LinkedListAtomicLink }); +intrusive_adapter!(pub TaskListAdapter = TaskPointer: Task { list_link: LinkedListAtomicLink }); + +impl PartialEq for Task { + fn eq(&self, other: &Self) -> bool { + core::ptr::eq(self, other) + } +} + +impl fmt::Debug for Task { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Task") + .field("rsp", &self.rsp) + .field("state", &self.sched_state.lock_read().state) + .field("id", &self.id) + .finish() + } +} + +impl Task { + pub fn create(cpu: &PerCpu, entry: extern "C" fn()) -> Result { + let mut pgtable = cpu.get_pgtable().clone_shared()?; + + cpu.populate_page_table(&mut pgtable); + + let vm_kernel_range = VMR::new(SVSM_PERTASK_BASE, SVSM_PERTASK_END, PTEntryFlags::empty()); + vm_kernel_range.initialize()?; + + let xsa = Self::allocate_xsave_area(); + let xsa_addr = u64::from(xsa.vaddr()) as usize; + let (stack, raw_bounds, rsp_offset) = Self::allocate_ktask_stack(cpu, entry, xsa_addr)?; + vm_kernel_range.insert_at(SVSM_PERTASK_STACK_BASE, stack)?; + + vm_kernel_range.populate(&mut pgtable); + + // Remap at the per-task offset + let bounds = MemoryRegion::new( + SVSM_PERTASK_STACK_BASE + raw_bounds.start().into(), + raw_bounds.len(), + ); + + Ok(Arc::new(Task { + rsp: bounds + .end() + .checked_sub(rsp_offset) + .expect("Invalid stack offset from task::allocate_ktask_stack()") + .bits() as u64, + xsa, + stack_bounds: bounds, + page_table: SpinLock::new(pgtable), + vm_kernel_range, + vm_user_range: None, + sched_state: RWLock::new(TaskSchedState { + idle_task: false, + state: TaskState::RUNNING, + cpu: cpu.get_apic_id(), + }), + id: TASK_ID_ALLOCATOR.next_id(), + list_link: LinkedListAtomicLink::default(), + runlist_link: LinkedListAtomicLink::default(), + objs: Arc::new(RWLock::new(BTreeMap::new())), + })) + } + + pub fn create_user(cpu: &PerCpu, user_entry: usize) -> Result { + let mut pgtable = cpu.get_pgtable().clone_shared()?; + + cpu.populate_page_table(&mut pgtable); + + let vm_kernel_range = VMR::new(SVSM_PERTASK_BASE, SVSM_PERTASK_END, PTEntryFlags::empty()); + vm_kernel_range.initialize()?; + + let xsa = Self::allocate_xsave_area(); + let xsa_addr = u64::from(xsa.vaddr()) as usize; + let (stack, raw_bounds, stack_offset) = + Self::allocate_utask_stack(cpu, user_entry, xsa_addr)?; + vm_kernel_range.insert_at(SVSM_PERTASK_STACK_BASE, stack)?; + + vm_kernel_range.populate(&mut pgtable); + + let vm_user_range = VMR::new(USER_MEM_START, USER_MEM_END, PTEntryFlags::USER); + vm_user_range.initialize_lazy()?; + + // Remap at the per-task offset + let bounds = MemoryRegion::new( + SVSM_PERTASK_STACK_BASE + raw_bounds.start().into(), + raw_bounds.len(), + ); + + Ok(Arc::new(Task { + rsp: bounds + .end() + .checked_sub(stack_offset) + .expect("Invalid stack offset from task::allocate_utask_stack()") + .bits() as u64, + xsa, + stack_bounds: bounds, + page_table: SpinLock::new(pgtable), + vm_kernel_range, + vm_user_range: Some(vm_user_range), + sched_state: RWLock::new(TaskSchedState { + idle_task: false, + state: TaskState::RUNNING, + cpu: cpu.get_apic_id(), + }), + id: TASK_ID_ALLOCATOR.next_id(), + list_link: LinkedListAtomicLink::default(), + runlist_link: LinkedListAtomicLink::default(), + objs: Arc::new(RWLock::new(BTreeMap::new())), + })) + } + + pub fn stack_bounds(&self) -> MemoryRegion { + self.stack_bounds + } + + pub fn get_task_id(&self) -> u32 { + self.id + } + + pub fn set_task_running(&self) { + self.sched_state.lock_write().state = TaskState::RUNNING; + } + + pub fn set_task_terminated(&self) { + self.sched_state + .lock_write() + .panic_on_idle("Trying to terminate idle task") + .state = TaskState::TERMINATED; + } + + pub fn set_task_blocked(&self) { + self.sched_state + .lock_write() + .panic_on_idle("Trying to block idle task") + .state = TaskState::BLOCKED; + } + + pub fn is_running(&self) -> bool { + self.sched_state.lock_read().state == TaskState::RUNNING + } + + pub fn is_terminated(&self) -> bool { + self.sched_state.lock_read().state == TaskState::TERMINATED + } + + pub fn set_idle_task(&self) { + self.sched_state.lock_write().idle_task = true; + } + + pub fn is_idle_task(&self) -> bool { + self.sched_state.lock_read().idle_task + } + + pub fn update_cpu(&self, new_cpu: u32) -> u32 { + let mut state = self.sched_state.lock_write(); + let old_cpu = state.cpu; + state.cpu = new_cpu; + old_cpu + } + + pub fn handle_pf(&self, vaddr: VirtAddr, write: bool) -> Result<(), SvsmError> { + self.vm_kernel_range.handle_page_fault(vaddr, write) + } + + pub fn fault(&self, vaddr: VirtAddr, write: bool) -> Result<(), SvsmError> { + if vaddr >= USER_MEM_START && vaddr < USER_MEM_END && self.vm_user_range.is_some() { + let vmr = self.vm_user_range.as_ref().unwrap(); + let mut pgtbl = self.page_table.lock(); + vmr.populate_addr(&mut pgtbl, vaddr); + vmr.handle_page_fault(vaddr, write)?; + Ok(()) + } else { + Err(SvsmError::Mem) + } + } + + fn allocate_stack_common() -> Result<(Arc, MemoryRegion), SvsmError> { + let stack = VMKernelStack::new()?; + let bounds = stack.bounds(VirtAddr::from(0u64)); + + let mapping = Arc::new(Mapping::new(stack)); + + Ok((mapping, bounds)) + } + + fn allocate_ktask_stack( + cpu: &PerCpu, + entry: extern "C" fn(), + xsa_addr: usize, + ) -> Result<(Arc, MemoryRegion, usize), SvsmError> { + let (mapping, bounds) = Task::allocate_stack_common()?; + + let percpu_mapping = cpu.new_mapping(mapping.clone())?; + + // We need to setup a context on the stack that matches the stack layout + // defined in switch_context below. + let stack_ptr = (percpu_mapping.virt_addr() + bounds.end().bits()).as_mut_ptr::(); + + // 'Push' the task frame onto the stack + unsafe { + let tc_offset: isize = ((size_of::() / size_of::()) + 1) + .try_into() + .unwrap(); + let task_context = stack_ptr.offset(-tc_offset).cast::(); + // The processor flags must always be in a default state, unrelated + // to the flags of the caller. In particular, interrupts must be + // disabled because the task switch code expects to execute a new + // task with interrupts disabled. + (*task_context).flags = 2; + // ret_addr + (*task_context).regs.rdi = entry as *const () as usize; + // xsave area addr + (*task_context).regs.rsi = xsa_addr; + (*task_context).ret_addr = run_kernel_task as *const () as u64; + // Task termination handler for when entry point returns + stack_ptr.offset(-1).write(task_exit as *const () as u64); + } + + Ok((mapping, bounds, size_of::() + size_of::())) + } + + fn allocate_utask_stack( + cpu: &PerCpu, + user_entry: usize, + xsa_addr: usize, + ) -> Result<(Arc, MemoryRegion, usize), SvsmError> { + let (mapping, bounds) = Task::allocate_stack_common()?; + + let percpu_mapping = cpu.new_mapping(mapping.clone())?; + + // We need to setup a context on the stack that matches the stack layout + // defined in switch_context below. + let stack_ptr = (percpu_mapping.virt_addr() + bounds.end().bits()).as_mut_ptr::(); + + let mut stack_offset = size_of::(); + + // 'Push' the task frame onto the stack + unsafe { + // Setup IRQ return frame. User-mode tasks always run with + // interrupts enabled. + let mut iret_frame = X86ExceptionContext::default(); + iret_frame.frame.rip = user_entry; + iret_frame.frame.cs = (SVSM_USER_CS | 3).into(); + iret_frame.frame.flags = 0x202; + iret_frame.frame.rsp = (USER_MEM_END - 8).into(); + iret_frame.frame.ss = (SVSM_USER_DS | 3).into(); + + // Copy IRET frame to stack + let stack_iret_frame = stack_ptr.sub(stack_offset).cast::(); + *stack_iret_frame = iret_frame; + + stack_offset += size_of::(); + + let mut task_context = TaskContext { + ret_addr: VirtAddr::from(return_new_task as *const ()) + .bits() + .try_into() + .unwrap(), + ..Default::default() + }; + + // xsave area addr + task_context.regs.rdi = xsa_addr; + let stack_task_context = stack_ptr.sub(stack_offset).cast::(); + *stack_task_context = task_context; + } + + Ok((mapping, bounds, stack_offset)) + } + + fn allocate_xsave_area() -> PageBox<[u8]> { + let len = get_xsave_area_size() as usize; + let xsa = PageBox::<[u8]>::try_new_slice(0u8, NonZeroUsize::new(len).unwrap()); + if xsa.is_err() { + panic!("Error while allocating xsave area"); + } + xsa.unwrap() + } + + pub fn mmap_common( + vmr: &VMR, + addr: VirtAddr, + file: Option<&FileHandle>, + offset: usize, + size: usize, + flags: VMFileMappingFlags, + ) -> Result { + let mapping = if let Some(f) = file { + create_file_mapping(f, offset, size, flags)? + } else { + create_anon_mapping(size, flags)? + }; + + if flags.contains(VMFileMappingFlags::Fixed) { + Ok(vmr.insert_at(addr, mapping)?) + } else { + Ok(vmr.insert_hint(addr, mapping)?) + } + } + + pub fn mmap_kernel( + &self, + addr: VirtAddr, + file: Option<&FileHandle>, + offset: usize, + size: usize, + flags: VMFileMappingFlags, + ) -> Result { + Self::mmap_common(&self.vm_kernel_range, addr, file, offset, size, flags) + } + + pub fn mmap_kernel_guard<'a>( + &'a self, + addr: VirtAddr, + file: Option<&FileHandle>, + offset: usize, + size: usize, + flags: VMFileMappingFlags, + ) -> Result, SvsmError> { + let vaddr = Self::mmap_common(&self.vm_kernel_range, addr, file, offset, size, flags)?; + Ok(VMMappingGuard::new(&self.vm_kernel_range, vaddr)) + } + + pub fn mmap_user( + &self, + addr: VirtAddr, + file: Option<&FileHandle>, + offset: usize, + size: usize, + flags: VMFileMappingFlags, + ) -> Result { + if self.vm_user_range.is_none() { + return Err(SvsmError::Mem); + } + + let vmr = self.vm_user_range.as_ref().unwrap(); + + Self::mmap_common(vmr, addr, file, offset, size, flags) + } + + pub fn munmap_kernel(&self, addr: VirtAddr) -> Result<(), SvsmError> { + self.vm_kernel_range.remove(addr)?; + Ok(()) + } + + pub fn munmap_user(&self, addr: VirtAddr) -> Result<(), SvsmError> { + if self.vm_user_range.is_none() { + return Err(SvsmError::Mem); + } + + self.vm_user_range.as_ref().unwrap().remove(addr)?; + Ok(()) + } + + /// Adds an object to the current task. + /// + /// # Arguments + /// + /// * `obj` - The object to be added. + /// + /// # Returns + /// + /// * `Result` - Returns the object handle for the object + /// to be added if successful, or an `SvsmError` on failure. + /// + /// # Errors + /// + /// This function will return an error if allocating the object handle fails. + pub fn add_obj(&self, obj: Arc) -> Result { + let mut objs = self.objs.lock_write(); + let last_key = objs + .keys() + .last() + .map_or(Some(0), |k| u32::from(*k).checked_add(1)) + .ok_or(SvsmError::from(ObjError::InvalidHandle))?; + let id = ObjHandle::new(if last_key != objs.len() as u32 { + objs.keys() + .enumerate() + .find(|(i, &key)| *i as u32 != u32::from(key)) + .unwrap() + .0 as u32 + } else { + last_key + }); + + objs.insert(id, obj); + + Ok(id) + } + + /// Removes an object from the current task. + /// + /// # Arguments + /// + /// * `id` - The ObjHandle for the object to be removed. + /// + /// # Returns + /// + /// * `Result>, SvsmError>` - Returns the removed `Arc` + /// on success, or an `SvsmError` on failure. + /// + /// # Errors + /// + /// This function will return an error if the object handle id does not + /// exist in the current task. + pub fn remove_obj(&self, id: ObjHandle) -> Result, SvsmError> { + self.objs + .lock_write() + .remove(&id) + .ok_or(ObjError::NotFound.into()) + } + + /// Retrieves an object from the current task. + /// + /// # Arguments + /// + /// * `id` - The ObjHandle for the object to be retrieved. + /// + /// # Returns + /// + /// * `Result>, SvsmError>` - Returns the `Arc` on + /// success, or an `SvsmError` on failure. + /// + /// # Errors + /// + /// This function will return an error if the object handle id does not exist + /// in the current task. + pub fn get_obj(&self, id: ObjHandle) -> Result, SvsmError> { + self.objs + .lock_read() + .get(&id) + .cloned() + .ok_or(ObjError::NotFound.into()) + } +} + +pub fn is_task_fault(vaddr: VirtAddr) -> bool { + (vaddr >= USER_MEM_START && vaddr < USER_MEM_END) + || (vaddr >= SVSM_PERTASK_BASE && vaddr < SVSM_PERTASK_END) +} + +/// Runs the first time a new task is scheduled, in the context of the new +/// task. Any first-time initialization and setup work for a new task that +/// needs to happen in its context must be done here. +#[no_mangle] +fn setup_new_task(xsa_addr: u64) { + // Re-enable IRQs here, as they are still disabled from the + // schedule()/sched_init() functions. After the context switch the IrqGuard + // from the previous task is not dropped, which causes IRQs to stay + // disabled in the new task. + // This only needs to be done for the first time a task runs. Any + // subsequent task switches will go through schedule() and there the guard + // is dropped, re-enabling IRQs. + + // SAFETY: Safe because this matches the IrqGuard drop in + // schedule()/schedule_init(). See description above. + unsafe { + irqs_enable(); + sse_restore_context(xsa_addr); + } +} + +extern "C" fn run_kernel_task(entry: extern "C" fn(), xsa_addr: u64) { + setup_new_task(xsa_addr); + entry(); +} + +extern "C" fn task_exit() { + unsafe { + current_task_terminated(); + } + schedule(); +} + +#[cfg(test)] +mod tests { + use crate::task::create_kernel_task; + use core::arch::asm; + use core::arch::global_asm; + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_media_and_x87_instructions() { + let ret: u64; + unsafe { + asm!("call test_fpu", out("rax") ret, options(att_syntax)); + } + + assert_eq!(ret, 0); + } + + global_asm!( + r#" + .text + test_fpu: + movq $0x3ff, %rax + shl $52, %rax + // rax contains 1 in Double Precison FP representation + movd %rax, %xmm1 + movapd %xmm1, %xmm3 + + movq $0x400, %rax + shl $52, %rax + // rax contains 2 in Double Precison FP representation + movd %rax, %xmm2 + + divsd %xmm2, %xmm3 + movq $0, %rax + ret + "#, + options(att_syntax) + ); + + global_asm!( + r#" + .text + check_fpu: + movq $1, %rax + movq $0x3ff, %rbx + shl $52, %rbx + // rbx contains 1 in Double Precison FP representation + movd %rbx, %xmm4 + movapd %xmm4, %xmm6 + comisd %xmm4, %xmm1 + jnz 1f + + movq $0x400, %rbx + shl $52, %rbx + // rbx contains 2 in Double Precison FP representation + movd %rbx, %xmm5 + comisd %xmm5, %xmm2 + jnz 1f + + divsd %xmm5, %xmm6 + comisd %xmm6, %xmm3 + jnz 1f + movq $0, %rax + 1: + ret + "#, + options(att_syntax) + ); + + global_asm!( + r#" + .text + alter_fpu: + movq $0x400, %rax + shl $52, %rax + // rax contains 2 in Double Precison FP representation + movd %rax, %xmm1 + movapd %xmm1, %xmm3 + + movq $0x3ff, %rax + shl $52, %rax + // rax contains 1 in Double Precison FP representation + movd %rax, %xmm2 + divsd %xmm3, %xmm2 + movq $0, %rax + ret + "#, + options(att_syntax) + ); + + #[test] + #[cfg_attr(not(test_in_svsm), ignore = "Can only be run inside guest")] + fn test_fpu_context_switch() { + create_kernel_task(task1).expect("Failed to launch request processing task"); + } + + extern "C" fn task1() { + let ret: u64; + unsafe { + asm!("call test_fpu", options(att_syntax)); + } + + create_kernel_task(task2).expect("Failed to launch request processing task"); + + unsafe { + asm!("call check_fpu", out("rax") ret, options(att_syntax)); + } + assert_eq!(ret, 0); + } + + extern "C" fn task2() { + unsafe { + asm!("call alter_fpu", options(att_syntax)); + } + } +} diff --git a/stage2/src/task/waiting.rs b/stage2/src/task/waiting.rs new file mode 100644 index 000000000..8bbd1aebe --- /dev/null +++ b/stage2/src/task/waiting.rs @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2024 SUSE LLC +// +// Author: Joerg Roedel + +use super::tasks::TaskPointer; + +#[derive(Debug, Default)] +pub struct WaitQueue { + waiter: Option, +} + +impl WaitQueue { + pub const fn new() -> Self { + Self { waiter: None } + } + + pub fn wait_for_event(&mut self, current_task: TaskPointer) { + assert!(self.waiter.is_none()); + + current_task.set_task_blocked(); + self.waiter = Some(current_task); + } + + pub fn wakeup(&mut self) -> Option { + self.waiter.take() + } +} diff --git a/stage2/src/testing.rs b/stage2/src/testing.rs new file mode 100644 index 000000000..661924f89 --- /dev/null +++ b/stage2/src/testing.rs @@ -0,0 +1,94 @@ +use log::info; +use test::ShouldPanic; + +use crate::{ + cpu::percpu::current_ghcb, + locking::{LockGuard, SpinLock}, + platform::SVSM_PLATFORM, + serial::SerialPort, + sev::ghcb::GHCBIOSize, +}; + +#[macro_export] +macro_rules! assert_eq_warn { + ($left:expr, $right:expr) => { + { + let left = $left; + let right = $right; + if left != right { + log::warn!( + "Assertion warning failed at {}:{}:{}:\nassertion `left == right` failed\n left: {left:?}\n right: {right:?}", + file!(), + line!(), + column!(), + ); + } + } + }; +} +pub use assert_eq_warn; + +static SERIAL_PORT: SpinLock>> = SpinLock::new(None); + +/// Byte used to tell the host the request we need for the test. +/// These values must be aligned with `test_io()` in scripts/test-in-svsm.sh +#[repr(u8)] +#[derive(Clone, Copy, Debug)] +pub enum IORequest { + NOP = 0x00, + /// get SEV-SNP pre-calculated launch measurement (48 bytes) from the host + GetLaunchMeasurement = 0x01, +} + +/// Return the serial port to communicate with the host for a given request +/// used in a test. The request (first byte) is sent by this function, so the +/// caller can start using the serial port according to the request implemented +/// in `test_io()` in scripts/test-in-svsm.sh +pub fn svsm_test_io() -> LockGuard<'static, Option>> { + let mut sp = SERIAL_PORT.lock(); + if sp.is_none() { + let io_port = SVSM_PLATFORM.get_io_port(); + let serial_port = SerialPort::new(io_port, 0x2e8 /*COM4*/); + *sp = Some(serial_port); + serial_port.init(); + } + + sp +} + +pub fn svsm_test_runner(test_cases: &[&test::TestDescAndFn]) { + info!("running {} tests", test_cases.len()); + for mut test_case in test_cases.iter().copied().copied() { + if test_case.desc.should_panic == ShouldPanic::Yes { + test_case.desc.ignore = true; + test_case + .desc + .ignore_message + .get_or_insert("#[should_panic] not supported"); + } + + if test_case.desc.ignore { + if let Some(message) = test_case.desc.ignore_message { + info!("test {} ... ignored, {message}", test_case.desc.name.0); + } else { + info!("test {} ... ignored", test_case.desc.name.0); + } + continue; + } + + info!("test {} ...", test_case.desc.name.0); + (test_case.testfn.0)(); + } + + info!("All tests passed!"); + + exit(); +} + +fn exit() -> ! { + const QEMU_EXIT_PORT: u16 = 0xf4; + current_ghcb() + .ioio_out(QEMU_EXIT_PORT, GHCBIOSize::Size32, 0) + .unwrap(); + unreachable!(); +} diff --git a/stage2/src/types.rs b/stage2/src/types.rs new file mode 100644 index 000000000..ec42b99f9 --- /dev/null +++ b/stage2/src/types.rs @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::error::SvsmError; +use crate::sev::vmsa::VMPL_MAX; + +pub const PAGE_SHIFT: usize = 12; +pub const PAGE_SHIFT_2M: usize = 21; +pub const PAGE_SHIFT_1G: usize = 30; +pub const PAGE_SIZE: usize = 1 << PAGE_SHIFT; +pub const PAGE_SIZE_2M: usize = 1 << PAGE_SHIFT_2M; +pub const PAGE_SIZE_1G: usize = 1 << PAGE_SHIFT_1G; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PageSize { + Regular, + Huge, +} + +impl From for usize { + fn from(psize: PageSize) -> Self { + match psize { + PageSize::Regular => PAGE_SIZE, + PageSize::Huge => PAGE_SIZE_2M, + } + } +} + +#[expect(clippy::identity_op)] +pub const SVSM_CS: u16 = 1 * 8; +pub const SVSM_DS: u16 = 2 * 8; +pub const SVSM_USER_CS: u16 = 3 * 8; +pub const SVSM_USER_DS: u16 = 4 * 8; +pub const SVSM_TSS: u16 = 6 * 8; + +pub const SVSM_CS_FLAGS: u16 = 0x29b; +pub const SVSM_DS_FLAGS: u16 = 0xc93; +pub const SVSM_TR_FLAGS: u16 = 0x89; + +/// VMPL level the guest OS will be executed at. +/// Keep VMPL 1 for the SVSM and execute the OS at VMPL-2. This leaves VMPL-3 +/// free for the OS to use in the future. +pub const GUEST_VMPL: usize = 2; + +const _: () = assert!(GUEST_VMPL > 0 && GUEST_VMPL < VMPL_MAX); + +pub const MAX_CPUS: usize = 512; + +/// Length in byte which represents maximum 8 bytes(u64) +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)] +pub enum Bytes { + #[default] + Zero, + One, + Two, + Four = 4, + Eight = 8, +} + +impl Bytes { + pub fn mask(&self) -> u64 { + match self { + Bytes::Zero => 0, + Bytes::One => (1 << 8) - 1, + Bytes::Two => (1 << 16) - 1, + Bytes::Four => (1 << 32) - 1, + Bytes::Eight => u64::MAX, + } + } +} + +impl TryFrom for Bytes { + type Error = SvsmError; + + fn try_from(val: usize) -> Result { + match val { + 0 => Ok(Bytes::Zero), + 1 => Ok(Bytes::One), + 2 => Ok(Bytes::Two), + 4 => Ok(Bytes::Four), + 8 => Ok(Bytes::Eight), + _ => Err(SvsmError::InvalidBytes), + } + } +} diff --git a/stage2/src/utils/bitmap_allocator.rs b/stage2/src/utils/bitmap_allocator.rs new file mode 100644 index 000000000..2965b095d --- /dev/null +++ b/stage2/src/utils/bitmap_allocator.rs @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Roy Hopkins + +use core::fmt::Debug; + +pub trait BitmapAllocator { + const CAPACITY: usize; + + fn alloc(&mut self, entries: usize, align: usize) -> Option; + fn free(&mut self, start: usize, entries: usize); + + fn set(&mut self, start: usize, entries: usize, value: bool); + fn next_free(&self, start: usize) -> Option; + fn get(&self, offset: usize) -> bool; + fn empty(&self) -> bool; + fn capacity(&self) -> usize; + fn used(&self) -> usize; + #[cfg(fuzzing)] + fn max_align(&self) -> usize { + (::CAPACITY.ilog2() - 1) as usize + } +} + +pub type BitmapAllocator1024 = BitmapAllocatorTree; + +#[derive(Debug, Default, Copy, Clone)] +pub struct BitmapAllocator64 { + bits: u64, +} + +impl BitmapAllocator64 { + pub const fn new() -> Self { + Self { bits: u64::MAX } + } + + #[cfg(fuzzing)] + pub fn get_bits(&self) -> u64 { + self.bits + } +} + +impl BitmapAllocator for BitmapAllocator64 { + const CAPACITY: usize = u64::BITS as usize; + + fn alloc(&mut self, entries: usize, align: usize) -> Option { + alloc_aligned(self, entries, align) + } + + fn free(&mut self, start: usize, entries: usize) { + self.set(start, entries, false); + } + + fn set(&mut self, start: usize, entries: usize, value: bool) { + assert!(entries > 0); + assert!((start + entries) <= BitmapAllocator64::CAPACITY); + // Create a mask for changing the bitmap + let start_mask = !((1 << start) - 1); + // Need to do some bit shifting to avoid overflow when top bit set + let end_mask = (((1 << (start + entries - 1)) - 1) << 1) + 1; + let mask = start_mask & end_mask; + + if value { + self.bits |= mask; + } else { + self.bits &= !mask; + } + } + + fn next_free(&self, start: usize) -> Option { + assert!(start < Self::CAPACITY); + let mask: u64 = (1 << start) - 1; + let idx = (self.bits | mask).trailing_ones() as usize; + (idx < Self::CAPACITY).then_some(idx) + } + + fn get(&self, offset: usize) -> bool { + assert!(offset < BitmapAllocator64::CAPACITY); + (self.bits & (1 << offset)) != 0 + } + + fn empty(&self) -> bool { + self.bits == 0 + } + + fn capacity(&self) -> usize { + Self::CAPACITY + } + + fn used(&self) -> usize { + self.bits.count_ones() as usize + } +} + +#[derive(Debug, Default, Clone)] +pub struct BitmapAllocatorTree { + bits: u16, + child: [T; 16], +} + +impl BitmapAllocatorTree { + pub const fn new() -> Self { + Self { + bits: u16::MAX, + child: [BitmapAllocator64::new(); 16], + } + } + + #[cfg(fuzzing)] + pub fn get_child(&self, index: usize) -> BitmapAllocator64 { + self.child[index] + } +} + +impl BitmapAllocator for BitmapAllocatorTree { + const CAPACITY: usize = T::CAPACITY * 16; + + fn alloc(&mut self, entries: usize, align: usize) -> Option { + alloc_aligned(self, entries, align) + } + + fn free(&mut self, start: usize, entries: usize) { + self.set(start, entries, false); + } + + fn set(&mut self, start: usize, entries: usize, value: bool) { + assert!((start + entries) <= Self::CAPACITY); + let mut offset = start % T::CAPACITY; + let mut remain = entries; + for index in (start / T::CAPACITY)..16 { + let child_size = if remain > (T::CAPACITY - offset) { + T::CAPACITY - offset + } else { + remain + }; + remain -= child_size; + + self.child[index].set(offset, child_size, value); + if self.child[index].empty() { + self.bits &= !(1 << index); + } else { + self.bits |= 1 << index; + } + if remain == 0 { + break; + } + // Only the first loop iteration uses a non-zero offset + offset = 0; + } + } + + fn next_free(&self, start: usize) -> Option { + assert!(start < Self::CAPACITY); + let mut offset = start % T::CAPACITY; + for index in (start / T::CAPACITY)..16 { + if let Some(next_offset) = self.child[index].next_free(offset) { + return Some(next_offset + (index * T::CAPACITY)); + } + // Only the first loop iteration uses a non-zero offset + offset = 0; + } + None + } + + fn get(&self, offset: usize) -> bool { + assert!(offset < Self::CAPACITY); + let index = offset / T::CAPACITY; + self.child[index].get(offset % T::CAPACITY) + } + + fn empty(&self) -> bool { + self.bits == 0 + } + + fn capacity(&self) -> usize { + Self::CAPACITY + } + + fn used(&self) -> usize { + self.child.iter().map(|c| c.used()).sum() + } +} + +fn alloc_aligned(ba: &mut impl BitmapAllocator, entries: usize, align: usize) -> Option { + // Iterate through the bitmap checking on each alignment boundary + // for a free range of the requested size + if align >= (ba.capacity().ilog2() as usize) { + return None; + } + let align_mask = (1 << align) - 1; + let mut offset = 0; + while (offset + entries) <= ba.capacity() { + if let Some(offset_free) = ba.next_free(offset) { + // If the next free offset does not match the current aligned + // offset then move forward to the next aligned offset + if offset_free != offset { + offset = ((offset_free - 1) & !align_mask) + (1 << align); + continue; + } + // The aligned offset is free. Keep checking the next bit until we + // reach the requested size + assert!((offset + entries) <= ba.capacity()); + let mut free_entries = 0; + for size_check in offset..(offset + entries) { + if !ba.get(size_check) { + free_entries += 1; + } else { + break; + } + } + if free_entries == entries { + // Mark the range as in-use + ba.set(offset, entries, true); + return Some(offset); + } + } + offset += 1 << align; + } + None +} + +// +// Tests +// + +#[cfg(test)] +mod tests { + use super::{BitmapAllocator, BitmapAllocator64}; + + use super::BitmapAllocatorTree; + + #[test] + fn test_set_single() { + let mut b = BitmapAllocator64 { bits: 0 }; + b.set(0, 1, true); + assert_eq!(b.bits, 0x0000000000000001); + b.set(8, 1, true); + assert_eq!(b.bits, 0x0000000000000101); + b.set(63, 1, true); + assert_eq!(b.bits, 0x8000000000000101); + assert_eq!(b.used(), 3); + } + + #[test] + fn test_clear_single() { + let mut b = BitmapAllocator64 { bits: u64::MAX }; + b.set(0, 1, false); + assert_eq!(b.bits, 0xfffffffffffffffe); + b.set(8, 1, false); + assert_eq!(b.bits, 0xfffffffffffffefe); + b.set(63, 1, false); + assert_eq!(b.bits, 0x7ffffffffffffefe); + assert_eq!(b.used(), 64 - 3); + } + + #[test] + fn test_set_range() { + let mut b = BitmapAllocator64 { bits: 0 }; + b.set(0, 9, true); + assert_eq!(b.bits, 0x00000000000001ff); + b.set(11, 4, true); + assert_eq!(b.bits, 0x00000000000079ff); + b.set(61, 3, true); + assert_eq!(b.bits, 0xe0000000000079ff); + assert_eq!(b.used(), 16); + } + + #[test] + fn test_clear_range() { + let mut b = BitmapAllocator64 { bits: u64::MAX }; + b.set(0, 9, false); + assert_eq!(b.bits, !0x00000000000001ff); + b.set(11, 4, false); + assert_eq!(b.bits, !0x00000000000079ff); + b.set(61, 3, false); + assert_eq!(b.bits, !0xe0000000000079ff); + assert_eq!(b.used(), 64 - 16); + } + + #[test] + #[should_panic] + fn test_exceed_range() { + let mut b = BitmapAllocator64 { bits: 0 }; + b.set(0, 65, true); + } + + #[test] + #[should_panic] + fn test_exceed_start() { + let mut b = BitmapAllocator64 { bits: 0 }; + b.set(64, 1, true); + } + + #[test] + fn test_next_free() { + let mut b = BitmapAllocator64 { + bits: !0x8000000000000101, + }; + assert_eq!(b.next_free(0), Some(0)); + assert_eq!(b.next_free(1), Some(8)); + assert_eq!(b.next_free(9), Some(63)); + b.set(63, 1, true); + assert_eq!(b.next_free(9), None); + } + + #[test] + fn alloc_simple() { + let mut b = BitmapAllocator64 { bits: 0 }; + assert_eq!(b.alloc(1, 0), Some(0)); + assert_eq!(b.alloc(1, 0), Some(1)); + assert_eq!(b.alloc(1, 0), Some(2)); + } + + #[test] + fn alloc_aligned() { + let mut b = BitmapAllocator64 { bits: 0 }; + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!(b.alloc(1, 4), Some(0)); + assert_eq!(b.alloc(1, 4), Some(16)); + assert_eq!(b.alloc(1, 4), Some(32)); + } + + #[test] + fn alloc_large_aligned() { + let mut b = BitmapAllocator64 { bits: 0 }; + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!(b.alloc(17, 4), Some(0)); + assert_eq!(b.alloc(1, 4), Some(32)); + } + + #[test] + fn alloc_out_of_space() { + let mut b = BitmapAllocator64 { bits: 0 }; + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!(b.alloc(50, 4), Some(0)); + assert_eq!(b.alloc(1, 4), None); + } + + #[test] + fn free_space() { + let mut b = BitmapAllocator64 { bits: 0 }; + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!(b.alloc(50, 4), Some(0)); + assert_eq!(b.alloc(1, 4), None); + b.free(0, 50); + assert_eq!(b.alloc(1, 4), Some(0)); + } + + #[test] + fn free_multiple() { + let mut b = BitmapAllocator64 { bits: u64::MAX }; + b.free(0, 16); + b.free(23, 16); + b.free(41, 16); + assert_eq!(b.alloc(16, 0), Some(0)); + assert_eq!(b.alloc(16, 0), Some(23)); + assert_eq!(b.alloc(16, 0), Some(41)); + assert_eq!(b.alloc(16, 0), None); + } + + #[test] + fn tree_set_all() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, 64 * 16, false); + for i in 0..16 { + assert_eq!(b.child[i].bits, 0); + } + assert_eq!(b.bits, 0); + assert_eq!(b.used(), 0); + } + + #[test] + fn tree_clear_all() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, 64 * 16, true); + for i in 0..16 { + assert_eq!(b.child[i].bits, u64::MAX); + } + assert_eq!(b.bits, u16::MAX); + assert_eq!(b.used(), 1024); + } + + #[test] + fn tree_set_some() { + let mut b = BitmapAllocatorTree::::new(); + + // First child + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + b.set(11, 17, true); + for i in 0..16 { + if i == 0 { + assert_eq!(b.child[i].bits, 0x000000000ffff800); + } else { + assert_eq!(b.child[i].bits, 0); + } + } + assert_eq!(b.bits, 0x0001); + assert_eq!(b.used(), 17); + + // Last child + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + b.set((15 * 64) + 11, 17, true); + for i in 0..16 { + if i == 15 { + assert_eq!(b.child[i].bits, 0x000000000ffff800); + } else { + assert_eq!(b.child[i].bits, 0); + } + } + assert_eq!(b.bits, 0x8000); + assert_eq!(b.used(), 17); + + // Traverse child boundary + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + b.set(50, 28, true); + for i in 0..16 { + if i == 0 { + assert_eq!(b.child[i].bits, 0xfffc000000000000); + } else if i == 1 { + assert_eq!(b.child[i].bits, 0x0000000000003fff); + } else { + assert_eq!(b.child[i].bits, 0); + } + } + assert_eq!(b.bits, 0x0003); + assert_eq!(b.used(), 28); + } + + #[test] + fn tree_alloc_simple() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + for i in 0..256 { + assert_eq!(b.alloc(1, 0), Some(i)); + } + assert_eq!(b.used(), 256); + } + + #[test] + fn tree_alloc_aligned() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + // Alignment of 1 << 5 bits : 32 bit alignment + assert_eq!(b.alloc(1, 5), Some(0)); + assert_eq!(b.alloc(1, 5), Some(32)); + assert_eq!(b.alloc(1, 5), Some(64)); + assert_eq!(b.alloc(1, 5), Some(96)); + assert_eq!(b.alloc(1, 5), Some(128)); + assert_eq!(b.alloc(1, 0), Some(1)); + assert_eq!(b.used(), 6); + } + + #[test] + fn tree_alloc_large_aligned() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!(b.alloc(500, 4), Some(0)); + assert_eq!(b.alloc(400, 4), Some(512)); + assert_eq!(b.used(), 900); + } + + #[test] + fn tree_alloc_out_of_space() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!(b.alloc(1000, 4), Some(0)); + assert_eq!( + b.alloc(BitmapAllocatorTree::::CAPACITY - 100, 4), + None + ); + assert_eq!(b.used(), 1000); + } + + #[test] + fn tree_free_space() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, BitmapAllocatorTree::::CAPACITY, false); + // Alignment of 1 << 4 bits : 16 bit alignment + assert_eq!( + b.alloc(BitmapAllocatorTree::::CAPACITY - 10, 4), + Some(0) + ); + assert_eq!(b.alloc(1, 4), None); + b.free(0, 50); + assert_eq!(b.alloc(1, 4), Some(0)); + assert_eq!(b.used(), 965); + } + + #[test] + fn tree_free_multiple() { + let mut b = BitmapAllocatorTree::::new(); + b.set(0, BitmapAllocatorTree::::CAPACITY, true); + b.free(0, 16); + b.free(765, 16); + b.free(897, 16); + assert_eq!(b.alloc(16, 0), Some(0)); + assert_eq!(b.alloc(16, 0), Some(765)); + assert_eq!(b.alloc(16, 0), Some(897)); + assert_eq!(b.alloc(16, 0), None); + assert_eq!(b.used(), 1024); + } +} diff --git a/stage2/src/utils/immut_after_init.rs b/stage2/src/utils/immut_after_init.rs new file mode 100644 index 000000000..e9467dbd1 --- /dev/null +++ b/stage2/src/utils/immut_after_init.rs @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Nicolai Stange + +use core::cell::UnsafeCell; +use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::ops::Deref; +#[cfg(debug_assertions)] +use core::sync::atomic::{AtomicBool, Ordering}; + +#[cfg(not(debug_assertions))] +pub type ImmutAfterInitResult = Result; + +#[cfg(debug_assertions)] +pub type ImmutAfterInitResult = Result; + +#[cfg(debug_assertions)] +static MULTI_THREADED: AtomicBool = AtomicBool::new(false); + +#[cfg(debug_assertions)] +#[derive(Clone, Copy, Debug)] +pub enum ImmutAfterInitError { + AlreadyInit, + Uninitialized, + NotSingleThreaded, +} + +/// A memory location which is effectively immutable after initalization code +/// has run. +/// +/// The use of global variables initialized once from code requires either +/// making them `static mut`, which is (on the way of getting) deprecated +/// (c.f. [Consider deprecation of UB-happy `static +/// mut`](https://github.com/rust-lang/rust/issues/53639)), or wrapping them in +/// one of the [`core::cell`] types. However, those either would require to mark +/// each access as `unsafe{}`, including loads from the value, or incur some +/// runtime and storage overhead. +/// +/// Using `ImmutAfterInitCell` as an alternative makes the intended usage +/// pattern more verbatim and limits the `unsafe{}` regions to the +/// initialization code. The main purpose is to facilitate code review: it must +/// get verified that the value gets initialized only once and before first +/// potential use. +/// +/// # Examples +/// A `ImmutAfterInitCell` may start out in unitialized state and can get +/// initialized at runtime: +/// ``` +/// # use svsm::utils::immut_after_init::ImmutAfterInitCell; +/// static X: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); +/// pub fn main() { +/// unsafe { X.init(&123) }; +/// assert_eq!(*X, 123); +/// } +/// ``` +/// +/// Also, to support early/late initialization scenarios, a +/// `ImmutAfterInitCell`'s value may get reset after having been initialized +/// already: +/// ``` +/// # use svsm::utils::immut_after_init::ImmutAfterInitCell; +/// static X: ImmutAfterInitCell = ImmutAfterInitCell::new(0); +/// pub fn main() { +/// assert_eq!(*X, 0); +/// unsafe { X.reinit(&123) }; +/// assert_eq!(*X, 123); +/// } +/// ``` +#[derive(Debug)] +pub struct ImmutAfterInitCell { + #[doc(hidden)] + data: UnsafeCell>, + // Used to keep track of the initialization state. Even though this + // is atomic, the data structure does not guarantee thread safety. + #[cfg(debug_assertions)] + init: AtomicBool, +} + +impl ImmutAfterInitCell { + /// Create an unitialized `ImmutAfterInitCell` instance. The value must get + /// initialized by means of [`Self::init()`] before first usage. + pub const fn uninit() -> Self { + Self { + data: UnsafeCell::new(MaybeUninit::uninit()), + #[cfg(debug_assertions)] + init: AtomicBool::new(false), + } + } + + fn set_init(&self) { + #[cfg(debug_assertions)] + self.init.store(true, Ordering::Relaxed); + } + + fn check_init(&self) -> ImmutAfterInitResult<()> { + #[cfg(debug_assertions)] + if !self.init.load(Ordering::Relaxed) { + return Err(ImmutAfterInitError::Uninitialized); + } + Ok(()) + } + + fn check_uninit(&self) -> ImmutAfterInitResult<()> { + #[cfg(debug_assertions)] + if self.init.load(Ordering::Relaxed) { + return Err(ImmutAfterInitError::AlreadyInit); + } + Ok(()) + } + + fn check_single_threaded(&self) -> ImmutAfterInitResult<()> { + #[cfg(debug_assertions)] + if MULTI_THREADED.load(Ordering::Relaxed) { + return Err(ImmutAfterInitError::NotSingleThreaded); + } + Ok(()) + } + + // The caller must check the initialization status to avoid double init bugs + unsafe fn set_inner(&self, v: &T) { + self.set_init(); + (*self.data.get()) + .as_mut_ptr() + .copy_from_nonoverlapping(v, 1) + } + + // The caller must ensure that the cell is initialized + unsafe fn get_inner(&self) -> &T { + (*self.data.get()).assume_init_ref() + } + + fn try_get_inner(&self) -> ImmutAfterInitResult<&T> { + self.check_init()?; + unsafe { Ok(self.get_inner()) } + } + + /// Initialize an uninitialized `ImmutAfterInitCell` instance from a value. + /// + /// Must **not** get called on an already initialized instance! + /// + /// * `v` - Initialization value. + pub fn init(&self, v: &T) -> ImmutAfterInitResult<()> { + self.check_uninit()?; + self.check_single_threaded()?; + unsafe { self.set_inner(v) }; + Ok(()) + } + + /// Reinitialize an initialized `ImmutAfterInitCell` instance from a value. + /// + /// Must **not** get called while any borrow via [`Self::deref()`] or + /// [`ImmutAfterInitRef::deref()`] is alive! + /// + /// * `v` - Initialization value. + pub fn reinit(&self, v: &T) -> ImmutAfterInitResult<()> { + self.check_single_threaded()?; + unsafe { self.set_inner(v) } + Ok(()) + } + + /// Create an initialized `ImmutAfterInitCell` instance from a value. + /// + /// * `v` - Initialization value. + pub const fn new(v: T) -> Self { + Self { + data: UnsafeCell::new(MaybeUninit::new(v)), + #[cfg(debug_assertions)] + init: AtomicBool::new(true), + } + } +} + +impl Deref for ImmutAfterInitCell { + type Target = T; + + /// Dereference the wrapped value. Must **only ever** get called on an + /// initialized instance! + fn deref(&self) -> &T { + self.try_get_inner().unwrap() + } +} + +unsafe impl Send for ImmutAfterInitCell {} +unsafe impl Sync for ImmutAfterInitCell {} + +/// A reference to a memory location which is effectively immutable after +/// initalization code has run. + +/// A `ImmutAfterInitRef` can either get initialized statically at link time or +/// once from initialization code, basically following the protocol of a +/// [`ImmutAfterInitCell`] itself: +/// +/// # Examples +/// A `ImmutAfterInitRef` can be initialized to either point to a +/// `ImmutAfterInitCell`'s contents, +/// ``` +/// # use svsm::utils::immut_after_init::{ImmutAfterInitCell, ImmutAfterInitRef}; +/// static X: ImmutAfterInitCell = ImmutAfterInitCell::uninit(); +/// static RX: ImmutAfterInitRef<'_, i32> = ImmutAfterInitRef::uninit(); +/// fn main() { +/// unsafe { X.init(&123) }; +/// unsafe { RX.init_from_cell(&X) }; +/// assert_eq!(*RX, 123); +/// } +/// ``` +/// or to plain value directly: +/// ``` +/// # use svsm::utils::immut_after_init::ImmutAfterInitRef; +/// static X: i32 = 123; +/// static RX: ImmutAfterInitRef<'_, i32> = ImmutAfterInitRef::uninit(); +/// fn main() { +/// unsafe { RX.init_from_ref(&X) }; +/// assert_eq!(*RX, 123); +/// } +/// ``` +/// +/// Also, an `ImmutAfterInitRef` can be initialized by obtaining a reference +/// from another `ImmutAfterInitRef`: +/// ``` +/// # use svsm::utils::immut_after_init::ImmutAfterInitRef; +/// static RX : ImmutAfterInitRef::<'static, i32> = ImmutAfterInitRef::uninit(); +// +/// fn init_rx(r : ImmutAfterInitRef<'static, i32>) { +/// unsafe { RX.init_from_ref(r.get()) }; +/// } +/// +/// static X : i32 = 123; +/// +/// fn main() { +/// init_rx(ImmutAfterInitRef::new_from_ref(&X)); +/// assert_eq!(*RX, 123); +/// } +/// ``` +/// +#[derive(Debug)] +pub struct ImmutAfterInitRef<'a, T: Copy> { + #[doc(hidden)] + ptr: ImmutAfterInitCell<*const T>, + #[doc(hidden)] + _phantom: PhantomData<&'a &'a T>, +} + +impl<'a, T: Copy> ImmutAfterInitRef<'a, T> { + /// Create an unitialized `ImmutAfterInitRef` instance. The reference itself + /// must get initialized via either of [`Self::init_from_ref()`] or + /// [`Self::init_from_cell()`] before first dereferencing it. + pub const fn uninit() -> Self { + ImmutAfterInitRef { + ptr: ImmutAfterInitCell::uninit(), + _phantom: PhantomData, + } + } + + /// Initialize an uninitialized `ImmutAfterInitRef` instance to point to value + /// specified by a regular reference. + /// + /// Must **not** get called on an already initialized `ImmutAfterInitRef` + /// instance! + /// + /// * `r` - Reference to the value to make the `ImmutAfterInitRef` to refer + /// to. By convention, the referenced value must have been + /// initialized already. + pub fn init_from_ref<'b>(&self, r: &'b T) -> ImmutAfterInitResult<()> + where + 'b: 'a, + { + self.ptr.init(&(r as *const T)) + } + + /// Create an initialized `ImmutAfterInitRef` instance pointing to a value + /// specified by a regular reference. + /// + /// * `r` - Reference to the value to make the `ImmutAfterInitRef` to refer + /// to. By convention, the referenced value must have been + /// initialized already. + pub const fn new_from_ref(r: &'a T) -> Self { + Self { + ptr: ImmutAfterInitCell::new(r as *const T), + _phantom: PhantomData, + } + } + + /// Dereference the referenced value with lifetime propagation. Must **only + /// ever** get called on an initialized `ImmutAfterInitRef` instance! Moreover, + /// the value referenced must have been initialized as well. + pub fn get(&self) -> &'a T { + unsafe { &**self.ptr } + } +} + +impl<'a, T: Copy> ImmutAfterInitRef<'a, T> { + /// Initialize an uninitialized `ImmutAfterInitRef` instance to point to + /// value wrapped in a [`ImmutAfterInitCell`]. + /// + /// Must **not** get called on an already initialized `ImmutAfterInitRef` instance! + /// + /// * `cell` - The value to make the `ImmutAfterInitRef` to refer to. By + /// convention, the referenced value must have been initialized + /// already. + pub fn init_from_cell<'b>(&self, cell: &'b ImmutAfterInitCell) -> ImmutAfterInitResult<()> + where + 'b: 'a, + { + self.ptr.init(&(cell.try_get_inner()? as *const T)) + } +} + +impl Deref for ImmutAfterInitRef<'_, T> { + type Target = T; + + /// Dereference the referenced value *without* lifetime propagation. Must + /// **only ever** get called on an initialized `ImmutAfterInitRef` instance! + /// Moreover, the value referenced must have been initialized as well. If + /// lifetime propagation is needed, use [`ImmutAfterInitRef::get()`]. + fn deref(&self) -> &T { + self.get() + } +} + +unsafe impl Send for ImmutAfterInitRef<'_, T> {} +unsafe impl Sync for ImmutAfterInitRef<'_, T> {} + +pub fn immut_after_init_set_multithreaded() { + #[cfg(debug_assertions)] + MULTI_THREADED.store(true, Ordering::Relaxed); +} diff --git a/stage2/src/utils/memory_region.rs b/stage2/src/utils/memory_region.rs new file mode 100644 index 000000000..305380a2a --- /dev/null +++ b/stage2/src/utils/memory_region.rs @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Carlos López + +use crate::address::Address; +use crate::types::PageSize; + +/// An abstraction over a memory region, expressed in terms of physical +/// ([`PhysAddr`](crate::address::PhysAddr)) or virtual +/// ([`VirtAddr`](crate::address::VirtAddr)) addresses. +#[derive(Clone, Copy, Debug)] +pub struct MemoryRegion { + start: A, + end: A, +} + +impl MemoryRegion +where + A: Address, +{ + /// Create a new memory region starting at address `start`, spanning `len` + /// bytes. + pub fn new(start: A, len: usize) -> Self { + let end = A::from(start.bits() + len); + Self { start, end } + } + + /// Create a new memory region with overflow checks. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let start = VirtAddr::from(u64::MAX); + /// let region = MemoryRegion::checked_new(start, PAGE_SIZE); + /// assert!(region.is_none()); + /// ``` + pub fn checked_new(start: A, len: usize) -> Option { + let end = start.checked_add(len)?; + Some(Self { start, end }) + } + + /// Create a memory region from two raw addresses. + pub const fn from_addresses(start: A, end: A) -> Self { + Self { start, end } + } + + /// The base address of the memory region, originally set in + /// [`MemoryRegion::new()`]. + #[inline] + pub const fn start(&self) -> A { + self.start + } + + /// The length of the memory region in bytes, originally set in + /// [`MemoryRegion::new()`]. + #[inline] + pub fn len(&self) -> usize { + self.end.bits().saturating_sub(self.start.bits()) + } + + /// Returns whether the region spans any actual memory. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::utils::MemoryRegion; + /// let r = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), 0); + /// assert!(r.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The end address of the memory region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let base = VirtAddr::from(0xffffff0000u64); + /// let region = MemoryRegion::new(base, PAGE_SIZE); + /// assert_eq!(region.end(), VirtAddr::from(0xffffff1000u64)); + /// ``` + #[inline] + pub const fn end(&self) -> A { + self.end + } + + /// Checks whether two regions overlap. This does *not* include contiguous + /// regions, use [`MemoryRegion::contiguous()`] for that purpose. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff2000u64), PAGE_SIZE); + /// assert!(!r1.overlap(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(r1.overlap(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// // Contiguous regions do not overlap + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(!r1.overlap(&r2)); + /// ``` + pub fn overlap(&self, other: &Self) -> bool { + self.start() < other.end() && self.end() > other.start() + } + + /// Checks whether two regions are contiguous or overlapping. This is a + /// less strict check than [`MemoryRegion::overlap()`]. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// assert!(r1.contiguous(&r2)); + /// ``` + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff2000u64), PAGE_SIZE); + /// assert!(!r1.contiguous(&r2)); + /// ``` + pub fn contiguous(&self, other: &Self) -> bool { + self.start() <= other.end() && self.end() >= other.start() + } + + /// Merge two regions. It does not check whether the two regions are + /// contiguous in the first place, so the resulting region will cover + /// any non-overlapping memory between both. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let r1 = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// let r2 = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// let r3 = r1.merge(&r2); + /// assert_eq!(r3.start(), r1.start()); + /// assert_eq!(r3.len(), r1.len() + r2.len()); + /// assert_eq!(r3.end(), r2.end()); + /// ``` + pub fn merge(&self, other: &Self) -> Self { + let start = self.start.min(other.start); + let end = self.end().max(other.end()); + Self { start, end } + } + + /// Iterate over the addresses covering the memory region in jumps of the + /// specified page size. Note that if the base address of the region is not + /// page aligned, returned addresses will not be aligned either. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// let mut iter = region.iter_pages(PageSize::Regular); + /// assert_eq!(iter.next(), Some(VirtAddr::from(0xffffff0000u64))); + /// assert_eq!(iter.next(), Some(VirtAddr::from(0xffffff1000u64))); + /// assert_eq!(iter.next(), None); + /// ``` + pub fn iter_pages(&self, size: PageSize) -> impl Iterator { + let size = usize::from(size); + (self.start().bits()..self.end().bits()) + .step_by(size) + .map(A::from) + } + + /// Check whether an address is within this region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// assert!(region.contains(VirtAddr::from(0xffffff0000u64))); + /// assert!(region.contains(VirtAddr::from(0xffffff0fffu64))); + /// assert!(!region.contains(VirtAddr::from(0xffffff1000u64))); + /// ``` + pub fn contains(&self, addr: A) -> bool { + self.start() <= addr && addr < self.end() + } + + /// Check whether an address is within this region, treating `end` as part + /// of the region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::types::{PAGE_SIZE, PageSize}; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE); + /// assert!(region.contains_inclusive(VirtAddr::from(0xffffff0000u64))); + /// assert!(region.contains_inclusive(VirtAddr::from(0xffffff0fffu64))); + /// assert!(region.contains_inclusive(VirtAddr::from(0xffffff1000u64))); + /// assert!(!region.contains_inclusive(VirtAddr::from(0xffffff1001u64))); + /// ``` + pub fn contains_inclusive(&self, addr: A) -> bool { + (self.start()..=self.end()).contains(&addr) + } + + /// Check whether this region fully contains a different region. + /// + /// ```rust + /// # use svsm::address::VirtAddr; + /// # use svsm::utils::MemoryRegion; + /// # use svsm::types::PAGE_SIZE; + /// let big = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE * 2); + /// let small = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// let overlapping = MemoryRegion::new(VirtAddr::from(0xffffff0000u64), PAGE_SIZE * 2); + /// assert!(big.contains_region(&small)); + /// assert!(!small.contains_region(&big)); + /// assert!(!overlapping.contains_region(&big)); + /// assert!(!big.contains_region(&overlapping)); + /// ``` + pub fn contains_region(&self, other: &Self) -> bool { + self.start() <= other.start() && other.end() <= self.end() + } + + /// Returns a new memory region with the specified added length at the end. + /// + /// ``` + /// # use svsm::address::VirtAddr; + /// # use svsm::types::PAGE_SIZE; + /// # use svsm::utils::MemoryRegion; + /// let region = MemoryRegion::new(VirtAddr::from(0xffffff1000u64), PAGE_SIZE); + /// let bigger = region.expand(PAGE_SIZE); + /// assert_eq!(bigger.len(), PAGE_SIZE * 2); + /// ``` + pub fn expand(&self, len: usize) -> Self { + Self::new(self.start(), self.len() + len) + } +} diff --git a/stage2/src/utils/mod.rs b/stage2/src/utils/mod.rs new file mode 100644 index 000000000..e2cca7290 --- /dev/null +++ b/stage2/src/utils/mod.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +pub mod bitmap_allocator; +pub mod immut_after_init; +pub mod memory_region; +pub mod util; + +pub use memory_region::MemoryRegion; +pub use util::{ + align_down, align_up, halt, is_aligned, overlap, page_align_up, page_offset, zero_mem_region, +}; diff --git a/stage2/src/utils/util.rs b/stage2/src/utils/util.rs new file mode 100644 index 000000000..45aa5c827 --- /dev/null +++ b/stage2/src/utils/util.rs @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +// +// Copyright (c) 2022-2023 SUSE LLC +// +// Author: Joerg Roedel + +use crate::address::{Address, VirtAddr}; +use crate::types::PAGE_SIZE; +use core::arch::asm; +use core::ops::{Add, BitAnd, Not, Sub}; + +pub fn align_up(addr: T, align: T) -> T +where + T: Add + Sub + BitAnd + Not + From + Copy, +{ + let mask: T = align - T::from(1u8); + (addr + mask) & !mask +} + +pub fn align_down(addr: T, align: T) -> T +where + T: Sub + Not + BitAnd + From + Copy, +{ + addr & !(align - T::from(1u8)) +} + +pub fn is_aligned(addr: T, align: T) -> bool +where + T: Sub + BitAnd + PartialEq + From, +{ + (addr & (align - T::from(1))) == T::from(0) +} + +pub fn halt() { + unsafe { + asm!("hlt", options(att_syntax)); + } +} + +pub fn page_align_up(x: usize) -> usize { + align_up(x, PAGE_SIZE) +} + +pub fn page_offset(x: usize) -> usize { + x & (PAGE_SIZE - 1) +} + +pub fn overlap(x1: T, x2: T, y1: T, y2: T) -> bool +where + T: PartialOrd, +{ + x1 <= y2 && y1 <= x2 +} + +pub fn zero_mem_region(start: VirtAddr, end: VirtAddr) { + let size = end - start; + if start.is_null() { + panic!("Attempted to zero out a NULL pointer"); + } + + // Zero region + unsafe { start.as_mut_ptr::().write_bytes(0, size) } +} + +/// Obtain bit for a given position +#[macro_export] +macro_rules! BIT { + ($x: expr) => { + (1 << ($x)) + }; +} + +/// Obtain bit mask for the given positions +#[macro_export] +macro_rules! BIT_MASK { + ($e: expr, $s: expr) => {{ + assert!( + $s <= 63 && $e <= 63 && $s <= $e, + "Start bit position must be less than or equal to end bit position" + ); + (((1u64 << ($e - $s + 1)) - 1) << $s) + }}; +} + +#[cfg(test)] +mod tests { + + use crate::utils::util::*; + + #[test] + fn test_mem_utils() { + // Align up + assert_eq!(align_up(7, 4), 8); + assert_eq!(align_up(15, 8), 16); + assert_eq!(align_up(10, 2), 10); + // Align down + assert_eq!(align_down(7, 4), 4); + assert_eq!(align_down(15, 8), 8); + assert_eq!(align_down(10, 2), 10); + // Page align up + assert_eq!(page_align_up(4096), 4096); + assert_eq!(page_align_up(4097), 8192); + assert_eq!(page_align_up(0), 0); + // Page offset + assert_eq!(page_offset(4096), 0); + assert_eq!(page_offset(4097), 1); + assert_eq!(page_offset(0), 0); + // Overlaps + assert!(overlap(1, 5, 3, 6)); + assert!(overlap(0, 10, 5, 15)); + assert!(!overlap(1, 5, 6, 8)); + } + + #[test] + fn test_zero_mem_region() { + let mut data: [u8; 10] = [1; 10]; + let start = VirtAddr::from(data.as_mut_ptr()); + let end = start + core::mem::size_of_val(&data); + + zero_mem_region(start, end); + + for byte in &data { + assert_eq!(*byte, 0); + } + } +} diff --git a/stage2/src/vtpm/mod.rs b/stage2/src/vtpm/mod.rs new file mode 100644 index 000000000..1391bb211 --- /dev/null +++ b/stage2/src/vtpm/mod.rs @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! This crate defines the Virtual TPM interfaces and shows what +//! TPM backends are supported + +/// TPM 2.0 Reference Implementation by Microsoft +pub mod mstpm; + +use crate::vtpm::mstpm::MsTpm as Vtpm; +use crate::{locking::LockGuard, protocols::vtpm::TpmPlatformCommand}; +use crate::{locking::SpinLock, protocols::errors::SvsmReqError}; + +/// Basic services required to perform the VTPM Protocol +pub trait VtpmProtocolInterface { + /// Get the list of Platform Commands supported by the TPM implementation. + fn get_supported_commands(&self) -> &[TpmPlatformCommand]; +} + +/// This implements one handler for each [`TpmPlatformCommand`] supported by the +/// VTPM Protocol. These handlers are based on the TPM Simulator +/// interface (by Microsoft), but with a few changes to make it more Rust +/// idiomatic. +/// +/// `ms-tpm-20-ref/TPMCmd/Simulator/include/prototypes/Simulator_fp.h` +pub trait MsTpmSimulatorInterface: VtpmProtocolInterface { + /// Send a command for the TPM to run in a given locality + /// + /// # Arguments + /// + /// * `buffer`: Buffer with the command to be sent to the TPM. It has to be large enough + /// to hold the response received from the TPM. + /// * `length`: The length of the command stored in `buffer`. It will be updated with the + /// size of the TPM response received from the TPM. + /// * `locality`: TPM locality the TPM command will be executed + fn send_tpm_command( + &self, + buffer: &mut [u8], + length: &mut usize, + locality: u8, + ) -> Result<(), SvsmReqError>; + + /// Power-on the TPM, which also triggers a reset + /// + /// # Arguments + /// + /// *`only_reset``: If enabled, it will only reset the vTPM; + /// however, the vtPM has to be powered on previously. + /// Otherwise, it will fail. + fn signal_poweron(&mut self, only_reset: bool) -> Result<(), SvsmReqError>; + + /// In a system where the NV memory used by the TPM is not within the TPM, + /// the NV may not always be available. This function indicates that NV + /// is available. + fn signal_nvon(&self) -> Result<(), SvsmReqError>; +} + +/// Basic TPM driver services +pub trait VtpmInterface: MsTpmSimulatorInterface { + /// Check if the TPM is powered on. + fn is_powered_on(&self) -> bool; + + /// Prepare the TPM to be used for the first time. At this stage, + /// the TPM is manufactured. + fn init(&mut self) -> Result<(), SvsmReqError>; +} + +static VTPM: SpinLock = SpinLock::new(Vtpm::new()); + +/// Initialize the TPM by calling the init() implementation of the +/// [`VtpmInterface`] +pub fn vtpm_init() -> Result<(), SvsmReqError> { + let mut vtpm = VTPM.lock(); + if vtpm.is_powered_on() { + return Ok(()); + } + vtpm.init()?; + Ok(()) +} + +pub fn vtpm_get_locked<'a>() -> LockGuard<'a, Vtpm> { + VTPM.lock() +} diff --git a/stage2/src/vtpm/mstpm/mod.rs b/stage2/src/vtpm/mstpm/mod.rs new file mode 100644 index 000000000..83c0f3ce0 --- /dev/null +++ b/stage2/src/vtpm/mstpm/mod.rs @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! This crate implements the virtual TPM interfaces for the TPM 2.0 +//! Reference Implementation (by Microsoft) + +/// Functions required to build the Microsoft TPM libraries +#[cfg(not(any(test, fuzzing)))] +mod wrapper; + +extern crate alloc; + +use alloc::vec::Vec; +use core::ffi::c_void; +use libmstpm::bindings::{ + TPM_Manufacture, TPM_TearDown, _plat__LocalitySet, _plat__NVDisable, _plat__NVEnable, + _plat__RunCommand, _plat__SetNvAvail, _plat__Signal_PowerOn, _plat__Signal_Reset, +}; + +use crate::{ + address::VirtAddr, + protocols::{errors::SvsmReqError, vtpm::TpmPlatformCommand}, + types::PAGE_SIZE, + vtpm::{MsTpmSimulatorInterface, VtpmInterface, VtpmProtocolInterface}, +}; + +#[derive(Debug, Copy, Clone, Default)] +pub struct MsTpm { + is_powered_on: bool, +} + +impl MsTpm { + pub const fn new() -> MsTpm { + MsTpm { + is_powered_on: false, + } + } + + fn teardown(&self) -> Result<(), SvsmReqError> { + let result = unsafe { TPM_TearDown() }; + match result { + 0 => Ok(()), + rc => { + log::error!("TPM_Teardown failed rc={}", rc); + Err(SvsmReqError::incomplete()) + } + } + } + + fn manufacture(&self, first_time: i32) -> Result { + let result = unsafe { TPM_Manufacture(first_time) }; + match result { + // TPM manufactured successfully + 0 => Ok(0), + // TPM already manufactured + 1 => Ok(1), + // TPM failed to manufacture + rc => { + log::error!("TPM_Manufacture failed rc={}", rc); + Err(SvsmReqError::incomplete()) + } + } + } +} + +const TPM_CMDS_SUPPORTED: &[TpmPlatformCommand] = &[TpmPlatformCommand::SendCommand]; + +impl VtpmProtocolInterface for MsTpm { + fn get_supported_commands(&self) -> &[TpmPlatformCommand] { + TPM_CMDS_SUPPORTED + } +} + +pub const TPM_BUFFER_MAX_SIZE: usize = PAGE_SIZE; + +impl MsTpmSimulatorInterface for MsTpm { + fn send_tpm_command( + &self, + buffer: &mut [u8], + length: &mut usize, + locality: u8, + ) -> Result<(), SvsmReqError> { + if !self.is_powered_on { + return Err(SvsmReqError::invalid_request()); + } + if *length > TPM_BUFFER_MAX_SIZE || *length > buffer.len() { + return Err(SvsmReqError::invalid_parameter()); + } + + let mut request_ffi = buffer[..*length].to_vec(); + + let mut response_ffi = Vec::::with_capacity(TPM_BUFFER_MAX_SIZE); + let mut response_ffi_p = response_ffi.as_mut_ptr(); + let mut response_ffi_size = TPM_BUFFER_MAX_SIZE as u32; + + unsafe { + _plat__LocalitySet(locality); + _plat__RunCommand( + request_ffi.len() as u32, + request_ffi.as_mut_ptr().cast::(), + &raw mut response_ffi_size, + &raw mut response_ffi_p, + ); + if response_ffi_size == 0 || response_ffi_size as usize > response_ffi.capacity() { + return Err(SvsmReqError::invalid_request()); + } + response_ffi.set_len(response_ffi_size as usize); + } + + buffer.fill(0); + buffer + .get_mut(..response_ffi.len()) + .ok_or_else(SvsmReqError::invalid_request)? + .copy_from_slice(response_ffi.as_slice()); + *length = response_ffi.len(); + + Ok(()) + } + + fn signal_poweron(&mut self, only_reset: bool) -> Result<(), SvsmReqError> { + if self.is_powered_on && !only_reset { + return Ok(()); + } + if only_reset && !self.is_powered_on { + return Err(SvsmReqError::invalid_request()); + } + if !only_reset { + unsafe { _plat__Signal_PowerOn() }; + } + // It calls TPM_init() within to indicate that a TPM2_Startup is required. + unsafe { _plat__Signal_Reset() }; + self.is_powered_on = true; + + Ok(()) + } + + fn signal_nvon(&self) -> Result<(), SvsmReqError> { + if !self.is_powered_on { + return Err(SvsmReqError::invalid_request()); + } + unsafe { _plat__SetNvAvail() }; + + Ok(()) + } +} + +impl VtpmInterface for MsTpm { + fn is_powered_on(&self) -> bool { + self.is_powered_on + } + + fn init(&mut self) -> Result<(), SvsmReqError> { + // Initialize the MS TPM following the same steps done in the Simulator: + // + // 1. Manufacture it for the first time + // 2. Make sure it does not fail if it is re-manufactured + // 3. Teardown to indicate it needs to be manufactured + // 4. Manufacture it for the first time + // 5. Power it on indicating it requires startup. By default, OVMF will start + // and selftest it. + + unsafe { _plat__NVEnable(VirtAddr::null().as_mut_ptr::()) }; + + let mut rc = self.manufacture(1)?; + if rc != 0 { + unsafe { _plat__NVDisable(1) }; + return Err(SvsmReqError::incomplete()); + } + + rc = self.manufacture(0)?; + if rc != 1 { + return Err(SvsmReqError::incomplete()); + } + + self.teardown()?; + rc = self.manufacture(1)?; + if rc != 0 { + return Err(SvsmReqError::incomplete()); + } + + self.signal_poweron(false)?; + self.signal_nvon()?; + + log::info!("VTPM: Microsoft TPM 2.0 initialized"); + + Ok(()) + } +} diff --git a/stage2/src/vtpm/mstpm/wrapper.rs b/stage2/src/vtpm/mstpm/wrapper.rs new file mode 100644 index 000000000..0a1baef7a --- /dev/null +++ b/stage2/src/vtpm/mstpm/wrapper.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2023 IBM +// +// Authors: Claudio Carvalho + +//! Implement functions required to build the Microsoft TPM libraries. +//! All these functionalities are owned by the SVSM Rust code, +//! so we just need to create wrappers for them. + +use crate::{ + console::_print, + mm::alloc::{layout_from_ptr, layout_from_size}, + sev::msr_protocol::request_termination_msr, +}; + +use core::{ + alloc::Layout, + ffi::{c_char, c_int, c_ulong, c_void}, + ptr, + slice::from_raw_parts, + str::from_utf8, +}; + +extern crate alloc; +use alloc::alloc::{alloc, alloc_zeroed, dealloc, realloc as _realloc}; + +#[no_mangle] +pub extern "C" fn malloc(size: c_ulong) -> *mut c_void { + let layout: Layout = layout_from_size(size as usize); + unsafe { alloc(layout).cast() } +} + +#[no_mangle] +pub extern "C" fn calloc(items: c_ulong, size: c_ulong) -> *mut c_void { + if let Some(new_size) = items.checked_mul(size) { + let layout = layout_from_size(new_size as usize); + return unsafe { alloc_zeroed(layout).cast() }; + } + ptr::null_mut() +} + +#[no_mangle] +pub unsafe extern "C" fn realloc(p: *mut c_void, size: c_ulong) -> *mut c_void { + let ptr = p as *mut u8; + let new_size = size as usize; + if let Some(layout) = layout_from_ptr(ptr) { + return unsafe { _realloc(ptr, layout, new_size).cast() }; + } + ptr::null_mut() +} + +#[no_mangle] +pub unsafe extern "C" fn free(p: *mut c_void) { + if p.is_null() { + return; + } + let ptr = p as *mut u8; + if let Some(layout) = layout_from_ptr(ptr.cast()) { + unsafe { dealloc(ptr, layout) } + } +} + +#[no_mangle] +pub unsafe extern "C" fn serial_out(s: *const c_char, size: c_int) { + let str_slice: &[u8] = unsafe { from_raw_parts(s as *const u8, size as usize) }; + if let Ok(rust_str) = from_utf8(str_slice) { + _print(format_args!("[SVSM] {}", rust_str)); + } else { + log::error!("ERR: BUG: serial_out arg1 is not a valid utf8 string"); + } +} + +#[no_mangle] +pub extern "C" fn abort() -> ! { + request_termination_msr(); +}