Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,12 @@ fn simple_eval_(
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"Reciprocal" => {
let xs = get(&node.input[0])?;
let ones = Tensor::ones_like(&xs)?;
let output = ones.div(xs)?;
values.insert(node.output[0].clone(), output);
}
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
Expand Down Expand Up @@ -488,6 +494,29 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), ys);
}
"GlobalAveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool
let xs = get(&node.input[0])?;
let [_n_dim, _c_dim, kernel_shape @ ..] = xs.dims() else {
bail!(
"only 2d GlobalAveragePool is supported, kernel shape {:?}",
xs.dims()
);
};
let ys = match kernel_shape {
[d1, d2] => xs.avg_pool2d((*d1, *d2)),
[d1] => {
let xs = xs.unsqueeze(1)?;
let xs = xs.avg_pool2d((1, *d1))?;
xs.squeeze(1)
}
_ => bail!(
"only 2d GlobalAveragePool is supported, kernel shape {:?}",
xs.dims()
),
}?;
values.insert(node.output[0].clone(), ys);
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
Expand Down Expand Up @@ -1121,13 +1150,6 @@ fn simple_eval_(
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
let data = get(&node.input[0])?;
let pads = get(&node.input[1])?;
if node.input.len() > 2 {
bail!(
"unsupported number of inputs {} for Pad node {:?}, expected 2",
node.input.len(),
node.name
);
}
if pads.rank() != 1 {
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
}
Expand Down Expand Up @@ -1164,6 +1186,34 @@ fn simple_eval_(

values.insert(node.output[0].clone(), out);
}
"constant" => {
let value = if node.input.len() > 2 {
get(&node.input[2])?.to_vec0::<f32>()?
} else {
0.0
};

let mut out = data.clone();
for (axis, (pad_pre, pad_post)) in
pads_pre.iter().zip(pads_post).enumerate()
{
if *pad_pre == 0 && *pad_post == 0 {
continue;
}

let mut new_dims = out.dims().to_vec();
new_dims[axis] += (*pad_pre + *pad_post) as usize;

out = Tensor::full(value, new_dims, out.device())?.slice_scatter(
&out,
axis,
*pad_pre as usize,
)?;
}

values.insert(node.output[0].clone(), out);
}

_ => bail!(
"unsupported 'mode' value {mode:?} for Pad node {:?}",
node.name
Expand Down
Loading