diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index a1128c54f3..eedf3e9945 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -358,6 +358,12 @@ fn simple_eval_( let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } + "Reciprocal" => { + let xs = get(&node.input[0])?; + let ones = Tensor::ones_like(&xs)?; + let output = ones.div(xs)?; + values.insert(node.output[0].clone(), output); + } "Pow" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; @@ -488,6 +494,29 @@ fn simple_eval_( }; values.insert(node.output[0].clone(), ys); } + "GlobalAveragePool" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool + let xs = get(&node.input[0])?; + let [_n_dim, _c_dim, kernel_shape @ ..] = xs.dims() else { + bail!( + "only 2d GlobalAveragePool is supported, kernel shape {:?}", + xs.dims() + ); + }; + let ys = match kernel_shape { + [d1, d2] => xs.avg_pool2d((*d1, *d2)), + [d1] => { + let xs = xs.unsqueeze(1)?; + let xs = xs.avg_pool2d((1, *d1))?; + xs.squeeze(1) + } + _ => bail!( + "only 2d GlobalAveragePool is supported, kernel shape {:?}", + xs.dims() + ), + }?; + values.insert(node.output[0].clone(), ys); + } "AveragePool" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool let dilations = get_attr_opt::<[i64]>(node, "dilations")?; @@ -1121,13 +1150,6 @@ fn simple_eval_( let mode = get_attr_opt(node, "mode")?.unwrap_or("constant"); let data = get(&node.input[0])?; let pads = get(&node.input[1])?; - if node.input.len() > 2 { - bail!( - "unsupported number of inputs {} for Pad node {:?}, expected 2", - node.input.len(), - node.name - ); - } if pads.rank() != 1 { bail!("Pad expects 'pads' input to be 1D vector: {pads:?}"); } @@ -1164,6 +1186,34 @@ fn simple_eval_( values.insert(node.output[0].clone(), out); } + "constant" => { + let value = if node.input.len() > 2 { + get(&node.input[2])?.to_vec0::()? + } else { + 0.0 + }; + + let mut out = data.clone(); + for (axis, (pad_pre, pad_post)) in + pads_pre.iter().zip(pads_post).enumerate() + { + if *pad_pre == 0 && *pad_post == 0 { + continue; + } + + let mut new_dims = out.dims().to_vec(); + new_dims[axis] += (*pad_pre + *pad_post) as usize; + + out = Tensor::full(value, new_dims, out.device())?.slice_scatter( + &out, + axis, + *pad_pre as usize, + )?; + } + + values.insert(node.output[0].clone(), out); + } + _ => bail!( "unsupported 'mode' value {mode:?} for Pad node {:?}", node.name