Skip to content

Commit

Permalink
implement Slice op (#2260)
Browse files Browse the repository at this point in the history
  • Loading branch information
shua committed Jun 12, 2024
1 parent 9f804af commit 2b10aaa
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 0 deletions.
80 changes: 80 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
DataType::Bool => Some(DType::U8),
_ => None,
}
}
Expand Down Expand Up @@ -1053,6 +1054,85 @@ fn simple_eval_(
),
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
"Slice" => {
let data = get(&node.input[0])?;
let starts = get(&node.input[1])?;
let ends = get(&node.input[2])?;
let default_axes;
let default_steps;
let axes: &Tensor;
let steps: &Tensor;
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
// they are set to [1, ..., 1] of length len(starts)
match node.input.len() {
3 => {
let len = starts.dims()[0];
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
axes = default_axes.as_ref().unwrap();
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
4 => {
let len = starts.dims()[0];
axes = get(&node.input[3])?;
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
5 => {
steps = get(&node.input[4])?;
axes = get(&node.input[3])?;
}
_ => bail!(
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
node.input.len(),
node
),
}

let mut out = data.clone();
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
// All negative elements of axes are made non-negative by
// adding r to them, where r = rank(input).
let axis = if axis < 0 {
axis + data.rank() as i64
} else {
axis
} as usize;

let data_dim = data.dims()[axis] as i64;
let mut s = starts.get(i)?.to_scalar::<i64>()?;
let mut e = ends.get(i)?.to_scalar::<i64>()?;
// All negative values in starts[i] and ends[i] have
// dims[axes[i]] added to them, where dims are the
// dimensions of input.
if s < 0 {
s += data_dim;
}
if e < 0 {
e += data_dim;
}

let p = steps.get(i)?.to_scalar::<i64>()?;
// starts[i] is clamped into the range [0, dims[axes[i]]]
// for positive stepping and [0, dims[axes[i]]-1] for
// negative stepping.
// for positive stepping ends[axes[i]] is clamped to
// [0, dims[axes[i]]], while for negative stepping it is
// clamped to [-1, dims[axes[i]]-1].
if p >= 0 {
s = s.clamp(0, data_dim);
e = e.clamp(0, data_dim);
} else {
s = s.clamp(0, data_dim - 1);
e = e.clamp(-1, data_dim - 1);
}

let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
Expand Down
135 changes: 135 additions & 0 deletions candle-onnx/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> {
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
Ok(())
}

#[test]
fn test_slice() -> Result<()> {
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Slice".to_string(),
input: vec![
"data".to_string(),
"starts".to_string(),
"ends".to_string(),
"axes".to_string(),
"steps".to_string(),
],
output: vec!["result".to_string()],
..NodeProto::default()
}],
input: ["data", "starts", "ends", "axes", "steps"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
output: ["result"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
..GraphProto::default()
}));

/*
data = [
[1, 2, 3, 4],
[5, 6, 7, 8],
]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
steps = [1, 2]
result = [
[5, 7],
]
*/

let outputs = candle_onnx::simple_eval(
&model,
HashMap::from_iter([
(
"data".to_string(),
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
),
(
"starts".to_string(),
Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
),
(
"ends".to_string(),
Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
),
(
"axes".to_string(),
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
),
(
"steps".to_string(),
Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
),
]),
)?;
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
assert_eq!(actual, vec![vec![5i64, 7]]);

/*
data = [
[1, 2, 3, 4],
[5, 6, 7, 8],
]
starts = [0, 1]
ends = [-1, 1000]
result = [
[2, 3, 4],
]
*/
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Slice".to_string(),
input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
output: vec!["result".to_string()],
..NodeProto::default()
}],
input: ["data", "starts", "ends"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
output: ["result"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
..GraphProto::default()
}));
let outputs = candle_onnx::simple_eval(
&model,
HashMap::from_iter([
(
"data".to_string(),
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
),
(
"starts".to_string(),
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
),
(
"ends".to_string(),
Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
),
]),
)?;
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
assert_eq!(actual, vec![vec![2i64, 3, 4]]);

Ok(())
}

0 comments on commit 2b10aaa

Please sign in to comment.