Skip to content

Commit

Permalink
fix(ai): seq2seq models causing null pointer error
Browse files Browse the repository at this point in the history
- Solved the "`GetMutableData` should not be a null pointer" error while
executing seq2seq models.
- Ref.: pykeio/ort#185
  • Loading branch information
kallebysantos committed Oct 31, 2024
1 parent 5123bb0 commit 2c945d7
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions crates/sb_ai/onnxruntime/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,46 @@ use serde::{Deserialize, Serialize};
// but it only allows [u8] instead of [T], so we need to get into `unsafe` path.
macro_rules! v8_slice_from {
(tensor::<$type:ident>($tensor:expr)) => {{
let (_, raw_tensor) = $tensor
.try_extract_raw_tensor_mut::<$type>()
.map_err(AnyError::from)?;

let tensor_ptr = raw_tensor.as_ptr();
let tensor_len = raw_tensor.len();
let tensor_rc = Rc::into_raw(Rc::new(raw_tensor)) as *const c_void;
// We must ensure there's some detection to avoid `null pointer` errors
// https://github.com/pykeio/ort/issues/185
let n_detections = $tensor.shape()?[0];
if n_detections == 0 {
let buf_store = v8::ArrayBuffer::new_backing_store_from_vec(vec![]).make_shared();
let buffer_slice =
unsafe { deno_core::serde_v8::V8Slice::<u8>::from_parts(buf_store, 0..0) };

buffer_slice
} else {
let (_, raw_tensor) = $tensor
.try_extract_raw_tensor_mut::<$type>()
.map_err(AnyError::from)?;

let tensor_ptr = raw_tensor.as_ptr();
let tensor_len = raw_tensor.len();
let tensor_rc = Rc::into_raw(Rc::new(raw_tensor)) as *const c_void;

let buffer_len = tensor_len * size_of::<$type>();

extern "C" fn drop_tensor(_ptr: *mut c_void, _len: usize, data: *mut c_void) {
// SAFETY: We know that data is a raw Rc from above
unsafe { drop(Rc::from_raw(data.cast::<$type>())) }
}

let buffer_len = tensor_len * size_of::<$type>();
let buf_store = unsafe {
v8::ArrayBuffer::new_backing_store_from_ptr(
tensor_ptr as _,
buffer_len,
drop_tensor,
tensor_rc as _,
)
}
.make_shared();

extern "C" fn drop_tensor(_ptr: *mut c_void, _len: usize, data: *mut c_void) {
// SAFETY: We know that data is a raw Rc from above
unsafe { drop(Rc::from_raw(data.cast::<$type>())) }
}
let buffer_slice =
unsafe { deno_core::serde_v8::V8Slice::<u8>::from_parts(buf_store, 0..buffer_len) };

// Zero-Copying using ptr
let buf_store = unsafe {
v8::ArrayBuffer::new_backing_store_from_ptr(
tensor_ptr as _,
buffer_len,
drop_tensor,
tensor_rc as _,
)
buffer_slice
}
.make_shared();

let buffer_slice =
unsafe { deno_core::serde_v8::V8Slice::<u8>::from_parts(buf_store, 0..buffer_len) };

buffer_slice
}};
}

Expand Down Expand Up @@ -152,7 +162,7 @@ pub struct ToJsTensor {
#[serde(rename = "type", with = "JsTensorType")]
data_type: TensorElementType,
data: ToJsBuffer,
dims: Vec<i64>,
pub dims: Vec<i64>,
}

impl ToJsTensor {
Expand Down

0 comments on commit 2c945d7

Please sign in to comment.