Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid zeroing large stack buffers in stdio on Windows #101193

Merged
merged 2 commits into from
Aug 31, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions library/std/src/sys/windows/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::char::decode_utf16;
use crate::cmp;
use crate::io;
use crate::mem::MaybeUninit;
use crate::os::windows::io::{FromRawHandle, IntoRawHandle};
use crate::ptr;
use crate::str;
Expand Down Expand Up @@ -169,13 +170,14 @@ fn write(
}

fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
let mut utf16 = [0u16; MAX_BUFFER_SIZE / 2];
let mut utf16: [MaybeUninit<u16>; MAX_BUFFER_SIZE / 2] = MaybeUninit::uninit_array();
thomcc marked this conversation as resolved.
Show resolved Hide resolved
let mut len_utf16 = 0;
for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
*dest = chr;
*dest = MaybeUninit::new(chr);
len_utf16 += 1;
}
let utf16 = &utf16[..len_utf16];
// Safety: We've initialized `len_utf16` values.
let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) };

let mut written = write_u16s(handle, &utf16)?;

Expand Down Expand Up @@ -250,27 +252,33 @@ impl io::Read for Stdin {
return Ok(bytes_copied);
} else if buf.len() - bytes_copied < 4 {
// Not enough space to get a UTF-8 byte. We will use the incomplete UTF8.
let mut utf16_buf = [0u16; 1];
let mut utf16_buf = [MaybeUninit::new(1); 1];
thomcc marked this conversation as resolved.
Show resolved Hide resolved
// Read one u16 character.
let read = read_u16s_fixup_surrogates(handle, &mut utf16_buf, 1, &mut self.surrogate)?;
// Read bytes, using the (now-empty) self.incomplete_utf8 as extra space.
let read_bytes = utf16_to_utf8(&utf16_buf[..read], &mut self.incomplete_utf8.bytes)?;
let read_bytes = utf16_to_utf8(
unsafe { MaybeUninit::slice_assume_init_ref(&utf16_buf[..read]) },
&mut self.incomplete_utf8.bytes,
)?;

// Read in the bytes from incomplete_utf8 until the buffer is full.
self.incomplete_utf8.len = read_bytes as u8;
// No-op if no bytes.
bytes_copied += self.incomplete_utf8.read(&mut buf[bytes_copied..]);
Ok(bytes_copied)
} else {
let mut utf16_buf = [0u16; MAX_BUFFER_SIZE / 2];
let mut utf16_buf: [MaybeUninit<u16>; MAX_BUFFER_SIZE / 2] =
MaybeUninit::uninit_array();
thomcc marked this conversation as resolved.
Show resolved Hide resolved
// In the worst case, a UTF-8 string can take 3 bytes for every `u16` of a UTF-16. So
// we can read at most a third of `buf.len()` chars and uphold the guarantee no data gets
// lost.
let amount = cmp::min(buf.len() / 3, utf16_buf.len());
let read =
read_u16s_fixup_surrogates(handle, &mut utf16_buf, amount, &mut self.surrogate)?;

match utf16_to_utf8(&utf16_buf[..read], buf) {
// Safety `read_u16s_fixup_surrogates` returns the number of items
// initialized.
let utf16s = unsafe { MaybeUninit::slice_assume_init_ref(&utf16_buf[..read]) };
match utf16_to_utf8(utf16s, buf) {
Ok(value) => return Ok(bytes_copied + value),
Err(e) => return Err(e),
}
Expand All @@ -283,14 +291,14 @@ impl io::Read for Stdin {
// This is a best effort, and might not work if we are not the only reader on Stdin.
fn read_u16s_fixup_surrogates(
handle: c::HANDLE,
buf: &mut [u16],
buf: &mut [MaybeUninit<u16>],
mut amount: usize,
surrogate: &mut u16,
) -> io::Result<usize> {
// Insert possibly remaining unpaired surrogate from last read.
let mut start = 0;
if *surrogate != 0 {
buf[0] = *surrogate;
buf[0] = MaybeUninit::new(*surrogate);
*surrogate = 0;
start = 1;
if amount == 1 {
Expand All @@ -303,7 +311,10 @@ fn read_u16s_fixup_surrogates(
let mut amount = read_u16s(handle, &mut buf[start..amount])? + start;

if amount > 0 {
let last_char = buf[amount - 1];
// Safety: The returned `amount` is the number of values initialized,
// and it is not 0, so we know that `buf[amount - 1]` have been
// initialized.
let last_char = unsafe { buf[amount - 1].assume_init() };
if last_char >= 0xD800 && last_char <= 0xDBFF {
// high surrogate
*surrogate = last_char;
Expand All @@ -313,7 +324,8 @@ fn read_u16s_fixup_surrogates(
Ok(amount)
}

fn read_u16s(handle: c::HANDLE, buf: &mut [u16]) -> io::Result<usize> {
// Returns `Ok(n)` if it initialized `n` values in `buf`.
fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit<u16>]) -> io::Result<usize> {
// Configure the `pInputControl` parameter to not only return on `\r\n` but also Ctrl-Z, the
// traditional DOS method to indicate end of character stream / user input (SUB).
// See #38274 and https://stackoverflow.com/questions/43836040/win-api-readconsole.
Expand Down Expand Up @@ -346,8 +358,9 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [u16]) -> io::Result<usize> {
}
break;
}

if amount > 0 && buf[amount as usize - 1] == CTRL_Z {
// Safety: if `amount > 0`, then that many bytes were written, so
// `buf[amount as usize - 1]` has been initialized.
if amount > 0 && unsafe { buf[amount as usize - 1].assume_init() } == CTRL_Z {
amount -= 1;
}
Ok(amount as usize)
Expand Down