diff --git a/lib/virtual-fs/src/mem_fs/file.rs b/lib/virtual-fs/src/mem_fs/file.rs index 7c8bd230159..fe3c5779875 100644 --- a/lib/virtual-fs/src/mem_fs/file.rs +++ b/lib/virtual-fs/src/mem_fs/file.rs @@ -1059,13 +1059,13 @@ mod test_read_write_seek { matches!(file.write(b"baz").await, Ok(3)), "writing `baz` at the beginning of the file", ); - assert_eq!(file.size(), 9, "checking the size of the file"); + assert_eq!(file.size(), 6, "checking the size of the file"); assert!( matches!(file.write(b"qux").await, Ok(3)), "writing `qux` in the middle of the file", ); - assert_eq!(file.size(), 12, "checking the size of the file"); + assert_eq!(file.size(), 6, "checking the size of the file"); assert!( matches!(file.seek(io::SeekFrom::Start(0)).await, Ok(0)), @@ -1074,26 +1074,26 @@ mod test_read_write_seek { let mut string = String::new(); assert!( - matches!(file.read_to_string(&mut string).await, Ok(12)), - "reading `bazquxfoobar`", + matches!(file.read_to_string(&mut string).await, Ok(6)), + "reading `bazqux`", ); - assert_eq!(string, "bazquxfoobar"); + assert_eq!(string, "bazqux"); assert!( - matches!(file.seek(io::SeekFrom::Current(-6)).await, Ok(6)), - "seeking to 6", + matches!(file.seek(io::SeekFrom::Current(-3)).await, Ok(3)), + "seeking to 3", ); let mut string = String::new(); assert!( - matches!(file.read_to_string(&mut string).await, Ok(6)), - "reading `foobar`", + matches!(file.read_to_string(&mut string).await, Ok(3)), + "reading `qux`", ); - assert_eq!(string, "foobar"); + assert_eq!(string, "qux"); assert!( - matches!(file.seek(io::SeekFrom::End(0)).await, Ok(12)), - "seeking to 12", + matches!(file.seek(io::SeekFrom::End(0)).await, Ok(6)), + "seeking to 6", ); let mut string = String::new(); @@ -1104,6 +1104,49 @@ mod test_read_write_seek { assert_eq!(string, ""); } + #[test] + pub fn writing_to_middle() { + fn assert_contents(file: &File, expected: &[u8]) { + let mut buf = vec![0; expected.len() + 1]; + let mut cursor = 0; + let read = file.read(buf.as_mut(), &mut cursor).unwrap(); + assert_eq!(read, expected.len(), "Must have the same amount of data"); + assert_eq!(buf[0..expected.len()], *expected); + } + + let mut file = File::new(None); + + let mut cursor = 0; + + // Write to empty file + file.write(b"hello, world.", &mut cursor).unwrap(); + assert_eq!(cursor, 13); + assert_contents(&file, b"hello, world."); + + // Write to end of file + file.write(b"goodbye!", &mut cursor).unwrap(); + assert_eq!(cursor, 21); + assert_contents(&file, b"hello, world.goodbye!"); + + // Write to middle of file + cursor = 5; + file.write(b"BOOM", &mut cursor).unwrap(); + assert_eq!(cursor, 9); + assert_contents(&file, b"helloBOOMrld.goodbye!"); + + // Write to middle of file until last byte + cursor = 17; + file.write(b"BANG", &mut cursor).unwrap(); + assert_eq!(cursor, 21); + assert_contents(&file, b"helloBOOMrld.goodBANG"); + + // Write to middle past end of file + cursor = 17; + file.write(b"OUCH!", &mut cursor).unwrap(); + assert_eq!(cursor, 22); + assert_contents(&file, b"helloBOOMrld.goodOUCH!"); + } + #[tokio::test] async fn test_reading() { let fs = FileSystem::default(); @@ -1343,36 +1386,16 @@ impl File { impl File { pub fn write(&mut self, buf: &[u8], cursor: &mut u64) -> io::Result { - match *cursor { - // The cursor is at the end of the buffer: happy path! - position if position == self.buffer.len() as u64 => { - self.buffer.extend_from_slice(buf)?; - } - - // The cursor is at the beginning of the buffer (and the - // buffer is not empty, otherwise it would have been - // caught by the previous arm): almost a happy path! - 0 => { - // FIXME(perf,theduke): make this faster, it's horrible! - let mut new_buffer = TrackedVec::with_capacity( - self.buffer.len() + buf.len(), - self.buffer.limiter().cloned(), - )?; - new_buffer.extend_from_slice(buf)?; - new_buffer.append(&mut self.buffer)?; - - self.buffer = new_buffer; - } - - // The cursor is somewhere in the buffer: not the happy path. - position => { - self.buffer.reserve_exact(buf.len())?; - - // FIXME(perf,theduke): make this faster, it's horrible! - let mut remainder = self.buffer.split_off(position as usize)?; - self.buffer.extend_from_slice(buf)?; - self.buffer.append(&mut remainder)?; - } + let position = *cursor as usize; + + if position + buf.len() > self.buffer.len() { + // Writing past the end of the current buffer, must reallocate + let len_after_end = (position + buf.len()) - self.buffer.len(); + let let_to_end = buf.len() - len_after_end; + self.buffer[position..position + let_to_end].copy_from_slice(&buf[0..let_to_end]); + self.buffer.extend_from_slice(&buf[let_to_end..buf.len()])?; + } else { + self.buffer[position..position + buf.len()].copy_from_slice(buf); } *cursor += buf.len() as u64;