Skip to content

Commit 318e893

Browse files
committed
Head: ensure stdin input stream is correct on exit
Head tool now ensures that stdin is set to the last character that was output by the tool. This ensures that if any subsequent tools are run from the same input stream they will start at the correct point in the stream.
1 parent f94ff78 commit 318e893

File tree

6 files changed

+339
-38
lines changed

6 files changed

+339
-38
lines changed

src/uu/head/src/head.rs

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
use clap::{crate_version, Arg, ArgAction, ArgMatches, Command};
99
use std::ffi::OsString;
10+
#[cfg(unix)]
11+
use std::fs::File;
1012
use std::io::{self, ErrorKind, Read, Seek, SeekFrom, Write};
1113
use std::num::TryFromIntError;
14+
#[cfg(unix)]
15+
use std::os::fd::{AsRawFd, FromRawFd};
1216
use thiserror::Error;
1317
use uucore::display::Quotable;
1418
use uucore::error::{FromIo, UError, UResult};
@@ -239,7 +243,7 @@ impl HeadOptions {
239243
}
240244
}
241245

242-
fn read_n_bytes<R>(input: R, n: u64) -> std::io::Result<()>
246+
fn read_n_bytes<R>(input: R, n: u64) -> std::io::Result<u64>
243247
where
244248
R: Read,
245249
{
@@ -250,31 +254,31 @@ where
250254
let stdout = std::io::stdout();
251255
let mut stdout = stdout.lock();
252256

253-
io::copy(&mut reader, &mut stdout)?;
257+
let bytes_copied = io::copy(&mut reader, &mut stdout)?;
254258

255259
// Make sure we finish writing everything to the target before
256260
// exiting. Otherwise, when Rust is implicitly flushing, any
257261
// error will be silently ignored.
258262
stdout.flush()?;
259263

260-
Ok(())
264+
Ok(bytes_copied)
261265
}
262266

263-
fn read_n_lines(input: &mut impl std::io::BufRead, n: u64, separator: u8) -> std::io::Result<()> {
267+
fn read_n_lines(input: &mut impl std::io::BufRead, n: u64, separator: u8) -> std::io::Result<u64> {
264268
// Read the first `n` lines from the `input` reader.
265269
let mut reader = take_lines(input, n, separator);
266270

267271
// Write those bytes to `stdout`.
268272
let mut stdout = std::io::stdout();
269273

270-
io::copy(&mut reader, &mut stdout)?;
274+
let bytes_copied = io::copy(&mut reader, &mut stdout)?;
271275

272276
// Make sure we finish writing everything to the target before
273277
// exiting. Otherwise, when Rust is implicitly flushing, any
274278
// error will be silently ignored.
275279
stdout.flush()?;
276280

277-
Ok(())
281+
Ok(bytes_copied)
278282
}
279283

280284
fn catch_too_large_numbers_in_backwards_bytes_or_lines(n: u64) -> Option<usize> {
@@ -288,7 +292,7 @@ fn catch_too_large_numbers_in_backwards_bytes_or_lines(n: u64) -> Option<usize>
288292
}
289293

290294
/// Print to stdout all but the last `n` bytes from the given reader.
291-
fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::Result<()> {
295+
fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::Result<u64> {
292296
if n == 0 {
293297
//prints everything
294298
return read_n_bytes(input, u64::MAX);
@@ -302,6 +306,7 @@ fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::
302306

303307
let mut buffer = [0u8; BUF_SIZE];
304308
let mut total_read = 0;
309+
let mut total_written = 0;
305310

306311
loop {
307312
let read = match input.read(&mut buffer) {
@@ -321,30 +326,37 @@ fn read_but_last_n_bytes(input: &mut impl std::io::BufRead, n: u64) -> std::io::
321326
ring_buffer.extend_from_slice(&buffer[..read - overflow]);
322327
} else {
323328
// Write the ring buffer and the part of the buffer that exceeds n
329+
total_written += ring_buffer.len();
324330
stdout.write_all(&ring_buffer)?;
331+
total_written += read - n + ring_buffer.len();
325332
stdout.write_all(&buffer[..read - n + ring_buffer.len()])?;
326333
ring_buffer.clear();
327334
ring_buffer.extend_from_slice(&buffer[read - n + ring_buffer.len()..read]);
328335
}
329336
}
337+
eprintln!("Returning {}", total_written);
338+
return Ok(u64::try_from(total_written).unwrap());
330339
}
331340

332-
Ok(())
341+
Ok(0)
333342
}
334343

335344
fn read_but_last_n_lines(
336345
input: impl std::io::BufRead,
337346
n: u64,
338347
separator: u8,
339-
) -> std::io::Result<()> {
348+
) -> std::io::Result<u64> {
349+
let mut bytes_read: u64 = 0;
340350
if let Some(n) = catch_too_large_numbers_in_backwards_bytes_or_lines(n) {
341351
let stdout = std::io::stdout();
342352
let mut stdout = stdout.lock();
343353
for bytes in take_all_but(lines(input, separator), n) {
344-
stdout.write_all(&bytes?)?;
354+
let bytes = bytes?;
355+
bytes_read += u64::try_from(bytes.len()).unwrap();
356+
stdout.write_all(&bytes)?;
345357
}
346358
}
347-
Ok(())
359+
Ok(bytes_read)
348360
}
349361

350362
/// Return the index in `input` just after the `n`th line from the end.
@@ -425,61 +437,61 @@ fn is_seekable(input: &mut std::fs::File) -> bool {
425437
&& input.seek(SeekFrom::Start(current_pos.unwrap())).is_ok()
426438
}
427439

428-
fn head_backwards_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<()> {
440+
fn head_backwards_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<u64> {
429441
let st = input.metadata()?;
430442
let seekable = is_seekable(input);
431443
let blksize_limit = uucore::fs::sane_blksize::sane_blksize_from_metadata(&st);
432444
if !seekable || st.len() <= blksize_limit {
433-
return head_backwards_without_seek_file(input, options);
445+
eprintln!("head_backwards_without_seek_file");
446+
head_backwards_without_seek_file(input, options)
447+
} else {
448+
eprintln!("head_backwards_on_seekable_file");
449+
head_backwards_on_seekable_file(input, options)
434450
}
435-
436-
head_backwards_on_seekable_file(input, options)
437451
}
438452

439453
fn head_backwards_without_seek_file(
440454
input: &mut std::fs::File,
441455
options: &HeadOptions,
442-
) -> std::io::Result<()> {
456+
) -> std::io::Result<u64> {
443457
let reader = &mut std::io::BufReader::with_capacity(BUF_SIZE, &*input);
444458

445459
match options.mode {
446-
Mode::AllButLastBytes(n) => read_but_last_n_bytes(reader, n)?,
447-
Mode::AllButLastLines(n) => read_but_last_n_lines(reader, n, options.line_ending.into())?,
460+
Mode::AllButLastBytes(n) => read_but_last_n_bytes(reader, n),
461+
Mode::AllButLastLines(n) => read_but_last_n_lines(reader, n, options.line_ending.into()),
448462
_ => unreachable!(),
449463
}
450-
451-
Ok(())
452464
}
453465

454466
fn head_backwards_on_seekable_file(
455467
input: &mut std::fs::File,
456468
options: &HeadOptions,
457-
) -> std::io::Result<()> {
469+
) -> std::io::Result<u64> {
458470
match options.mode {
459471
Mode::AllButLastBytes(n) => {
460472
let size = input.metadata()?.len();
461473
if n >= size {
462-
return Ok(());
474+
Ok(0)
463475
} else {
464476
read_n_bytes(
465477
&mut std::io::BufReader::with_capacity(BUF_SIZE, input),
466478
size - n,
467-
)?;
479+
)
468480
}
469481
}
470482
Mode::AllButLastLines(n) => {
471483
let found = find_nth_line_from_end(input, n, options.line_ending.into())?;
472484
read_n_bytes(
473485
&mut std::io::BufReader::with_capacity(BUF_SIZE, input),
474486
found,
475-
)?;
487+
)
476488
}
477489
_ => unreachable!(),
478490
}
479-
Ok(())
480491
}
481492

482-
fn head_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<()> {
493+
fn head_file(input: &mut std::fs::File, options: &HeadOptions) -> std::io::Result<u64> {
494+
eprintln!("head_file");
483495
match options.mode {
484496
Mode::FirstBytes(n) => {
485497
read_n_bytes(&mut std::io::BufReader::with_capacity(BUF_SIZE, input), n)
@@ -506,16 +518,41 @@ fn uu_head(options: &HeadOptions) -> UResult<()> {
506518
println!("==> standard input <==");
507519
}
508520
let stdin = std::io::stdin();
509-
let mut stdin = stdin.lock();
510-
511-
match options.mode {
512-
Mode::FirstBytes(n) => read_n_bytes(&mut stdin, n),
513-
Mode::AllButLastBytes(n) => read_but_last_n_bytes(&mut stdin, n),
514-
Mode::FirstLines(n) => read_n_lines(&mut stdin, n, options.line_ending.into()),
515-
Mode::AllButLastLines(n) => {
516-
read_but_last_n_lines(&mut stdin, n, options.line_ending.into())
521+
522+
#[cfg(unix)]
523+
{
524+
let stdin_raw_fd = stdin.as_raw_fd();
525+
let mut stdin_file = unsafe { File::from_raw_fd(stdin_raw_fd) };
526+
let current_pos = stdin_file.stream_position();
527+
if let Ok(current_pos) = current_pos {
528+
// We have a seekable file. Ensure we set the input stream to the
529+
// last byte read so that any tools that parse the remainder of
530+
// the stdin stream read from the correct place.
531+
532+
let bytes_read = head_file(&mut stdin_file, options)?;
533+
stdin_file.seek(SeekFrom::Start(current_pos + bytes_read))?;
534+
} else {
535+
let _bytes_read = head_file(&mut stdin_file, options)?;
517536
}
518537
}
538+
539+
#[cfg(not(unix))]
540+
{
541+
let mut stdin = stdin.lock();
542+
543+
match options.mode {
544+
Mode::FirstBytes(n) => read_n_bytes(&mut stdin, n),
545+
Mode::AllButLastBytes(n) => read_but_last_n_bytes(&mut stdin, n),
546+
Mode::FirstLines(n) => {
547+
read_n_lines(&mut stdin, n, options.line_ending.into())
548+
}
549+
Mode::AllButLastLines(n) => {
550+
read_but_last_n_lines(&mut stdin, n, options.line_ending.into())
551+
}
552+
}?;
553+
}
554+
555+
Ok(())
519556
}
520557
(name, false) => {
521558
let mut file = match std::fs::File::open(name) {
@@ -534,7 +571,8 @@ fn uu_head(options: &HeadOptions) -> UResult<()> {
534571
}
535572
println!("==> {name} <==");
536573
}
537-
head_file(&mut file, options)
574+
head_file(&mut file, options)?;
575+
Ok(())
538576
}
539577
};
540578
if let Err(e) = res {

0 commit comments

Comments
 (0)