Skip to content

Commit

Permalink
Add std::ops traits to Tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
KerfuffleV2 committed Apr 12, 2023
1 parent 825a33b commit 7e2d3d0
Showing 1 changed file with 138 additions and 0 deletions.
138 changes: 138 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,141 @@ fn i32_to_usize(val: i32) -> usize {
fn i64_to_usize(val: i64) -> usize {
usize::try_from(val).unwrap()
}

mod map_ops {
use super::{Context, Tensor};
use std::{os::raw::c_int, slice};

unsafe extern "C" fn sub_fun(n: c_int, dst: *mut f32, src0: *mut f32, src1: *mut f32) {
let n = n as usize;
let dst = slice::from_raw_parts_mut(dst, n);
let src0 = slice::from_raw_parts(src0, n);
let src1 = slice::from_raw_parts(src1, n);

dst.iter_mut()
.zip(src0.iter())
.zip(src1.iter())
.for_each(|((dstel, src0el), src1el)| {
*dstel = *src0el - *src1el;
});
}

unsafe extern "C" fn div_fun(n: c_int, dst: *mut f32, src0: *mut f32, src1: *mut f32) {
let n = n as usize;
let dst = slice::from_raw_parts_mut(dst, n);
let src0 = slice::from_raw_parts(src0, n);
let src1 = slice::from_raw_parts(src1, n);

dst.iter_mut()
.zip(src0.iter())
.zip(src1.iter())
.for_each(|((dstel, src0el), src1el)| {
*dstel = *src0el / *src1el;
});
}

pub fn sub(ctx: &Context, tensor1: &Tensor, tensor2: &Tensor) -> Tensor {
unsafe { ctx.op_map_binary(tensor1, tensor2, sub_fun) }
}

pub fn div(ctx: &Context, tensor1: &Tensor, tensor2: &Tensor) -> Tensor {
unsafe { ctx.op_map_binary(tensor1, tensor2, div_fun) }
}
}

use std::ops::{Add, BitXor, Div, Mul, Sub};

impl Add for &Tensor {
type Output = Tensor;

fn add(self, rhs: Self) -> Self::Output {
Context {
ptr: self.ctx.upgrade().expect("Couldn't get context!"),
}
.op_add(self, rhs)
}
}

impl Add for Tensor {
type Output = Tensor;

fn add(self, rhs: Self) -> Self::Output {
&self + &rhs
}
}

impl Sub for &Tensor {
type Output = Tensor;

fn sub(self, rhs: Self) -> Self::Output {
let ctx = Context {
ptr: self.ctx.upgrade().expect("Couldn't get context!"),
};
map_ops::sub(&ctx, self, rhs)
}
}

impl Sub for Tensor {
type Output = Tensor;

fn sub(self, rhs: Self) -> Self::Output {
&self - &rhs
}
}

impl Mul for &Tensor {
type Output = Tensor;

fn mul(self, rhs: Self) -> Self::Output {
Context {
ptr: self.ctx.upgrade().expect("Couldn't get context!"),
}
.op_mul(self, rhs)
}
}

impl Mul for Tensor {
type Output = Tensor;

fn mul(self, rhs: Self) -> Self::Output {
&self * &rhs
}
}

impl Div for &Tensor {
type Output = Tensor;

fn div(self, rhs: Self) -> Self::Output {
let ctx = Context {
ptr: self.ctx.upgrade().expect("Couldn't get context!"),
};
map_ops::div(&ctx, self, rhs)
}
}

impl Div for Tensor {
type Output = Tensor;

fn div(self, rhs: Self) -> Self::Output {
&self / &rhs
}
}

impl BitXor for &Tensor {
type Output = Tensor;

fn bitxor(self, rhs: Self) -> Self::Output {
Context {
ptr: self.ctx.upgrade().expect("Couldn't get context!"),
}
.op_mul_mat(self, rhs)
}
}

impl BitXor for Tensor {
type Output = Tensor;

fn bitxor(self, rhs: Self) -> Self::Output {
&self ^ &rhs
}
}

0 comments on commit 7e2d3d0

Please sign in to comment.