Skip to content
128 changes: 127 additions & 1 deletion src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ where
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isnt possible.
/// **Panics** if broadcasting isn't possible.
#[track_caller]
pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
where
Expand Down Expand Up @@ -1067,3 +1067,129 @@ mod blas_tests
}
}
}

impl<A, S, S2> Dot<ArrayBase<S2, IxDyn>> for ArrayBase<S, IxDyn>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayBase<S2, IxDyn>) -> Self::Output {
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
// Vector-vector dot product
if self.len() != rhs.len() {
panic!("Vector lengths must match for dot product");
}
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}

#[cfg(test)]
mod arrayd_dot_tests {
use super::*;
use crate::ArrayD;

#[test]
fn test_arrayd_dot_2d() {
// Test case from the original issue
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();

let result = mat1.dot(&mat2);

// Verify the result is correct
assert_eq!(result.ndim(), 2);
assert_eq!(result.shape(), &[3, 3]);

// Compare with Array2 implementation
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let expected = mat1_2d.dot(&mat2_2d);

assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
}

#[test]
fn test_arrayd_dot_1d() {
// Test 1D array dot product
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();

let result = vec1.dot(&vec2);

// Verify scalar result
assert_eq!(result.ndim(), 0);
assert_eq!(result.shape(), &[]);
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
}

#[test]
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
fn test_arrayd_dot_3d() {
// Test that 3D arrays are not supported
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
fn test_arrayd_dot_incompatible_dims() {
// Test arrays with incompatible dimensions
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
fn test_arrayd_dot_matrix_vector() {
// Test matrix-vector multiplication
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();

let result = mat.dot(&vec);

// Verify result
assert_eq!(result.ndim(), 1);
assert_eq!(result.shape(), &[3]);

// Compare with Array2 implementation
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
let expected = mat_2d.dot(&vec_1d);

assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
}
}