diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a14306657b..151f272993 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -129,8 +129,6 @@ impl Tensor { | Op::Permute(node, _) | Op::Narrow(node, _, _, _) | Op::Unary(node, _) - | Op::Elu(node, _) - | Op::Powf(node, _) | Op::CustomOp1(node, _) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 8e24368ff1..2ea80b3f75 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -4,7 +4,7 @@ use crate::Tensor; use float8::F8E4M3; use half::{bf16, f16}; -use num_traits::float::Float; +use num_traits::{float::Float, PrimInt}; #[derive(Clone, Copy, PartialEq, Eq)] pub enum CmpOp { @@ -72,10 +72,18 @@ pub enum UnaryOp { Sign, } +// Op that is applied to itself with an additional tensor wide scalar argument +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryScalarOp { + Elu, + Powf +} + #[derive(Clone)] pub enum Op { Binary(Tensor, Tensor, BinaryOp), Unary(Tensor, UnaryOp), + UnaryScalar(Tensor, Tensor, UnaryScalarOp), Cmp(Tensor, CmpOp), // The third argument is the reduced shape with `keepdim=true`. Reduce(Tensor, ReduceOp, Vec), @@ -164,8 +172,6 @@ pub enum Op { ToDevice(Tensor), Transpose(Tensor, usize, usize), Permute(Tensor, Vec), - Elu(Tensor, f64), - Powf(Tensor, f64), CustomOp1( Tensor, std::sync::Arc>, @@ -241,6 +247,33 @@ pub trait BinaryOpT { fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} } +pub trait UnaryScalarOpT { + const NAME: &'static str; + const KERNEL: &'static str; + const V: Self; + fn bf16(v1: bf16, v2: bf16) -> bf16; + fn f16(v1: f16, v2: f16) -> f16; + fn f32(v1: f32, v2: f32) -> f32; + fn f64(v1: f64, v2: f64) -> f64; + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; + fn u8(v1: u8, v2: u8) -> u8; + fn u32(v1: u32, v2: u32) -> u32; + fn i64(v1: i64, v2: i64) -> i64; + + // There is no very good way to represent optional function in traits so we go for an explicit + // boolean flag to mark the function as existing. + const BF16_VEC: bool = false; + fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16], _: bf16) {} + const F16_VEC: bool = false; + fn f16_vec(_xs: &[f16], _ys: &mut [f16], _: f16) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3], _: F8E4M3) {} + const F32_VEC: bool = false; + fn f32_vec(_xs: &[f32], _ys: &mut [f32], _: f32) {} + const F64_VEC: bool = false; + fn f64_vec(_xs: &[f64], _ys: &mut [f64], _: f64) {} +} + pub(crate) struct Add; pub(crate) struct Div; pub(crate) struct Mul; @@ -266,6 +299,8 @@ pub(crate) struct Floor; pub(crate) struct Ceil; pub(crate) struct Round; pub(crate) struct Sign; +pub(crate) struct Elu; +pub(crate) struct Powf; macro_rules! bin_op { ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { @@ -929,6 +964,45 @@ impl UnaryOpT for Relu { } } +impl UnaryScalarOpT for Powf { + const NAME: &'static str = "powf"; + const KERNEL: &'static str = "upowf"; + const V: Self = Powf; + #[inline(always)] + fn bf16(v: bf16, exponent: bf16) -> bf16 { + v.powf(exponent) + } + #[inline(always)] + fn f16(v: f16, exponent: f16) -> f16 { + v.powf(exponent) + } + #[inline(always)] + fn f8e4m3(v: F8E4M3, exponent: F8E4M3) -> F8E4M3 { + v.powf(exponent) + } + #[inline(always)] + fn f32(v: f32, exponent: f32) -> f32 { + v.powf(exponent) + } + #[inline(always)] + fn f64(v: f64, exponent: f64) -> f64 { + v.powf(exponent) + } + #[inline(always)] + fn u8(v: u8, exponent: u8) -> u8 { + v.pow(exponent.into()) + } + #[inline(always)] + fn u32(v: u32, exponent: u32) -> u32 { + v.pow(exponent.into()) + } + #[inline(always)] + fn i64(v: i64, exponent: i64) -> i64 { + v.pow(exponent.try_into().expect("exponent must be positive")) + } +} + + /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are /// properly checked when creating a new value #[derive(Clone)]