Skip to content

Commit

Permalink
Add support for f16 tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
jleibs committed Feb 28, 2023
1 parent 1820cae commit 9b83da8
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 19 deletions.
6 changes: 2 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ rerun = { path = "crates/rerun", version = "0.2.0" }

anyhow = "1.0"
arrow2 = "0.16"
arrow2_convert = "0.4.2"
# TODO(jleibs): Land upstream: https://github.com/DataEngineeringLabs/arrow2-convert/pull/104
# arrow2_convert = "0.4.2"
arrow2_convert = { git = "https://github.com/rerun-io/arrow2-convert", rev = "93f9f85b55c3a51ea1" }
clap = "4.0"
comfy-table = { version = "6.1", default-features = false }
ecolor = "0.21.0"
Expand Down
28 changes: 23 additions & 5 deletions crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ impl ArrowDeserialize for TensorId {
/// false
/// ),
/// Field::new(
/// "F16",
/// DataType::List(Box::new(Field::new("item", DataType::Float16, false))),
/// false
/// ),
/// Field::new(
/// "F32",
/// DataType::List(Box::new(Field::new("item", DataType::Float32, false))),
/// false
Expand Down Expand Up @@ -166,8 +171,7 @@ pub enum TensorData {
I32(Buffer<i32>),
I64(Buffer<i64>),
// ---
// TODO(#854): Native F16 support for arrow tensors
//F16(Vec<arrow2::types::f16>),
F16(Buffer<arrow2::types::f16>),
F32(Buffer<f32>),
F64(Buffer<f64>),
JPEG(Vec<u8>),
Expand Down Expand Up @@ -386,6 +390,7 @@ impl TensorTrait for Tensor {
TensorData::I16(buf) => Some(TensorElement::I16(buf[offset])),
TensorData::I32(buf) => Some(TensorElement::I32(buf[offset])),
TensorData::I64(buf) => Some(TensorElement::I64(buf[offset])),
TensorData::F16(buf) => Some(TensorElement::F16(buf[offset])),
TensorData::F32(buf) => Some(TensorElement::F32(buf[offset])),
TensorData::F64(buf) => Some(TensorElement::F64(buf[offset])),
TensorData::JPEG(_) => None, // Too expensive to unpack here.
Expand All @@ -402,6 +407,7 @@ impl TensorTrait for Tensor {
TensorData::I16(_) => TensorDataType::I16,
TensorData::I32(_) => TensorDataType::I32,
TensorData::I64(_) => TensorDataType::I64,
TensorData::F16(_) => TensorDataType::F16,
TensorData::F32(_) => TensorDataType::F32,
TensorData::F64(_) => TensorDataType::F64,
}
Expand Down Expand Up @@ -470,6 +476,10 @@ impl From<&Tensor> for ClassicTensor {
crate::TensorDataType::I64,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::F16(data) => (
crate::TensorDataType::F16,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::F32(data) => (
crate::TensorDataType::F32,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
Expand Down Expand Up @@ -572,15 +582,23 @@ tensor_type!(i16, I16);
tensor_type!(i32, I32);
tensor_type!(i64, I64);

tensor_type!(arrow2::types::f16, F16);
tensor_type!(f32, F32);
tensor_type!(f64, F64);

// TODO(#854) Switch back to `tensor_type!` once we have F16 tensors
// Support for `half::f16` instead of `arrow2::types::f16`
impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, half::f16> {
type Error = TensorCastError;

fn try_from(_: &'a Tensor) -> Result<Self, Self::Error> {
Err(TensorCastError::F16NotSupported)
fn try_from(value: &'a Tensor) -> Result<Self, Self::Error> {
let shape: Vec<_> = value.shape.iter().map(|d| d.size as usize).collect();

if let TensorData::F16(data) = &value.data {
ndarray::ArrayViewD::from_shape(shape, bytemuck::cast_slice(data.as_slice()))
.map_err(|err| TensorCastError::BadTensorShape { source: err })
} else {
Err(TensorCastError::TypeMismatch)
}
}
}

Expand Down
7 changes: 3 additions & 4 deletions crates/re_log_types/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ impl TensorDataTypeTrait for f64 {

/// The data that can be stored in a [`ClassicTensor`].
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum TensorElement {
/// Unsigned 8 bit integer.
///
Expand Down Expand Up @@ -205,7 +204,7 @@ pub enum TensorElement {
///
/// Uses the standard IEEE 754-2008 binary16 format.
/// Set <https://en.wikipedia.org/wiki/Half-precision_floating-point_format>.
F16(f16),
F16(arrow2::types::f16),

/// 32-bit floating point number.
F32(f32),
Expand All @@ -228,7 +227,7 @@ impl TensorElement {
Self::I32(value) => *value as _,
Self::I64(value) => *value as _,

Self::F16(value) => value.to_f64(),
Self::F16(value) => value.to_f32() as _,
Self::F32(value) => *value as _,
Self::F64(value) => *value,
}
Expand All @@ -253,7 +252,7 @@ impl TensorElement {
Self::I32(value) => u16::try_from(*value).ok(),
Self::I64(value) => u16::try_from(*value).ok(),

Self::F16(value) => u16_from_f64(value.to_f64()),
Self::F16(value) => u16_from_f64(value.to_f32() as f64),
Self::F32(value) => u16_from_f64(*value as f64),
Self::F64(value) => u16_from_f64(*value),
}
Expand Down
26 changes: 26 additions & 0 deletions crates/re_viewer/src/misc/caches/tensor_image_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,17 @@ impl AsDynamicImage for ClassicTensor {
}
}
}
(1, TensorDataType::F16, _) => {
let l: &[half::f16] = bytemuck::cast_slice(bytes);
let colors: Vec<u8> = l
.iter()
.copied()
.map(|f| linear_u8_from_linear_f32(f.to_f32()))
.collect();
image::GrayImage::from_raw(width, height, colors)
.context("Bad Luminance f16")
.map(DynamicImage::ImageLuma8)
}
(1, TensorDataType::F32, _) => {
let l: &[f32] = bytemuck::cast_slice(bytes);
let colors: Vec<u8> =
Expand All @@ -409,6 +420,21 @@ impl AsDynamicImage for ClassicTensor {
.context("Bad RGB16 image")
.map(DynamicImage::ImageRgb16)
}
(3, TensorDataType::F16, _) => {
let rgb: &[[half::f16; 3]] = bytemuck::cast_slice(bytes);
let colors: Vec<u8> = rgb
.iter()
.flat_map(|&[r, g, b]| {
let r = gamma_u8_from_linear_f32(r.to_f32());
let g = gamma_u8_from_linear_f32(g.to_f32());
let b = gamma_u8_from_linear_f32(b.to_f32());
[r, g, b]
})
.collect();
image::RgbImage::from_raw(width, height, colors)
.context("Bad RGB f16")
.map(DynamicImage::ImageRgb8)
}
(3, TensorDataType::F32, _) => {
let rgb: &[[f32; 3]] = bytemuck::cast_slice(bytes);
let colors: Vec<u8> = rgb
Expand Down
3 changes: 3 additions & 0 deletions crates/re_viewer/src/ui/view_bar_chart/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ pub(crate) fn view_bar_chart(
instance_key,
data.iter().copied().map(|v| v as f64),
),
component_types::TensorData::F16(data) => {
create_bar_chart(ent_path, instance_key, data.iter().map(|f| f.to_f32()))
}
component_types::TensorData::F32(data) => {
create_bar_chart(ent_path, instance_key, data.iter().copied())
}
Expand Down
5 changes: 0 additions & 5 deletions rerun_py/rerun_sdk/rerun/log/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ def _log_tensor(
np.float64,
]

# We don't support float16 -- upscale to f32
# TODO(#854): Native F16 support for arrow tensors
if tensor.dtype == np.float16:
tensor = np.asarray(tensor, dtype="float32")

if tensor.dtype not in SUPPORTED_DTYPES:
_send_warning(f"Unsupported dtype: {tensor.dtype}. Expected a numeric type. Skipping this tensor.", 2)
return
Expand Down

0 comments on commit 9b83da8

Please sign in to comment.