diff --git a/src/uu/od/src/formatteriteminfo.rs b/src/uu/od/src/formatteriteminfo.rs index d650a6cec87..89cd645273b 100644 --- a/src/uu/od/src/formatteriteminfo.rs +++ b/src/uu/od/src/formatteriteminfo.rs @@ -11,6 +11,7 @@ use std::fmt; pub enum FormatWriter { IntWriter(fn(u64) -> String), FloatWriter(fn(f64) -> String), + BFloatWriter(fn(f64) -> String), MultibyteWriter(fn(&[u8]) -> String), } @@ -25,6 +26,10 @@ impl fmt::Debug for FormatWriter { f.write_str("FloatWriter:")?; fmt::Pointer::fmt(p, f) } + Self::BFloatWriter(ref p) => { + f.write_str("BFloatWriter:")?; + fmt::Pointer::fmt(p, f) + } Self::MultibyteWriter(ref p) => { f.write_str("MultibyteWriter:")?; fmt::Pointer::fmt(&(*p as *const ()), f) diff --git a/src/uu/od/src/inputdecoder.rs b/src/uu/od/src/inputdecoder.rs index 44ad2922843..fbdbcdf47ce 100644 --- a/src/uu/od/src/inputdecoder.rs +++ b/src/uu/od/src/inputdecoder.rs @@ -2,7 +2,8 @@ // // For the full copyright and license information, please view the LICENSE // file that was distributed with this source code. -use half::f16; +// spell-checker:ignore bfloat +use half::{bf16, f16}; use std::io; use crate::byteorder_io::ByteOrder; @@ -155,6 +156,13 @@ impl MemoryDecoder<'_> { _ => panic!("Invalid byte_size: {byte_size}"), } } + + /// Returns a bfloat16 as f64 from the internal buffer at position `start`. + pub fn read_bfloat(&self, start: usize) -> f64 { + let bits = self.byte_order.read_u16(&self.data[start..start + 2]); + let val = f32::from(bf16::from_bits(bits)); + f64::from(val) + } } #[cfg(test)] diff --git a/src/uu/od/src/od.rs b/src/uu/od/src/od.rs index 8f1f2f10b4c..8fd5761f5ce 100644 --- a/src/uu/od/src/od.rs +++ b/src/uu/od/src/od.rs @@ -5,7 +5,7 @@ // spell-checker:ignore (clap) dont // spell-checker:ignore (ToDO) formatteriteminfo inputdecoder inputoffset mockstream nrofbytes partialreader odfunc multifile exitcode -// spell-checker:ignore Anone +// spell-checker:ignore Anone bfloat mod byteorder_io; mod formatteriteminfo; @@ -576,6 +576,10 @@ fn print_bytes(prefix: &str, input_decoder: &MemoryDecoder, output_info: &Output let p = input_decoder.read_float(b, f.formatter_item_info.byte_size); output_text.push_str(&func(p)); } + FormatWriter::BFloatWriter(func) => { + let p = input_decoder.read_bfloat(b); + output_text.push_str(&func(p)); + } FormatWriter::MultibyteWriter(func) => { output_text.push_str(&func(input_decoder.get_full_buffer(b))); } diff --git a/src/uu/od/src/parse_formats.rs b/src/uu/od/src/parse_formats.rs index 7558f8e7f14..eb7643ec672 100644 --- a/src/uu/od/src/parse_formats.rs +++ b/src/uu/od/src/parse_formats.rs @@ -235,6 +235,10 @@ fn is_format_size_char( *byte_size = 8; true } + (FormatTypeCategory::Float, Some('H' | 'B')) => { + *byte_size = 2; + true + } // FormatTypeCategory::Float, 'L' => *byte_size = 16, // TODO support f128 _ => false, } @@ -290,7 +294,45 @@ fn parse_type_string(params: &str) -> Result, Strin let mut byte_size = 0u8; let mut show_ascii_dump = false; - if is_format_size_char(ch, type_cat, &mut byte_size) { + let mut float_variant = None; + if type_cat == FormatTypeCategory::Float { + match ch { + Some(var @ ('B' | 'H')) => { + byte_size = 2; + float_variant = Some(var); + ch = chars.next(); + } + Some('F') => { + byte_size = 4; + ch = chars.next(); + } + Some('D') => { + byte_size = 8; + ch = chars.next(); + } + _ => { + if is_format_size_char(ch, type_cat, &mut byte_size) { + ch = chars.next(); + } else { + let mut decimal_size = String::new(); + while is_format_size_decimal(ch, type_cat, &mut decimal_size) { + ch = chars.next(); + } + if !decimal_size.is_empty() { + byte_size = decimal_size.parse().map_err(|_| { + get_message_with_args( + "od-error-invalid-number", + HashMap::from([ + ("number".to_string(), decimal_size.quote().to_string()), + ("spec".to_string(), params.quote().to_string()), + ]), + ) + })?; + } + } + } + } + } else if is_format_size_char(ch, type_cat, &mut byte_size) { ch = chars.next(); } else { let mut decimal_size = String::new(); @@ -313,15 +355,23 @@ fn parse_type_string(params: &str) -> Result, Strin ch = chars.next(); } - let ft = od_format_type(type_char, byte_size).ok_or_else(|| { - get_message_with_args( - "od-error-invalid-size", - HashMap::from([ - ("size".to_string(), byte_size.to_string()), - ("spec".to_string(), params.quote().to_string()), - ]), - ) - })?; + let ft = if let Some(v) = float_variant { + match v { + 'B' => FORMAT_ITEM_BF16, + 'H' => FORMAT_ITEM_F16, + _ => unreachable!(), + } + } else { + od_format_type(type_char, byte_size).ok_or_else(|| { + get_message_with_args( + "od-error-invalid-size", + HashMap::from([ + ("size".to_string(), byte_size.to_string()), + ("spec".to_string(), params.quote().to_string()), + ]), + ) + })? + }; formats.push(ParsedFormatterItemInfo::new(ft, show_ascii_dump)); } diff --git a/src/uu/od/src/prn_float.rs b/src/uu/od/src/prn_float.rs index 938e029a2d3..44cafaf736e 100644 --- a/src/uu/od/src/prn_float.rs +++ b/src/uu/od/src/prn_float.rs @@ -25,6 +25,12 @@ pub static FORMAT_ITEM_F64: FormatterItemInfo = FormatterItemInfo { formatter: FormatWriter::FloatWriter(format_item_flo64), }; +pub static FORMAT_ITEM_BF16: FormatterItemInfo = FormatterItemInfo { + byte_size: 2, + print_width: 15, + formatter: FormatWriter::BFloatWriter(format_item_bf16), +}; + pub fn format_item_flo16(f: f64) -> String { format!(" {}", format_flo16(f16::from_f64(f))) } @@ -64,6 +70,10 @@ fn format_flo64_exp_precision(f: f64, width: usize, precision: usize) -> String formatted.replace('e', "e+") } +pub fn format_item_bf16(f: f64) -> String { + format!(" {}", format_flo32(f as f32)) +} + fn format_flo16(f: f16) -> String { format_float(f64::from(f), 9, 4) } diff --git a/tests/by-util/test_od.rs b/tests/by-util/test_od.rs index d2700df5dba..2c3ae2a959f 100644 --- a/tests/by-util/test_od.rs +++ b/tests/by-util/test_od.rs @@ -221,6 +221,62 @@ fn test_f16() { .stdout_is(expected_output); } +#[test] +fn test_fh() { + let input: [u8; 14] = [ + 0x00, 0x3c, // 0x3C00 1.0 + 0x00, 0x00, // 0x0000 0.0 + 0x00, 0x80, // 0x8000 -0.0 + 0x00, 0x7c, // 0x7C00 Inf + 0x00, 0xfc, // 0xFC00 -Inf + 0x00, 0xfe, // 0xFE00 NaN + 0x00, 0x84, + ]; // 0x8400 -6.104e-5 + let expected_output = unindent( + " + 0000000 1.000 0 -0 inf + 0000010 -inf NaN -6.104e-5 + 0000016 + ", + ); + new_ucmd!() + .arg("--endian=little") + .arg("-tfH") + .arg("-w8") + .run_piped_stdin(&input[..]) + .success() + .no_stderr() + .stdout_is(expected_output); +} + +#[test] +fn test_fb() { + let input: [u8; 14] = [ + 0x80, 0x3f, // 1.0 + 0x00, 0x00, // 0.0 + 0x00, 0x80, // -0.0 + 0x80, 0x7f, // Inf + 0x80, 0xff, // -Inf + 0xc0, 0x7f, // NaN + 0x80, 0xb8, + ]; // -6.1035156e-5 + let expected_output = unindent( + " + 0000000 1.0000000 0 -0 inf + 0000010 -inf NaN -6.1035156e-5 + 0000016 + ", + ); + new_ucmd!() + .arg("--endian=little") + .arg("-tfB") + .arg("-w8") + .run_piped_stdin(&input[..]) + .success() + .no_stderr() + .stdout_is(expected_output); +} + #[test] fn test_f32() { let input: [u8; 28] = [