From 7e2d3d0ad24c7597376fe2dd3b8adb413c411c80 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 12 Apr 2023 03:05:33 -0600 Subject: [PATCH] Add std::ops traits to Tensor. --- ggml/src/lib.rs | 138 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 6d466247..86a304fd 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -588,3 +588,141 @@ fn i32_to_usize(val: i32) -> usize { fn i64_to_usize(val: i64) -> usize { usize::try_from(val).unwrap() } + +mod map_ops { + use super::{Context, Tensor}; + use std::{os::raw::c_int, slice}; + + unsafe extern "C" fn sub_fun(n: c_int, dst: *mut f32, src0: *mut f32, src1: *mut f32) { + let n = n as usize; + let dst = slice::from_raw_parts_mut(dst, n); + let src0 = slice::from_raw_parts(src0, n); + let src1 = slice::from_raw_parts(src1, n); + + dst.iter_mut() + .zip(src0.iter()) + .zip(src1.iter()) + .for_each(|((dstel, src0el), src1el)| { + *dstel = *src0el - *src1el; + }); + } + + unsafe extern "C" fn div_fun(n: c_int, dst: *mut f32, src0: *mut f32, src1: *mut f32) { + let n = n as usize; + let dst = slice::from_raw_parts_mut(dst, n); + let src0 = slice::from_raw_parts(src0, n); + let src1 = slice::from_raw_parts(src1, n); + + dst.iter_mut() + .zip(src0.iter()) + .zip(src1.iter()) + .for_each(|((dstel, src0el), src1el)| { + *dstel = *src0el / *src1el; + }); + } + + pub fn sub(ctx: &Context, tensor1: &Tensor, tensor2: &Tensor) -> Tensor { + unsafe { ctx.op_map_binary(tensor1, tensor2, sub_fun) } + } + + pub fn div(ctx: &Context, tensor1: &Tensor, tensor2: &Tensor) -> Tensor { + unsafe { ctx.op_map_binary(tensor1, tensor2, div_fun) } + } +} + +use std::ops::{Add, BitXor, Div, Mul, Sub}; + +impl Add for &Tensor { + type Output = Tensor; + + fn add(self, rhs: Self) -> Self::Output { + Context { + ptr: self.ctx.upgrade().expect("Couldn't get context!"), + } + .op_add(self, rhs) + } +} + +impl Add for Tensor { + type Output = Tensor; + + fn add(self, rhs: Self) -> Self::Output { + &self + &rhs + } +} + +impl Sub for &Tensor { + type Output = Tensor; + + fn sub(self, rhs: Self) -> Self::Output { + let ctx = Context { + ptr: self.ctx.upgrade().expect("Couldn't get context!"), + }; + map_ops::sub(&ctx, self, rhs) + } +} + +impl Sub for Tensor { + type Output = Tensor; + + fn sub(self, rhs: Self) -> Self::Output { + &self - &rhs + } +} + +impl Mul for &Tensor { + type Output = Tensor; + + fn mul(self, rhs: Self) -> Self::Output { + Context { + ptr: self.ctx.upgrade().expect("Couldn't get context!"), + } + .op_mul(self, rhs) + } +} + +impl Mul for Tensor { + type Output = Tensor; + + fn mul(self, rhs: Self) -> Self::Output { + &self * &rhs + } +} + +impl Div for &Tensor { + type Output = Tensor; + + fn div(self, rhs: Self) -> Self::Output { + let ctx = Context { + ptr: self.ctx.upgrade().expect("Couldn't get context!"), + }; + map_ops::div(&ctx, self, rhs) + } +} + +impl Div for Tensor { + type Output = Tensor; + + fn div(self, rhs: Self) -> Self::Output { + &self / &rhs + } +} + +impl BitXor for &Tensor { + type Output = Tensor; + + fn bitxor(self, rhs: Self) -> Self::Output { + Context { + ptr: self.ctx.upgrade().expect("Couldn't get context!"), + } + .op_mul_mat(self, rhs) + } +} + +impl BitXor for Tensor { + type Output = Tensor; + + fn bitxor(self, rhs: Self) -> Self::Output { + &self ^ &rhs + } +}